mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-19 02:54:28 -05:00
Compare commits
84 Commits
fix/spinne
...
add-llm-ma
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b54022bded | ||
|
|
987712dac1 | ||
|
|
e01526cf52 | ||
|
|
1704812f50 | ||
|
|
29f95e5b61 | ||
|
|
266526f08c | ||
|
|
26490e32d8 | ||
|
|
d6bf54281b | ||
|
|
a7835056c9 | ||
|
|
cf3390d192 | ||
|
|
d8007f74e9 | ||
|
|
4d341c55c5 | ||
|
|
01ef7e1925 | ||
|
|
5baf1a0f60 | ||
|
|
9fc5d465da | ||
|
|
c797f4e1f2 | ||
|
|
05033610bb | ||
|
|
76f3a89be8 | ||
|
|
df7bb57c83 | ||
|
|
b11d46d246 | ||
|
|
8e6bc5eb48 | ||
|
|
8b2b0c853a | ||
|
|
ffb86cced4 | ||
|
|
fea46a6d28 | ||
|
|
f2f779e54f | ||
|
|
dda9a9b010 | ||
|
|
c1d3604682 | ||
|
|
dfbfbdf696 | ||
|
|
994ebc2cf8 | ||
|
|
2245d115d3 | ||
|
|
5238b1b71c | ||
|
|
4fb86b2738 | ||
|
|
e10128e9f0 | ||
|
|
b205d5863e | ||
|
|
6da2dee62f | ||
|
|
324ebc1e06 | ||
|
|
ce2ebee838 | ||
|
|
0597573b6c | ||
|
|
9496b33a1c | ||
|
|
8e3aabd558 | ||
|
|
fbef81c0c9 | ||
|
|
226d2ef4a0 | ||
|
|
42f8a26ee1 | ||
|
|
8d021fe76c | ||
|
|
cb10907bf6 | ||
|
|
54084fe597 | ||
|
|
8f5d851908 | ||
|
|
358a21c6fc | ||
|
|
336fc43b24 | ||
|
|
cfb1613877 | ||
|
|
386eea741c | ||
|
|
e5c6809d9c | ||
|
|
963b8090cc | ||
|
|
eab93aba2b | ||
|
|
47a70cdbd0 | ||
|
|
69c9136060 | ||
|
|
6ed8bb4f14 | ||
|
|
6cf28e58d3 | ||
|
|
632ef24408 | ||
|
|
6dc767aafa | ||
|
|
23e37fd163 | ||
|
|
63869fe710 | ||
|
|
90ae75d475 | ||
|
|
9b6dc3be12 | ||
|
|
9b8b6252c5 | ||
|
|
0d321323f5 | ||
|
|
3ee3ea8f02 | ||
|
|
7a842d35ae | ||
|
|
07e8568f57 | ||
|
|
13a0caa5d8 | ||
|
|
664523a721 | ||
|
|
33b103d09b | ||
|
|
2e3fc99caa | ||
|
|
52c7b223df | ||
|
|
24d86fde30 | ||
|
|
df7be39724 | ||
|
|
8c7b1af409 | ||
|
|
b6e2f05b63 | ||
|
|
7435739053 | ||
|
|
a97fdba554 | ||
|
|
ec705bbbcf | ||
|
|
7fe6b576ae | ||
|
|
dfc42003a1 | ||
|
|
6bbeb22943 |
9
.github/workflows/platform-backend-ci.yml
vendored
9
.github/workflows/platform-backend-ci.yml
vendored
@@ -41,18 +41,13 @@ jobs:
|
|||||||
ports:
|
ports:
|
||||||
- 6379:6379
|
- 6379:6379
|
||||||
rabbitmq:
|
rabbitmq:
|
||||||
image: rabbitmq:4.1.4
|
image: rabbitmq:3.12-management
|
||||||
ports:
|
ports:
|
||||||
- 5672:5672
|
- 5672:5672
|
||||||
|
- 15672:15672
|
||||||
env:
|
env:
|
||||||
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
||||||
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
||||||
options: >-
|
|
||||||
--health-cmd "rabbitmq-diagnostics -q ping"
|
|
||||||
--health-interval 30s
|
|
||||||
--health-timeout 10s
|
|
||||||
--health-retries 5
|
|
||||||
--health-start-period 10s
|
|
||||||
clamav:
|
clamav:
|
||||||
image: clamav/clamav-debian:latest
|
image: clamav/clamav-debian:latest
|
||||||
ports:
|
ports:
|
||||||
|
|||||||
6
.github/workflows/platform-frontend-ci.yml
vendored
6
.github/workflows/platform-frontend-ci.yml
vendored
@@ -6,16 +6,10 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- ".github/workflows/platform-frontend-ci.yml"
|
- ".github/workflows/platform-frontend-ci.yml"
|
||||||
- "autogpt_platform/frontend/**"
|
- "autogpt_platform/frontend/**"
|
||||||
- "autogpt_platform/backend/Dockerfile"
|
|
||||||
- "autogpt_platform/docker-compose.yml"
|
|
||||||
- "autogpt_platform/docker-compose.platform.yml"
|
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- ".github/workflows/platform-frontend-ci.yml"
|
- ".github/workflows/platform-frontend-ci.yml"
|
||||||
- "autogpt_platform/frontend/**"
|
- "autogpt_platform/frontend/**"
|
||||||
- "autogpt_platform/backend/Dockerfile"
|
|
||||||
- "autogpt_platform/docker-compose.yml"
|
|
||||||
- "autogpt_platform/docker-compose.platform.yml"
|
|
||||||
merge_group:
|
merge_group:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
|||||||
@@ -53,6 +53,63 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
|
|||||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||||
|
|
||||||
|
# ============================== BACKEND SERVER ============================== #
|
||||||
|
|
||||||
|
FROM debian:13-slim AS server
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
ENV POETRY_HOME=/opt/poetry \
|
||||||
|
POETRY_NO_INTERACTION=1 \
|
||||||
|
POETRY_VIRTUALENVS_CREATE=true \
|
||||||
|
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||||
|
DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
|
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
||||||
|
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
||||||
|
# for the bash_exec MCP tool.
|
||||||
|
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
python3.13 \
|
||||||
|
python3-pip \
|
||||||
|
ffmpeg \
|
||||||
|
imagemagick \
|
||||||
|
jq \
|
||||||
|
ripgrep \
|
||||||
|
tree \
|
||||||
|
bubblewrap \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||||
|
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||||
|
# Copy Node.js installation for Prisma
|
||||||
|
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||||
|
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||||
|
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||||
|
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||||
|
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||||
|
|
||||||
|
WORKDIR /app/autogpt_platform/backend
|
||||||
|
|
||||||
|
# Copy only the .venv from builder (not the entire /app directory)
|
||||||
|
# The .venv includes the generated Prisma client
|
||||||
|
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
|
||||||
|
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||||
|
|
||||||
|
# Copy dependency files + autogpt_libs (path dependency)
|
||||||
|
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||||
|
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
|
||||||
|
|
||||||
|
# Copy backend code + docs (for Copilot docs search)
|
||||||
|
COPY autogpt_platform/backend ./
|
||||||
|
COPY docs /app/docs
|
||||||
|
RUN poetry install --no-ansi --only-root
|
||||||
|
|
||||||
|
ENV PORT=8000
|
||||||
|
|
||||||
|
CMD ["poetry", "run", "rest"]
|
||||||
|
|
||||||
# =============================== DB MIGRATOR =============================== #
|
# =============================== DB MIGRATOR =============================== #
|
||||||
|
|
||||||
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
|
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
|
||||||
@@ -84,59 +141,3 @@ COPY autogpt_platform/backend/schema.prisma ./
|
|||||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||||
COPY autogpt_platform/backend/migrations ./migrations
|
COPY autogpt_platform/backend/migrations ./migrations
|
||||||
|
|
||||||
# ============================== BACKEND SERVER ============================== #
|
|
||||||
|
|
||||||
FROM debian:13-slim AS server
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
|
||||||
|
|
||||||
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
|
||||||
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
|
||||||
# for the bash_exec MCP tool.
|
|
||||||
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
python3.13 \
|
|
||||||
python3-pip \
|
|
||||||
ffmpeg \
|
|
||||||
imagemagick \
|
|
||||||
jq \
|
|
||||||
ripgrep \
|
|
||||||
tree \
|
|
||||||
bubblewrap \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
|
|
||||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
|
||||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
|
||||||
# Copy Node.js installation for Prisma
|
|
||||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
|
||||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
|
||||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
|
||||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
|
||||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
|
||||||
|
|
||||||
WORKDIR /app/autogpt_platform/backend
|
|
||||||
|
|
||||||
# Copy only the .venv from builder (not the entire /app directory)
|
|
||||||
# The .venv includes the generated Prisma client
|
|
||||||
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
|
|
||||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
|
||||||
|
|
||||||
# Copy dependency files + autogpt_libs (path dependency)
|
|
||||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
|
||||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
|
|
||||||
|
|
||||||
# Copy backend code + docs (for Copilot docs search)
|
|
||||||
COPY autogpt_platform/backend ./
|
|
||||||
COPY docs /app/docs
|
|
||||||
# Install the project package to create entry point scripts in .venv/bin/
|
|
||||||
# (e.g., rest, executor, ws, db, scheduler, notification - see [tool.poetry.scripts])
|
|
||||||
RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
|
|
||||||
poetry install --no-ansi --only-root
|
|
||||||
|
|
||||||
ENV PORT=8000
|
|
||||||
|
|
||||||
CMD ["rest"]
|
|
||||||
|
|||||||
@@ -1,9 +1,4 @@
|
|||||||
"""Common test fixtures for server tests.
|
"""Common test fixtures for server tests."""
|
||||||
|
|
||||||
Note: Common fixtures like test_user_id, admin_user_id, target_user_id,
|
|
||||||
setup_test_user, and setup_admin_user are defined in the parent conftest.py
|
|
||||||
(backend/conftest.py) and are available here automatically.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest_snapshot.plugin import Snapshot
|
from pytest_snapshot.plugin import Snapshot
|
||||||
@@ -16,6 +11,54 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
|||||||
return snapshot
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_user_id() -> str:
|
||||||
|
"""Test user ID fixture."""
|
||||||
|
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_user_id() -> str:
|
||||||
|
"""Admin user ID fixture."""
|
||||||
|
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def target_user_id() -> str:
|
||||||
|
"""Target user ID fixture."""
|
||||||
|
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def setup_test_user(test_user_id):
|
||||||
|
"""Create test user in database before tests."""
|
||||||
|
from backend.data.user import get_or_create_user
|
||||||
|
|
||||||
|
# Create the test user in the database using JWT token format
|
||||||
|
user_data = {
|
||||||
|
"sub": test_user_id,
|
||||||
|
"email": "test@example.com",
|
||||||
|
"user_metadata": {"name": "Test User"},
|
||||||
|
}
|
||||||
|
await get_or_create_user(user_data)
|
||||||
|
return test_user_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def setup_admin_user(admin_user_id):
|
||||||
|
"""Create admin user in database before tests."""
|
||||||
|
from backend.data.user import get_or_create_user
|
||||||
|
|
||||||
|
# Create the admin user in the database using JWT token format
|
||||||
|
user_data = {
|
||||||
|
"sub": admin_user_id,
|
||||||
|
"email": "test-admin@example.com",
|
||||||
|
"user_metadata": {"name": "Test Admin"},
|
||||||
|
}
|
||||||
|
await get_or_create_user(user_data)
|
||||||
|
return admin_user_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_jwt_user(test_user_id):
|
def mock_jwt_user(test_user_id):
|
||||||
"""Provide mock JWT payload for regular user testing."""
|
"""Provide mock JWT payload for regular user testing."""
|
||||||
|
|||||||
@@ -122,6 +122,24 @@ class ConnectionManager:
|
|||||||
|
|
||||||
return len(connections)
|
return len(connections)
|
||||||
|
|
||||||
|
async def broadcast_to_all(self, *, method: WSMethod, data: dict) -> int:
|
||||||
|
"""Broadcast a message to all active websocket connections."""
|
||||||
|
message = WSMessage(
|
||||||
|
method=method,
|
||||||
|
data=data,
|
||||||
|
).model_dump_json()
|
||||||
|
|
||||||
|
connections = tuple(self.active_connections)
|
||||||
|
if not connections:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
await asyncio.gather(
|
||||||
|
*(connection.send_text(message) for connection in connections),
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(connections)
|
||||||
|
|
||||||
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
|
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
|
||||||
if channel_key not in self.subscriptions:
|
if channel_key not in self.subscriptions:
|
||||||
self.subscriptions[channel_key] = set()
|
self.subscriptions[channel_key] = set()
|
||||||
|
|||||||
@@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.api.external.middleware import require_permission
|
from backend.api.external.middleware import require_permission
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.copilot.tools import find_agent_tool, run_agent_tool
|
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
|
||||||
from backend.copilot.tools.models import ToolResponseBase
|
from backend.api.features.chat.tools.models import ToolResponseBase
|
||||||
from backend.data.auth.base import APIAuthorizationInfo
|
from backend.data.auth.base import APIAuthorizationInfo
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -176,30 +176,64 @@ async def get_execution_analytics_config(
|
|||||||
# Return with provider prefix for clarity
|
# Return with provider prefix for clarity
|
||||||
return f"{provider_name}: {model_name}"
|
return f"{provider_name}: {model_name}"
|
||||||
|
|
||||||
# Include all LlmModel values (no more filtering by hardcoded list)
|
# Get all models from the registry (dynamic, not hardcoded enum)
|
||||||
recommended_model = LlmModel.GPT4O_MINI.value
|
from backend.data import llm_registry
|
||||||
for model in LlmModel:
|
from backend.server.v2.llm import db as llm_db
|
||||||
|
|
||||||
|
# Get the recommended model from the database (configurable via admin UI)
|
||||||
|
recommended_model_slug = await llm_db.get_recommended_model_slug()
|
||||||
|
|
||||||
|
# Build the available models list
|
||||||
|
first_enabled_slug = None
|
||||||
|
for registry_model in llm_registry.iter_dynamic_models():
|
||||||
|
# Only include enabled models in the list
|
||||||
|
if not registry_model.is_enabled:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Track first enabled model as fallback
|
||||||
|
if first_enabled_slug is None:
|
||||||
|
first_enabled_slug = registry_model.slug
|
||||||
|
|
||||||
|
model = LlmModel(registry_model.slug)
|
||||||
label = generate_model_label(model)
|
label = generate_model_label(model)
|
||||||
# Add "(Recommended)" suffix to the recommended model
|
# Add "(Recommended)" suffix to the recommended model
|
||||||
if model.value == recommended_model:
|
if registry_model.slug == recommended_model_slug:
|
||||||
label += " (Recommended)"
|
label += " (Recommended)"
|
||||||
|
|
||||||
available_models.append(
|
available_models.append(
|
||||||
ModelInfo(
|
ModelInfo(
|
||||||
value=model.value,
|
value=registry_model.slug,
|
||||||
label=label,
|
label=label,
|
||||||
provider=model.provider,
|
provider=registry_model.metadata.provider,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sort models by provider and name for better UX
|
# Sort models by provider and name for better UX
|
||||||
available_models.sort(key=lambda x: (x.provider, x.label))
|
available_models.sort(key=lambda x: (x.provider, x.label))
|
||||||
|
|
||||||
|
# Handle case where no models are available
|
||||||
|
if not available_models:
|
||||||
|
logger.warning(
|
||||||
|
"No enabled LLM models found in registry. "
|
||||||
|
"Ensure models are configured and enabled in the LLM Registry."
|
||||||
|
)
|
||||||
|
# Provide a placeholder entry so admins see meaningful feedback
|
||||||
|
available_models.append(
|
||||||
|
ModelInfo(
|
||||||
|
value="",
|
||||||
|
label="No models available - configure in LLM Registry",
|
||||||
|
provider="none",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the DB recommended model, or fallback to first enabled model
|
||||||
|
final_recommended = recommended_model_slug or first_enabled_slug or ""
|
||||||
|
|
||||||
return ExecutionAnalyticsConfig(
|
return ExecutionAnalyticsConfig(
|
||||||
available_models=available_models,
|
available_models=available_models,
|
||||||
default_system_prompt=DEFAULT_SYSTEM_PROMPT,
|
default_system_prompt=DEFAULT_SYSTEM_PROMPT,
|
||||||
default_user_prompt=DEFAULT_USER_PROMPT,
|
default_user_prompt=DEFAULT_USER_PROMPT,
|
||||||
recommended_model=recommended_model,
|
recommended_model=final_recommended,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,593 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import autogpt_libs.auth
|
||||||
|
import fastapi
|
||||||
|
|
||||||
|
from backend.data import llm_registry
|
||||||
|
from backend.data.block_cost_config import refresh_llm_costs
|
||||||
|
from backend.server.v2.llm import db as llm_db
|
||||||
|
from backend.server.v2.llm import model as llm_model
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = fastapi.APIRouter(
|
||||||
|
tags=["llm", "admin"],
|
||||||
|
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _refresh_runtime_state() -> None:
|
||||||
|
"""Refresh the LLM registry and clear all related caches to ensure real-time updates."""
|
||||||
|
logger.info("Refreshing LLM registry runtime state...")
|
||||||
|
try:
|
||||||
|
# Refresh registry from database
|
||||||
|
await llm_registry.refresh_llm_registry()
|
||||||
|
await refresh_llm_costs()
|
||||||
|
|
||||||
|
# Clear block schema caches so they're regenerated with updated model options
|
||||||
|
from backend.blocks._base import BlockSchema
|
||||||
|
|
||||||
|
BlockSchema.clear_all_schema_caches()
|
||||||
|
logger.info("Cleared all block schema caches")
|
||||||
|
|
||||||
|
# Clear the /blocks endpoint cache so frontend gets updated schemas
|
||||||
|
try:
|
||||||
|
from backend.api.features.v1 import _get_cached_blocks
|
||||||
|
|
||||||
|
_get_cached_blocks.cache_clear()
|
||||||
|
logger.info("Cleared /blocks endpoint cache")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to clear /blocks cache: %s", e)
|
||||||
|
|
||||||
|
# Clear the v2 builder caches
|
||||||
|
try:
|
||||||
|
from backend.api.features.builder import db as builder_db
|
||||||
|
|
||||||
|
builder_db._get_all_providers.cache_clear()
|
||||||
|
logger.info("Cleared v2 builder providers cache")
|
||||||
|
builder_db._build_cached_search_results.cache_clear()
|
||||||
|
logger.info("Cleared v2 builder search results cache")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Could not clear v2 builder cache: %s", e)
|
||||||
|
|
||||||
|
# Notify all executor services to refresh their registry cache
|
||||||
|
from backend.data.llm_registry import publish_registry_refresh_notification
|
||||||
|
|
||||||
|
await publish_registry_refresh_notification()
|
||||||
|
logger.info("Published registry refresh notification")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception(
|
||||||
|
"LLM runtime state refresh failed; caches may be stale: %s", exc
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/providers",
|
||||||
|
summary="List LLM providers",
|
||||||
|
response_model=llm_model.LlmProvidersResponse,
|
||||||
|
)
|
||||||
|
async def list_llm_providers(include_models: bool = True):
|
||||||
|
providers = await llm_db.list_providers(include_models=include_models)
|
||||||
|
return llm_model.LlmProvidersResponse(providers=providers)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/providers",
|
||||||
|
summary="Create LLM provider",
|
||||||
|
response_model=llm_model.LlmProvider,
|
||||||
|
)
|
||||||
|
async def create_llm_provider(request: llm_model.UpsertLlmProviderRequest):
|
||||||
|
provider = await llm_db.upsert_provider(request=request)
|
||||||
|
await _refresh_runtime_state()
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/providers/{provider_id}",
|
||||||
|
summary="Update LLM provider",
|
||||||
|
response_model=llm_model.LlmProvider,
|
||||||
|
)
|
||||||
|
async def update_llm_provider(
|
||||||
|
provider_id: str,
|
||||||
|
request: llm_model.UpsertLlmProviderRequest,
|
||||||
|
):
|
||||||
|
provider = await llm_db.upsert_provider(request=request, provider_id=provider_id)
|
||||||
|
await _refresh_runtime_state()
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/providers/{provider_id}",
|
||||||
|
summary="Delete LLM provider",
|
||||||
|
response_model=dict,
|
||||||
|
)
|
||||||
|
async def delete_llm_provider(provider_id: str):
|
||||||
|
"""
|
||||||
|
Delete an LLM provider.
|
||||||
|
|
||||||
|
A provider can only be deleted if it has no associated models.
|
||||||
|
Delete all models from the provider first before deleting the provider.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await llm_db.delete_provider(provider_id)
|
||||||
|
await _refresh_runtime_state()
|
||||||
|
logger.info("Deleted LLM provider '%s'", provider_id)
|
||||||
|
return {"success": True, "message": "Provider deleted successfully"}
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning("Failed to delete provider '%s': %s", provider_id, e)
|
||||||
|
raise fastapi.HTTPException(status_code=400, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to delete provider '%s': %s", provider_id, e)
|
||||||
|
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/models",
|
||||||
|
summary="List LLM models",
|
||||||
|
response_model=llm_model.LlmModelsResponse,
|
||||||
|
)
|
||||||
|
async def list_llm_models(
|
||||||
|
provider_id: str | None = fastapi.Query(default=None),
|
||||||
|
page: int = fastapi.Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||||
|
page_size: int = fastapi.Query(
|
||||||
|
default=50, ge=1, le=100, description="Number of models per page"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
return await llm_db.list_models(
|
||||||
|
provider_id=provider_id, page=page, page_size=page_size
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/models",
|
||||||
|
summary="Create LLM model",
|
||||||
|
response_model=llm_model.LlmModel,
|
||||||
|
)
|
||||||
|
async def create_llm_model(request: llm_model.CreateLlmModelRequest):
|
||||||
|
model = await llm_db.create_model(request=request)
|
||||||
|
await _refresh_runtime_state()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/models/{model_id}",
|
||||||
|
summary="Update LLM model",
|
||||||
|
response_model=llm_model.LlmModel,
|
||||||
|
)
|
||||||
|
async def update_llm_model(
|
||||||
|
model_id: str,
|
||||||
|
request: llm_model.UpdateLlmModelRequest,
|
||||||
|
):
|
||||||
|
model = await llm_db.update_model(model_id=model_id, request=request)
|
||||||
|
await _refresh_runtime_state()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/models/{model_id}/toggle",
|
||||||
|
summary="Toggle LLM model availability",
|
||||||
|
response_model=llm_model.ToggleLlmModelResponse,
|
||||||
|
)
|
||||||
|
async def toggle_llm_model(
|
||||||
|
model_id: str,
|
||||||
|
request: llm_model.ToggleLlmModelRequest,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Toggle a model's enabled status, optionally migrating workflows when disabling.
|
||||||
|
|
||||||
|
If disabling a model and `migrate_to_slug` is provided, all workflows using
|
||||||
|
this model will be migrated to the specified replacement model before disabling.
|
||||||
|
A migration record is created which can be reverted later using the revert endpoint.
|
||||||
|
|
||||||
|
Optional fields:
|
||||||
|
- `migration_reason`: Reason for the migration (e.g., "Provider outage")
|
||||||
|
- `custom_credit_cost`: Custom pricing override for billing during migration
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await llm_db.toggle_model(
|
||||||
|
model_id=model_id,
|
||||||
|
is_enabled=request.is_enabled,
|
||||||
|
migrate_to_slug=request.migrate_to_slug,
|
||||||
|
migration_reason=request.migration_reason,
|
||||||
|
custom_credit_cost=request.custom_credit_cost,
|
||||||
|
)
|
||||||
|
await _refresh_runtime_state()
|
||||||
|
if result.nodes_migrated > 0:
|
||||||
|
logger.info(
|
||||||
|
"Toggled model '%s' to %s and migrated %d nodes to '%s' (migration_id=%s)",
|
||||||
|
result.model.slug,
|
||||||
|
"enabled" if request.is_enabled else "disabled",
|
||||||
|
result.nodes_migrated,
|
||||||
|
result.migrated_to_slug,
|
||||||
|
result.migration_id,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.warning("Model toggle validation failed: %s", exc)
|
||||||
|
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to toggle LLM model %s: %s", model_id, exc)
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to toggle model availability",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/models/{model_id}/usage",
|
||||||
|
summary="Get model usage count",
|
||||||
|
response_model=llm_model.LlmModelUsageResponse,
|
||||||
|
)
|
||||||
|
async def get_llm_model_usage(model_id: str):
|
||||||
|
"""Get the number of workflow nodes using this model."""
|
||||||
|
try:
|
||||||
|
return await llm_db.get_model_usage(model_id=model_id)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to get model usage %s: %s", model_id, exc)
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to get model usage",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/models/{model_id}",
|
||||||
|
summary="Delete LLM model and migrate workflows",
|
||||||
|
response_model=llm_model.DeleteLlmModelResponse,
|
||||||
|
)
|
||||||
|
async def delete_llm_model(
|
||||||
|
model_id: str,
|
||||||
|
replacement_model_slug: str | None = fastapi.Query(
|
||||||
|
default=None,
|
||||||
|
description="Slug of the model to migrate existing workflows to (required only if workflows use this model)",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete a model and optionally migrate workflows using it to a replacement model.
|
||||||
|
|
||||||
|
If no workflows are using this model, it can be deleted without providing a
|
||||||
|
replacement. If workflows exist, replacement_model_slug is required.
|
||||||
|
|
||||||
|
This endpoint:
|
||||||
|
1. Counts how many workflow nodes use the model being deleted
|
||||||
|
2. If nodes exist, validates the replacement model and migrates them
|
||||||
|
3. Deletes the model record
|
||||||
|
4. Refreshes all caches and notifies executors
|
||||||
|
|
||||||
|
Example: DELETE /api/llm/admin/models/{id}?replacement_model_slug=gpt-4o
|
||||||
|
Example (no usage): DELETE /api/llm/admin/models/{id}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await llm_db.delete_model(
|
||||||
|
model_id=model_id, replacement_model_slug=replacement_model_slug
|
||||||
|
)
|
||||||
|
await _refresh_runtime_state()
|
||||||
|
logger.info(
|
||||||
|
"Deleted model '%s' and migrated %d nodes to '%s'",
|
||||||
|
result.deleted_model_slug,
|
||||||
|
result.nodes_migrated,
|
||||||
|
result.replacement_model_slug,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except ValueError as exc:
|
||||||
|
# Validation errors (model not found, replacement invalid, etc.)
|
||||||
|
logger.warning("Model deletion validation failed: %s", exc)
|
||||||
|
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to delete LLM model %s: %s", model_id, exc)
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to delete model and migrate workflows",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Migration Management Endpoints
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/migrations",
|
||||||
|
summary="List model migrations",
|
||||||
|
response_model=llm_model.LlmMigrationsResponse,
|
||||||
|
)
|
||||||
|
async def list_llm_migrations(
|
||||||
|
include_reverted: bool = fastapi.Query(
|
||||||
|
default=False, description="Include reverted migrations in the list"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List all model migrations.
|
||||||
|
|
||||||
|
Migrations are created when disabling a model with the migrate_to_slug option.
|
||||||
|
They can be reverted to restore the original model configuration.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
migrations = await llm_db.list_migrations(include_reverted=include_reverted)
|
||||||
|
return llm_model.LlmMigrationsResponse(migrations=migrations)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to list migrations: %s", exc)
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to list migrations",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/migrations/{migration_id}",
|
||||||
|
summary="Get migration details",
|
||||||
|
response_model=llm_model.LlmModelMigration,
|
||||||
|
)
|
||||||
|
async def get_llm_migration(migration_id: str):
|
||||||
|
"""Get details of a specific migration."""
|
||||||
|
try:
|
||||||
|
migration = await llm_db.get_migration(migration_id)
|
||||||
|
if not migration:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=404, detail=f"Migration '{migration_id}' not found"
|
||||||
|
)
|
||||||
|
return migration
|
||||||
|
except fastapi.HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to get migration %s: %s", migration_id, exc)
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to get migration",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/migrations/{migration_id}/revert",
|
||||||
|
summary="Revert a model migration",
|
||||||
|
response_model=llm_model.RevertMigrationResponse,
|
||||||
|
)
|
||||||
|
async def revert_llm_migration(
|
||||||
|
migration_id: str,
|
||||||
|
request: llm_model.RevertMigrationRequest | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Revert a model migration, restoring affected workflows to their original model.
|
||||||
|
|
||||||
|
This only reverts the specific nodes that were part of the migration.
|
||||||
|
The source model must exist for the revert to succeed.
|
||||||
|
|
||||||
|
Options:
|
||||||
|
- `re_enable_source_model`: Whether to re-enable the source model if disabled (default: True)
|
||||||
|
|
||||||
|
Response includes:
|
||||||
|
- `nodes_reverted`: Number of nodes successfully reverted
|
||||||
|
- `nodes_already_changed`: Number of nodes that were modified since migration (not reverted)
|
||||||
|
- `source_model_re_enabled`: Whether the source model was re-enabled
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- Migration must not already be reverted
|
||||||
|
- Source model must exist
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
re_enable = request.re_enable_source_model if request else True
|
||||||
|
result = await llm_db.revert_migration(
|
||||||
|
migration_id,
|
||||||
|
re_enable_source_model=re_enable,
|
||||||
|
)
|
||||||
|
await _refresh_runtime_state()
|
||||||
|
logger.info(
|
||||||
|
"Reverted migration '%s': %d nodes restored from '%s' to '%s' "
|
||||||
|
"(%d already changed, source re-enabled=%s)",
|
||||||
|
migration_id,
|
||||||
|
result.nodes_reverted,
|
||||||
|
result.target_model_slug,
|
||||||
|
result.source_model_slug,
|
||||||
|
result.nodes_already_changed,
|
||||||
|
result.source_model_re_enabled,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.warning("Migration revert validation failed: %s", exc)
|
||||||
|
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to revert migration %s: %s", migration_id, exc)
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to revert migration",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Creator Management Endpoints
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/creators",
|
||||||
|
summary="List model creators",
|
||||||
|
response_model=llm_model.LlmCreatorsResponse,
|
||||||
|
)
|
||||||
|
async def list_llm_creators():
|
||||||
|
"""
|
||||||
|
List all model creators.
|
||||||
|
|
||||||
|
Creators are organizations that create/train models (e.g., OpenAI, Meta, Anthropic).
|
||||||
|
This is distinct from providers who host/serve the models (e.g., OpenRouter).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
creators = await llm_db.list_creators()
|
||||||
|
return llm_model.LlmCreatorsResponse(creators=creators)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to list creators: %s", exc)
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to list creators",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/creators/{creator_id}",
|
||||||
|
summary="Get creator details",
|
||||||
|
response_model=llm_model.LlmModelCreator,
|
||||||
|
)
|
||||||
|
async def get_llm_creator(creator_id: str):
|
||||||
|
"""Get details of a specific model creator."""
|
||||||
|
try:
|
||||||
|
creator = await llm_db.get_creator(creator_id)
|
||||||
|
if not creator:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=404, detail=f"Creator '{creator_id}' not found"
|
||||||
|
)
|
||||||
|
return creator
|
||||||
|
except fastapi.HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to get creator %s: %s", creator_id, exc)
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to get creator",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/creators",
|
||||||
|
summary="Create model creator",
|
||||||
|
response_model=llm_model.LlmModelCreator,
|
||||||
|
)
|
||||||
|
async def create_llm_creator(request: llm_model.UpsertLlmCreatorRequest):
|
||||||
|
"""
|
||||||
|
Create a new model creator.
|
||||||
|
|
||||||
|
A creator represents an organization that creates/trains AI models,
|
||||||
|
such as OpenAI, Anthropic, Meta, or Google.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
creator = await llm_db.upsert_creator(request=request)
|
||||||
|
await _refresh_runtime_state()
|
||||||
|
logger.info("Created model creator '%s' (%s)", creator.display_name, creator.id)
|
||||||
|
return creator
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to create creator: %s", exc)
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to create creator",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/creators/{creator_id}",
|
||||||
|
summary="Update model creator",
|
||||||
|
response_model=llm_model.LlmModelCreator,
|
||||||
|
)
|
||||||
|
async def update_llm_creator(
|
||||||
|
creator_id: str,
|
||||||
|
request: llm_model.UpsertLlmCreatorRequest,
|
||||||
|
):
|
||||||
|
"""Update an existing model creator."""
|
||||||
|
try:
|
||||||
|
creator = await llm_db.upsert_creator(request=request, creator_id=creator_id)
|
||||||
|
await _refresh_runtime_state()
|
||||||
|
logger.info("Updated model creator '%s' (%s)", creator.display_name, creator_id)
|
||||||
|
return creator
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to update creator %s: %s", creator_id, exc)
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to update creator",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/creators/{creator_id}",
|
||||||
|
summary="Delete model creator",
|
||||||
|
response_model=dict,
|
||||||
|
)
|
||||||
|
async def delete_llm_creator(creator_id: str):
|
||||||
|
"""
|
||||||
|
Delete a model creator.
|
||||||
|
|
||||||
|
This will remove the creator association from all models that reference it
|
||||||
|
(sets creatorId to NULL), but will not delete the models themselves.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await llm_db.delete_creator(creator_id)
|
||||||
|
await _refresh_runtime_state()
|
||||||
|
logger.info("Deleted model creator '%s'", creator_id)
|
||||||
|
return {"success": True, "message": f"Creator '{creator_id}' deleted"}
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.warning("Creator deletion validation failed: %s", exc)
|
||||||
|
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to delete creator %s: %s", creator_id, exc)
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to delete creator",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Recommended Model Endpoints
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/recommended-model",
|
||||||
|
summary="Get recommended model",
|
||||||
|
response_model=llm_model.RecommendedModelResponse,
|
||||||
|
)
|
||||||
|
async def get_recommended_model():
|
||||||
|
"""
|
||||||
|
Get the currently recommended LLM model.
|
||||||
|
|
||||||
|
The recommended model is shown to users as the default/suggested option
|
||||||
|
in model selection dropdowns.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model = await llm_db.get_recommended_model()
|
||||||
|
return llm_model.RecommendedModelResponse(
|
||||||
|
model=model,
|
||||||
|
slug=model.slug if model else None,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to get recommended model: %s", exc)
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to get recommended model",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/recommended-model",
|
||||||
|
summary="Set recommended model",
|
||||||
|
response_model=llm_model.SetRecommendedModelResponse,
|
||||||
|
)
|
||||||
|
async def set_recommended_model(request: llm_model.SetRecommendedModelRequest):
|
||||||
|
"""
|
||||||
|
Set a model as the recommended model.
|
||||||
|
|
||||||
|
This clears the recommended flag from any other model and sets it on
|
||||||
|
the specified model. The model must be enabled to be set as recommended.
|
||||||
|
|
||||||
|
The recommended model is displayed to users as the default/suggested
|
||||||
|
option in model selection dropdowns throughout the platform.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model, previous_slug = await llm_db.set_recommended_model(request.model_id)
|
||||||
|
await _refresh_runtime_state()
|
||||||
|
logger.info(
|
||||||
|
"Set recommended model to '%s' (previous: %s)",
|
||||||
|
model.slug,
|
||||||
|
previous_slug or "none",
|
||||||
|
)
|
||||||
|
return llm_model.SetRecommendedModelResponse(
|
||||||
|
model=model,
|
||||||
|
previous_recommended_slug=previous_slug,
|
||||||
|
message=f"Model '{model.display_name}' is now the recommended model",
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.warning("Set recommended model validation failed: %s", exc)
|
||||||
|
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to set recommended model: %s", exc)
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to set recommended model",
|
||||||
|
) from exc
|
||||||
@@ -0,0 +1,491 @@
|
|||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
import fastapi.testclient
|
||||||
|
import pytest
|
||||||
|
import pytest_mock
|
||||||
|
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||||
|
from pytest_snapshot.plugin import Snapshot
|
||||||
|
|
||||||
|
import backend.api.features.admin.llm_routes as llm_routes
|
||||||
|
from backend.server.v2.llm import model as llm_model
|
||||||
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
app.include_router(llm_routes.router, prefix="/admin/llm")
|
||||||
|
|
||||||
|
client = fastapi.testclient.TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_app_admin_auth(mock_jwt_admin):
|
||||||
|
"""Setup admin auth overrides for all tests in this module"""
|
||||||
|
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_llm_providers_success(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
configured_snapshot: Snapshot,
|
||||||
|
) -> None:
|
||||||
|
"""Test successful listing of LLM providers"""
|
||||||
|
# Mock the database function
|
||||||
|
mock_providers = [
|
||||||
|
{
|
||||||
|
"id": "provider-1",
|
||||||
|
"name": "openai",
|
||||||
|
"display_name": "OpenAI",
|
||||||
|
"description": "OpenAI LLM provider",
|
||||||
|
"supports_tools": True,
|
||||||
|
"supports_json_output": True,
|
||||||
|
"supports_reasoning": False,
|
||||||
|
"supports_parallel_tool": True,
|
||||||
|
"metadata": {},
|
||||||
|
"models": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "provider-2",
|
||||||
|
"name": "anthropic",
|
||||||
|
"display_name": "Anthropic",
|
||||||
|
"description": "Anthropic LLM provider",
|
||||||
|
"supports_tools": True,
|
||||||
|
"supports_json_output": True,
|
||||||
|
"supports_reasoning": False,
|
||||||
|
"supports_parallel_tool": True,
|
||||||
|
"metadata": {},
|
||||||
|
"models": [],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.llm_routes.llm_db.list_providers",
|
||||||
|
new=AsyncMock(return_value=mock_providers),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/admin/llm/providers")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
response_data = response.json()
|
||||||
|
assert len(response_data["providers"]) == 2
|
||||||
|
assert response_data["providers"][0]["name"] == "openai"
|
||||||
|
|
||||||
|
# Snapshot test the response (must be string)
|
||||||
|
configured_snapshot.assert_match(
|
||||||
|
json.dumps(response_data, indent=2, sort_keys=True),
|
||||||
|
"list_llm_providers_success.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_llm_models_success(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
configured_snapshot: Snapshot,
|
||||||
|
) -> None:
|
||||||
|
"""Test successful listing of LLM models with pagination"""
|
||||||
|
# Mock the database function - now returns LlmModelsResponse
|
||||||
|
mock_model = llm_model.LlmModel(
|
||||||
|
id="model-1",
|
||||||
|
slug="gpt-4o",
|
||||||
|
display_name="GPT-4o",
|
||||||
|
description="GPT-4 Optimized",
|
||||||
|
provider_id="provider-1",
|
||||||
|
context_window=128000,
|
||||||
|
max_output_tokens=16384,
|
||||||
|
is_enabled=True,
|
||||||
|
capabilities={},
|
||||||
|
metadata={},
|
||||||
|
costs=[
|
||||||
|
llm_model.LlmModelCost(
|
||||||
|
id="cost-1",
|
||||||
|
credit_cost=10,
|
||||||
|
credential_provider="openai",
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response = llm_model.LlmModelsResponse(
|
||||||
|
models=[mock_model],
|
||||||
|
pagination=Pagination(
|
||||||
|
total_items=1,
|
||||||
|
total_pages=1,
|
||||||
|
current_page=1,
|
||||||
|
page_size=50,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.llm_routes.llm_db.list_models",
|
||||||
|
new=AsyncMock(return_value=mock_response),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/admin/llm/models")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
response_data = response.json()
|
||||||
|
assert len(response_data["models"]) == 1
|
||||||
|
assert response_data["models"][0]["slug"] == "gpt-4o"
|
||||||
|
assert response_data["pagination"]["total_items"] == 1
|
||||||
|
assert response_data["pagination"]["page_size"] == 50
|
||||||
|
|
||||||
|
# Snapshot test the response (must be string)
|
||||||
|
configured_snapshot.assert_match(
|
||||||
|
json.dumps(response_data, indent=2, sort_keys=True),
|
||||||
|
"list_llm_models_success.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_provider_success(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
configured_snapshot: Snapshot,
|
||||||
|
) -> None:
|
||||||
|
"""Test successful creation of LLM provider"""
|
||||||
|
mock_provider = {
|
||||||
|
"id": "new-provider-id",
|
||||||
|
"name": "groq",
|
||||||
|
"display_name": "Groq",
|
||||||
|
"description": "Groq LLM provider",
|
||||||
|
"supports_tools": True,
|
||||||
|
"supports_json_output": True,
|
||||||
|
"supports_reasoning": False,
|
||||||
|
"supports_parallel_tool": False,
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.llm_routes.llm_db.upsert_provider",
|
||||||
|
new=AsyncMock(return_value=mock_provider),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_refresh = mocker.patch(
|
||||||
|
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||||
|
new=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"name": "groq",
|
||||||
|
"display_name": "Groq",
|
||||||
|
"description": "Groq LLM provider",
|
||||||
|
"supports_tools": True,
|
||||||
|
"supports_json_output": True,
|
||||||
|
"supports_reasoning": False,
|
||||||
|
"supports_parallel_tool": False,
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/admin/llm/providers", json=request_data)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
response_data = response.json()
|
||||||
|
assert response_data["name"] == "groq"
|
||||||
|
assert response_data["display_name"] == "Groq"
|
||||||
|
|
||||||
|
# Verify refresh was called
|
||||||
|
mock_refresh.assert_called_once()
|
||||||
|
|
||||||
|
# Snapshot test the response (must be string)
|
||||||
|
configured_snapshot.assert_match(
|
||||||
|
json.dumps(response_data, indent=2, sort_keys=True),
|
||||||
|
"create_llm_provider_success.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_model_success(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
configured_snapshot: Snapshot,
|
||||||
|
) -> None:
|
||||||
|
"""Test successful creation of LLM model"""
|
||||||
|
mock_model = {
|
||||||
|
"id": "new-model-id",
|
||||||
|
"slug": "gpt-4.1-mini",
|
||||||
|
"display_name": "GPT-4.1 Mini",
|
||||||
|
"description": "Latest GPT-4.1 Mini model",
|
||||||
|
"provider_id": "provider-1",
|
||||||
|
"context_window": 128000,
|
||||||
|
"max_output_tokens": 16384,
|
||||||
|
"is_enabled": True,
|
||||||
|
"capabilities": {},
|
||||||
|
"metadata": {},
|
||||||
|
"costs": [
|
||||||
|
{
|
||||||
|
"id": "cost-id",
|
||||||
|
"credit_cost": 5,
|
||||||
|
"credential_provider": "openai",
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.llm_routes.llm_db.create_model",
|
||||||
|
new=AsyncMock(return_value=mock_model),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_refresh = mocker.patch(
|
||||||
|
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||||
|
new=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"slug": "gpt-4.1-mini",
|
||||||
|
"display_name": "GPT-4.1 Mini",
|
||||||
|
"description": "Latest GPT-4.1 Mini model",
|
||||||
|
"provider_id": "provider-1",
|
||||||
|
"context_window": 128000,
|
||||||
|
"max_output_tokens": 16384,
|
||||||
|
"is_enabled": True,
|
||||||
|
"capabilities": {},
|
||||||
|
"metadata": {},
|
||||||
|
"costs": [
|
||||||
|
{
|
||||||
|
"credit_cost": 5,
|
||||||
|
"credential_provider": "openai",
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/admin/llm/models", json=request_data)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
response_data = response.json()
|
||||||
|
assert response_data["slug"] == "gpt-4.1-mini"
|
||||||
|
assert response_data["is_enabled"] is True
|
||||||
|
|
||||||
|
# Verify refresh was called
|
||||||
|
mock_refresh.assert_called_once()
|
||||||
|
|
||||||
|
# Snapshot test the response (must be string)
|
||||||
|
configured_snapshot.assert_match(
|
||||||
|
json.dumps(response_data, indent=2, sort_keys=True),
|
||||||
|
"create_llm_model_success.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_llm_model_success(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
configured_snapshot: Snapshot,
|
||||||
|
) -> None:
|
||||||
|
"""Test successful update of LLM model"""
|
||||||
|
mock_model = {
|
||||||
|
"id": "model-1",
|
||||||
|
"slug": "gpt-4o",
|
||||||
|
"display_name": "GPT-4o Updated",
|
||||||
|
"description": "Updated description",
|
||||||
|
"provider_id": "provider-1",
|
||||||
|
"context_window": 256000,
|
||||||
|
"max_output_tokens": 32768,
|
||||||
|
"is_enabled": True,
|
||||||
|
"capabilities": {},
|
||||||
|
"metadata": {},
|
||||||
|
"costs": [
|
||||||
|
{
|
||||||
|
"id": "cost-1",
|
||||||
|
"credit_cost": 15,
|
||||||
|
"credential_provider": "openai",
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.llm_routes.llm_db.update_model",
|
||||||
|
new=AsyncMock(return_value=mock_model),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_refresh = mocker.patch(
|
||||||
|
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||||
|
new=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"display_name": "GPT-4o Updated",
|
||||||
|
"description": "Updated description",
|
||||||
|
"context_window": 256000,
|
||||||
|
"max_output_tokens": 32768,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.patch("/admin/llm/models/model-1", json=request_data)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
response_data = response.json()
|
||||||
|
assert response_data["display_name"] == "GPT-4o Updated"
|
||||||
|
assert response_data["context_window"] == 256000
|
||||||
|
|
||||||
|
# Verify refresh was called
|
||||||
|
mock_refresh.assert_called_once()
|
||||||
|
|
||||||
|
# Snapshot test the response (must be string)
|
||||||
|
configured_snapshot.assert_match(
|
||||||
|
json.dumps(response_data, indent=2, sort_keys=True),
|
||||||
|
"update_llm_model_success.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_toggle_llm_model_success(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
configured_snapshot: Snapshot,
|
||||||
|
) -> None:
|
||||||
|
"""Test successful toggling of LLM model enabled status"""
|
||||||
|
# Create a proper mock model object
|
||||||
|
mock_model = llm_model.LlmModel(
|
||||||
|
id="model-1",
|
||||||
|
slug="gpt-4o",
|
||||||
|
display_name="GPT-4o",
|
||||||
|
description="GPT-4 Optimized",
|
||||||
|
provider_id="provider-1",
|
||||||
|
context_window=128000,
|
||||||
|
max_output_tokens=16384,
|
||||||
|
is_enabled=False,
|
||||||
|
capabilities={},
|
||||||
|
metadata={},
|
||||||
|
costs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a proper ToggleLlmModelResponse
|
||||||
|
mock_response = llm_model.ToggleLlmModelResponse(
|
||||||
|
model=mock_model,
|
||||||
|
nodes_migrated=0,
|
||||||
|
migrated_to_slug=None,
|
||||||
|
migration_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.llm_routes.llm_db.toggle_model",
|
||||||
|
new=AsyncMock(return_value=mock_response),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_refresh = mocker.patch(
|
||||||
|
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||||
|
new=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
request_data = {"is_enabled": False}
|
||||||
|
|
||||||
|
response = client.patch("/admin/llm/models/model-1/toggle", json=request_data)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
response_data = response.json()
|
||||||
|
assert response_data["model"]["is_enabled"] is False
|
||||||
|
|
||||||
|
# Verify refresh was called
|
||||||
|
mock_refresh.assert_called_once()
|
||||||
|
|
||||||
|
# Snapshot test the response (must be string)
|
||||||
|
configured_snapshot.assert_match(
|
||||||
|
json.dumps(response_data, indent=2, sort_keys=True),
|
||||||
|
"toggle_llm_model_success.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_llm_model_success(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
configured_snapshot: Snapshot,
|
||||||
|
) -> None:
|
||||||
|
"""Test successful deletion of LLM model with migration"""
|
||||||
|
# Create a proper DeleteLlmModelResponse
|
||||||
|
mock_response = llm_model.DeleteLlmModelResponse(
|
||||||
|
deleted_model_slug="gpt-3.5-turbo",
|
||||||
|
deleted_model_display_name="GPT-3.5 Turbo",
|
||||||
|
replacement_model_slug="gpt-4o-mini",
|
||||||
|
nodes_migrated=42,
|
||||||
|
message="Successfully deleted model 'GPT-3.5 Turbo' (gpt-3.5-turbo) "
|
||||||
|
"and migrated 42 workflow node(s) to 'gpt-4o-mini'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||||
|
new=AsyncMock(return_value=mock_response),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_refresh = mocker.patch(
|
||||||
|
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||||
|
new=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.delete(
|
||||||
|
"/admin/llm/models/model-1?replacement_model_slug=gpt-4o-mini"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
response_data = response.json()
|
||||||
|
assert response_data["deleted_model_slug"] == "gpt-3.5-turbo"
|
||||||
|
assert response_data["nodes_migrated"] == 42
|
||||||
|
assert response_data["replacement_model_slug"] == "gpt-4o-mini"
|
||||||
|
|
||||||
|
# Verify refresh was called
|
||||||
|
mock_refresh.assert_called_once()
|
||||||
|
|
||||||
|
# Snapshot test the response (must be string)
|
||||||
|
configured_snapshot.assert_match(
|
||||||
|
json.dumps(response_data, indent=2, sort_keys=True),
|
||||||
|
"delete_llm_model_success.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_llm_model_validation_error(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test deletion fails with proper error when validation fails"""
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||||
|
new=AsyncMock(side_effect=ValueError("Replacement model 'invalid' not found")),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.delete("/admin/llm/models/model-1?replacement_model_slug=invalid")
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Replacement model 'invalid' not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_llm_model_no_replacement_with_usage(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test deletion fails when nodes exist but no replacement is provided"""
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||||
|
new=AsyncMock(
|
||||||
|
side_effect=ValueError(
|
||||||
|
"Cannot delete model 'test-model': 5 workflow node(s) are using it. "
|
||||||
|
"Please provide a replacement_model_slug to migrate them."
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.delete("/admin/llm/models/model-1")
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "workflow node(s) are using it" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_llm_model_no_replacement_no_usage(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test deletion succeeds when no nodes use the model and no replacement is provided"""
|
||||||
|
mock_response = llm_model.DeleteLlmModelResponse(
|
||||||
|
deleted_model_slug="unused-model",
|
||||||
|
deleted_model_display_name="Unused Model",
|
||||||
|
replacement_model_slug=None,
|
||||||
|
nodes_migrated=0,
|
||||||
|
message="Successfully deleted model 'Unused Model' (unused-model). No workflows were using this model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||||
|
new=AsyncMock(return_value=mock_response),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_refresh = mocker.patch(
|
||||||
|
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||||
|
new=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.delete("/admin/llm/models/model-1")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
response_data = response.json()
|
||||||
|
assert response_data["deleted_model_slug"] == "unused-model"
|
||||||
|
assert response_data["nodes_migrated"] == 0
|
||||||
|
assert response_data["replacement_model_slug"] is None
|
||||||
|
mock_refresh.assert_called_once()
|
||||||
@@ -20,6 +20,7 @@ from backend.blocks._base import (
|
|||||||
)
|
)
|
||||||
from backend.blocks.llm import LlmModel
|
from backend.blocks.llm import LlmModel
|
||||||
from backend.data.db import query_raw_with_schema
|
from backend.data.db import query_raw_with_schema
|
||||||
|
from backend.data.llm_registry import get_all_model_slugs_for_validation
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
@@ -36,7 +37,14 @@ from .model import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
|
||||||
|
|
||||||
|
def _get_llm_models() -> list[str]:
|
||||||
|
"""Get LLM model names for search matching from the registry."""
|
||||||
|
return [
|
||||||
|
slug.lower().replace("-", " ") for slug in get_all_model_slugs_for_validation()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
MAX_LIBRARY_AGENT_RESULTS = 100
|
MAX_LIBRARY_AGENT_RESULTS = 100
|
||||||
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
||||||
@@ -501,8 +509,10 @@ async def _get_static_counts():
|
|||||||
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||||
for field in schema_cls.model_fields.values():
|
for field in schema_cls.model_fields.values():
|
||||||
if field.annotation == LlmModel:
|
if field.annotation == LlmModel:
|
||||||
# Check if query matches any value in llm_models
|
# Normalize query same as model slugs (lowercase, hyphens to spaces)
|
||||||
if any(query in name for name in llm_models):
|
normalized_model_query = query.lower().replace("-", " ")
|
||||||
|
# Check if query matches any value in llm_models from registry
|
||||||
|
if any(normalized_model_query in name for name in _get_llm_models()):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -37,10 +37,12 @@ stale pending messages from dead consumers.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
|
from prisma import Prisma
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from redis.exceptions import ResponseError
|
from redis.exceptions import ResponseError
|
||||||
|
|
||||||
@@ -67,8 +69,8 @@ class OperationCompleteMessage(BaseModel):
|
|||||||
class ChatCompletionConsumer:
|
class ChatCompletionConsumer:
|
||||||
"""Consumer for chat operation completion messages from Redis Streams.
|
"""Consumer for chat operation completion messages from Redis Streams.
|
||||||
|
|
||||||
Database operations are handled through the chat_db() accessor, which
|
This consumer initializes its own Prisma client in start() to ensure
|
||||||
routes through DatabaseManager RPC when Prisma is not directly connected.
|
database operations work correctly within this async context.
|
||||||
|
|
||||||
Uses Redis consumer groups to allow multiple platform pods to consume
|
Uses Redis consumer groups to allow multiple platform pods to consume
|
||||||
messages reliably with automatic redelivery on failure.
|
messages reliably with automatic redelivery on failure.
|
||||||
@@ -77,6 +79,7 @@ class ChatCompletionConsumer:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._consumer_task: asyncio.Task | None = None
|
self._consumer_task: asyncio.Task | None = None
|
||||||
self._running = False
|
self._running = False
|
||||||
|
self._prisma: Prisma | None = None
|
||||||
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
@@ -112,6 +115,15 @@ class ChatCompletionConsumer:
|
|||||||
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _ensure_prisma(self) -> Prisma:
|
||||||
|
"""Lazily initialize Prisma client on first use."""
|
||||||
|
if self._prisma is None:
|
||||||
|
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||||
|
self._prisma = Prisma(datasource={"url": database_url})
|
||||||
|
await self._prisma.connect()
|
||||||
|
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
||||||
|
return self._prisma
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the completion consumer."""
|
"""Stop the completion consumer."""
|
||||||
self._running = False
|
self._running = False
|
||||||
@@ -124,6 +136,11 @@ class ChatCompletionConsumer:
|
|||||||
pass
|
pass
|
||||||
self._consumer_task = None
|
self._consumer_task = None
|
||||||
|
|
||||||
|
if self._prisma:
|
||||||
|
await self._prisma.disconnect()
|
||||||
|
self._prisma = None
|
||||||
|
logger.info("[COMPLETION] Consumer Prisma client disconnected")
|
||||||
|
|
||||||
logger.info("Chat completion consumer stopped")
|
logger.info("Chat completion consumer stopped")
|
||||||
|
|
||||||
async def _consume_messages(self) -> None:
|
async def _consume_messages(self) -> None:
|
||||||
@@ -235,7 +252,7 @@ class ChatCompletionConsumer:
|
|||||||
# XAUTOCLAIM after min_idle_time expires
|
# XAUTOCLAIM after min_idle_time expires
|
||||||
|
|
||||||
async def _handle_message(self, body: bytes) -> None:
|
async def _handle_message(self, body: bytes) -> None:
|
||||||
"""Handle a completion message."""
|
"""Handle a completion message using our own Prisma client."""
|
||||||
try:
|
try:
|
||||||
data = orjson.loads(body)
|
data = orjson.loads(body)
|
||||||
message = OperationCompleteMessage(**data)
|
message = OperationCompleteMessage(**data)
|
||||||
@@ -285,7 +302,8 @@ class ChatCompletionConsumer:
|
|||||||
message: OperationCompleteMessage,
|
message: OperationCompleteMessage,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle successful operation completion."""
|
"""Handle successful operation completion."""
|
||||||
await process_operation_success(task, message.result)
|
prisma = await self._ensure_prisma()
|
||||||
|
await process_operation_success(task, message.result, prisma)
|
||||||
|
|
||||||
async def _handle_failure(
|
async def _handle_failure(
|
||||||
self,
|
self,
|
||||||
@@ -293,7 +311,8 @@ class ChatCompletionConsumer:
|
|||||||
message: OperationCompleteMessage,
|
message: OperationCompleteMessage,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle failed operation completion."""
|
"""Handle failed operation completion."""
|
||||||
await process_operation_failure(task, message.error)
|
prisma = await self._ensure_prisma()
|
||||||
|
await process_operation_failure(task, message.error, prisma)
|
||||||
|
|
||||||
|
|
||||||
# Module-level consumer instance
|
# Module-level consumer instance
|
||||||
@@ -9,8 +9,7 @@ import logging
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
|
from prisma import Prisma
|
||||||
from backend.data.db_accessors import chat_db
|
|
||||||
|
|
||||||
from . import service as chat_service
|
from . import service as chat_service
|
||||||
from . import stream_registry
|
from . import stream_registry
|
||||||
@@ -73,40 +72,48 @@ async def _update_tool_message(
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
tool_call_id: str,
|
tool_call_id: str,
|
||||||
content: str,
|
content: str,
|
||||||
|
prisma_client: Prisma | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update tool message in database using the chat_db accessor.
|
"""Update tool message in database.
|
||||||
|
|
||||||
Routes through DatabaseManager RPC when Prisma is not directly
|
|
||||||
connected (e.g. in the CoPilot Executor microservice).
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The session ID
|
session_id: The session ID
|
||||||
tool_call_id: The tool call ID to update
|
tool_call_id: The tool call ID to update
|
||||||
content: The new content for the message
|
content: The new content for the message
|
||||||
|
prisma_client: Optional Prisma client. If None, uses chat_service.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ToolMessageUpdateError: If the database update fails.
|
ToolMessageUpdateError: If the database update fails. The caller should
|
||||||
|
handle this to avoid marking the task as completed with inconsistent state.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
updated = await chat_db().update_tool_message_content(
|
if prisma_client:
|
||||||
session_id=session_id,
|
# Use provided Prisma client (for consumer with its own connection)
|
||||||
tool_call_id=tool_call_id,
|
updated_count = await prisma_client.chatmessage.update_many(
|
||||||
new_content=content,
|
where={
|
||||||
)
|
"sessionId": session_id,
|
||||||
if not updated:
|
"toolCallId": tool_call_id,
|
||||||
raise ToolMessageUpdateError(
|
},
|
||||||
f"No message found with tool_call_id="
|
data={"content": content},
|
||||||
f"{tool_call_id} in session {session_id}"
|
)
|
||||||
|
# Check if any rows were updated - 0 means message not found
|
||||||
|
if updated_count == 0:
|
||||||
|
raise ToolMessageUpdateError(
|
||||||
|
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use service function (for webhook endpoint)
|
||||||
|
await chat_service._update_pending_operation(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
result=content,
|
||||||
)
|
)
|
||||||
except ToolMessageUpdateError:
|
except ToolMessageUpdateError:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
||||||
f"[COMPLETION] Failed to update tool message: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
raise ToolMessageUpdateError(
|
raise ToolMessageUpdateError(
|
||||||
f"Failed to update tool message for tool call #{tool_call_id}: {e}"
|
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
@@ -195,6 +202,7 @@ async def _save_agent_from_result(
|
|||||||
async def process_operation_success(
|
async def process_operation_success(
|
||||||
task: stream_registry.ActiveTask,
|
task: stream_registry.ActiveTask,
|
||||||
result: dict | str | None,
|
result: dict | str | None,
|
||||||
|
prisma_client: Prisma | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle successful operation completion.
|
"""Handle successful operation completion.
|
||||||
|
|
||||||
@@ -204,10 +212,12 @@ async def process_operation_success(
|
|||||||
Args:
|
Args:
|
||||||
task: The active task that completed
|
task: The active task that completed
|
||||||
result: The result data from the operation
|
result: The result data from the operation
|
||||||
|
prisma_client: Optional Prisma client for database operations.
|
||||||
|
If None, uses chat_service._update_pending_operation instead.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ToolMessageUpdateError: If the database update fails. The task
|
ToolMessageUpdateError: If the database update fails. The task will be
|
||||||
will be marked as failed instead of completed.
|
marked as failed instead of completed to avoid inconsistent state.
|
||||||
"""
|
"""
|
||||||
# For agent generation tools, save the agent to library
|
# For agent generation tools, save the agent to library
|
||||||
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
||||||
@@ -240,6 +250,7 @@ async def process_operation_success(
|
|||||||
session_id=task.session_id,
|
session_id=task.session_id,
|
||||||
tool_call_id=task.tool_call_id,
|
tool_call_id=task.tool_call_id,
|
||||||
content=result_str,
|
content=result_str,
|
||||||
|
prisma_client=prisma_client,
|
||||||
)
|
)
|
||||||
except ToolMessageUpdateError:
|
except ToolMessageUpdateError:
|
||||||
# DB update failed - mark task as failed to avoid inconsistent state
|
# DB update failed - mark task as failed to avoid inconsistent state
|
||||||
@@ -282,15 +293,18 @@ async def process_operation_success(
|
|||||||
async def process_operation_failure(
|
async def process_operation_failure(
|
||||||
task: stream_registry.ActiveTask,
|
task: stream_registry.ActiveTask,
|
||||||
error: str | None,
|
error: str | None,
|
||||||
|
prisma_client: Prisma | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle failed operation completion.
|
"""Handle failed operation completion.
|
||||||
|
|
||||||
Publishes the error to the stream registry, updates the database
|
Publishes the error to the stream registry, updates the database with
|
||||||
with the error response, and marks the task as failed.
|
the error response, and marks the task as failed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task: The active task that failed
|
task: The active task that failed
|
||||||
error: The error message from the operation
|
error: The error message from the operation
|
||||||
|
prisma_client: Optional Prisma client for database operations.
|
||||||
|
If None, uses chat_service._update_pending_operation instead.
|
||||||
"""
|
"""
|
||||||
error_msg = error or "Operation failed"
|
error_msg = error or "Operation failed"
|
||||||
|
|
||||||
@@ -311,6 +325,7 @@ async def process_operation_failure(
|
|||||||
session_id=task.session_id,
|
session_id=task.session_id,
|
||||||
tool_call_id=task.tool_call_id,
|
tool_call_id=task.tool_call_id,
|
||||||
content=error_response.model_dump_json(),
|
content=error_response.model_dump_json(),
|
||||||
|
prisma_client=prisma_client,
|
||||||
)
|
)
|
||||||
except ToolMessageUpdateError:
|
except ToolMessageUpdateError:
|
||||||
# DB update failed - log but continue with cleanup
|
# DB update failed - log but continue with cleanup
|
||||||
@@ -14,27 +14,29 @@ from prisma.types import (
|
|||||||
ChatSessionWhereInput,
|
ChatSessionWhereInput,
|
||||||
)
|
)
|
||||||
|
|
||||||
from backend.data import db
|
from backend.data.db import transaction
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
|
|
||||||
from .model import ChatMessage, ChatSession, ChatSessionInfo
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def get_chat_session(session_id: str) -> ChatSession | None:
|
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
||||||
"""Get a chat session by ID from the database."""
|
"""Get a chat session by ID from the database."""
|
||||||
session = await PrismaChatSession.prisma().find_unique(
|
session = await PrismaChatSession.prisma().find_unique(
|
||||||
where={"id": session_id},
|
where={"id": session_id},
|
||||||
include={"Messages": {"order_by": {"sequence": "asc"}}},
|
include={"Messages": True},
|
||||||
)
|
)
|
||||||
return ChatSession.from_db(session) if session else None
|
if session and session.Messages:
|
||||||
|
# Sort messages by sequence in Python - Prisma Python client doesn't support
|
||||||
|
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
|
||||||
|
session.Messages.sort(key=lambda m: m.sequence)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def create_chat_session(
|
async def create_chat_session(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
) -> ChatSessionInfo:
|
) -> PrismaChatSession:
|
||||||
"""Create a new chat session in the database."""
|
"""Create a new chat session in the database."""
|
||||||
data = ChatSessionCreateInput(
|
data = ChatSessionCreateInput(
|
||||||
id=session_id,
|
id=session_id,
|
||||||
@@ -43,8 +45,7 @@ async def create_chat_session(
|
|||||||
successfulAgentRuns=SafeJson({}),
|
successfulAgentRuns=SafeJson({}),
|
||||||
successfulAgentSchedules=SafeJson({}),
|
successfulAgentSchedules=SafeJson({}),
|
||||||
)
|
)
|
||||||
prisma_session = await PrismaChatSession.prisma().create(data=data)
|
return await PrismaChatSession.prisma().create(data=data)
|
||||||
return ChatSessionInfo.from_db(prisma_session)
|
|
||||||
|
|
||||||
|
|
||||||
async def update_chat_session(
|
async def update_chat_session(
|
||||||
@@ -55,7 +56,7 @@ async def update_chat_session(
|
|||||||
total_prompt_tokens: int | None = None,
|
total_prompt_tokens: int | None = None,
|
||||||
total_completion_tokens: int | None = None,
|
total_completion_tokens: int | None = None,
|
||||||
title: str | None = None,
|
title: str | None = None,
|
||||||
) -> ChatSession | None:
|
) -> PrismaChatSession | None:
|
||||||
"""Update a chat session's metadata."""
|
"""Update a chat session's metadata."""
|
||||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||||
|
|
||||||
@@ -75,9 +76,12 @@ async def update_chat_session(
|
|||||||
session = await PrismaChatSession.prisma().update(
|
session = await PrismaChatSession.prisma().update(
|
||||||
where={"id": session_id},
|
where={"id": session_id},
|
||||||
data=data,
|
data=data,
|
||||||
include={"Messages": {"order_by": {"sequence": "asc"}}},
|
include={"Messages": True},
|
||||||
)
|
)
|
||||||
return ChatSession.from_db(session) if session else None
|
if session and session.Messages:
|
||||||
|
# Sort in Python - Prisma Python doesn't support order_by in include clauses
|
||||||
|
session.Messages.sort(key=lambda m: m.sequence)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def add_chat_message(
|
async def add_chat_message(
|
||||||
@@ -90,7 +94,7 @@ async def add_chat_message(
|
|||||||
refusal: str | None = None,
|
refusal: str | None = None,
|
||||||
tool_calls: list[dict[str, Any]] | None = None,
|
tool_calls: list[dict[str, Any]] | None = None,
|
||||||
function_call: dict[str, Any] | None = None,
|
function_call: dict[str, Any] | None = None,
|
||||||
) -> ChatMessage:
|
) -> PrismaChatMessage:
|
||||||
"""Add a message to a chat session."""
|
"""Add a message to a chat session."""
|
||||||
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
||||||
# because Prisma's TypedDict validation rejects optional fields set to None.
|
# because Prisma's TypedDict validation rejects optional fields set to None.
|
||||||
@@ -125,14 +129,14 @@ async def add_chat_message(
|
|||||||
),
|
),
|
||||||
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
||||||
)
|
)
|
||||||
return ChatMessage.from_db(message)
|
return message
|
||||||
|
|
||||||
|
|
||||||
async def add_chat_messages_batch(
|
async def add_chat_messages_batch(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
start_sequence: int,
|
start_sequence: int,
|
||||||
) -> list[ChatMessage]:
|
) -> list[PrismaChatMessage]:
|
||||||
"""Add multiple messages to a chat session in a batch.
|
"""Add multiple messages to a chat session in a batch.
|
||||||
|
|
||||||
Uses a transaction for atomicity - if any message creation fails,
|
Uses a transaction for atomicity - if any message creation fails,
|
||||||
@@ -143,7 +147,7 @@ async def add_chat_messages_batch(
|
|||||||
|
|
||||||
created_messages = []
|
created_messages = []
|
||||||
|
|
||||||
async with db.transaction() as tx:
|
async with transaction() as tx:
|
||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||||
# directly because Prisma's TypedDict validation rejects optional fields
|
# directly because Prisma's TypedDict validation rejects optional fields
|
||||||
@@ -183,22 +187,21 @@ async def add_chat_messages_batch(
|
|||||||
data={"updatedAt": datetime.now(UTC)},
|
data={"updatedAt": datetime.now(UTC)},
|
||||||
)
|
)
|
||||||
|
|
||||||
return [ChatMessage.from_db(m) for m in created_messages]
|
return created_messages
|
||||||
|
|
||||||
|
|
||||||
async def get_user_chat_sessions(
|
async def get_user_chat_sessions(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> list[ChatSessionInfo]:
|
) -> list[PrismaChatSession]:
|
||||||
"""Get chat sessions for a user, ordered by most recent."""
|
"""Get chat sessions for a user, ordered by most recent."""
|
||||||
prisma_sessions = await PrismaChatSession.prisma().find_many(
|
return await PrismaChatSession.prisma().find_many(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
order={"updatedAt": "desc"},
|
order={"updatedAt": "desc"},
|
||||||
take=limit,
|
take=limit,
|
||||||
skip=offset,
|
skip=offset,
|
||||||
)
|
)
|
||||||
return [ChatSessionInfo.from_db(s) for s in prisma_sessions]
|
|
||||||
|
|
||||||
|
|
||||||
async def get_user_session_count(user_id: str) -> int:
|
async def get_user_session_count(user_id: str) -> int:
|
||||||
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any, Self, cast
|
from typing import Any, cast
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
@@ -23,17 +23,26 @@ from prisma.models import ChatMessage as PrismaChatMessage
|
|||||||
from prisma.models import ChatSession as PrismaChatSession
|
from prisma.models import ChatSession as PrismaChatSession
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.db_accessors import chat_db
|
|
||||||
from backend.data.redis_client import get_redis_async
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
from backend.util.exceptions import DatabaseError, RedisError
|
from backend.util.exceptions import DatabaseError, RedisError
|
||||||
|
|
||||||
|
from . import db as chat_db
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||||
|
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
if isinstance(value, str):
|
||||||
|
return json.loads(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
# Redis cache key prefix for chat sessions
|
# Redis cache key prefix for chat sessions
|
||||||
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
||||||
|
|
||||||
@@ -43,7 +52,28 @@ def _get_session_cache_key(session_id: str) -> str:
|
|||||||
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
||||||
|
|
||||||
|
|
||||||
# ===================== Chat data models ===================== #
|
# Session-level locks to prevent race conditions during concurrent upserts.
|
||||||
|
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
|
||||||
|
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
|
||||||
|
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
|
||||||
|
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
|
||||||
|
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||||
|
_session_locks_mutex = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||||
|
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||||
|
|
||||||
|
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||||
|
when no coroutine holds a reference to them, preventing memory leaks from
|
||||||
|
unbounded growth of session locks.
|
||||||
|
"""
|
||||||
|
async with _session_locks_mutex:
|
||||||
|
lock = _session_locks.get(session_id)
|
||||||
|
if lock is None:
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
_session_locks[session_id] = lock
|
||||||
|
return lock
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
@@ -55,19 +85,6 @@ class ChatMessage(BaseModel):
|
|||||||
tool_calls: list[dict] | None = None
|
tool_calls: list[dict] | None = None
|
||||||
function_call: dict | None = None
|
function_call: dict | None = None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
|
|
||||||
"""Convert a Prisma ChatMessage to a Pydantic ChatMessage."""
|
|
||||||
return ChatMessage(
|
|
||||||
role=prisma_message.role,
|
|
||||||
content=prisma_message.content,
|
|
||||||
name=prisma_message.name,
|
|
||||||
tool_call_id=prisma_message.toolCallId,
|
|
||||||
refusal=prisma_message.refusal,
|
|
||||||
tool_calls=_parse_json_field(prisma_message.toolCalls),
|
|
||||||
function_call=_parse_json_field(prisma_message.functionCall),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Usage(BaseModel):
|
class Usage(BaseModel):
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
@@ -75,10 +92,11 @@ class Usage(BaseModel):
|
|||||||
total_tokens: int
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
class ChatSessionInfo(BaseModel):
|
class ChatSession(BaseModel):
|
||||||
session_id: str
|
session_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
title: str | None = None
|
title: str | None = None
|
||||||
|
messages: list[ChatMessage]
|
||||||
usage: list[Usage]
|
usage: list[Usage]
|
||||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
@@ -86,9 +104,60 @@ class ChatSessionInfo(BaseModel):
|
|||||||
successful_agent_runs: dict[str, int] = {}
|
successful_agent_runs: dict[str, int] = {}
|
||||||
successful_agent_schedules: dict[str, int] = {}
|
successful_agent_schedules: dict[str, int] = {}
|
||||||
|
|
||||||
@classmethod
|
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
||||||
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
"""Attach a tool_call to the current turn's assistant message.
|
||||||
"""Convert Prisma ChatSession to Pydantic ChatSession."""
|
|
||||||
|
Searches backwards for the most recent assistant message (stopping at
|
||||||
|
any user message boundary). If found, appends the tool_call to it.
|
||||||
|
Otherwise creates a new assistant message with the tool_call.
|
||||||
|
"""
|
||||||
|
for msg in reversed(self.messages):
|
||||||
|
if msg.role == "user":
|
||||||
|
break
|
||||||
|
if msg.role == "assistant":
|
||||||
|
if not msg.tool_calls:
|
||||||
|
msg.tool_calls = []
|
||||||
|
msg.tool_calls.append(tool_call)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.messages.append(
|
||||||
|
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new(user_id: str) -> "ChatSession":
|
||||||
|
return ChatSession(
|
||||||
|
session_id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
title=None,
|
||||||
|
messages=[],
|
||||||
|
usage=[],
|
||||||
|
credentials={},
|
||||||
|
started_at=datetime.now(UTC),
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_db(
|
||||||
|
prisma_session: PrismaChatSession,
|
||||||
|
prisma_messages: list[PrismaChatMessage] | None = None,
|
||||||
|
) -> "ChatSession":
|
||||||
|
"""Convert Prisma models to Pydantic ChatSession."""
|
||||||
|
messages = []
|
||||||
|
if prisma_messages:
|
||||||
|
for msg in prisma_messages:
|
||||||
|
messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role=msg.role,
|
||||||
|
content=msg.content,
|
||||||
|
name=msg.name,
|
||||||
|
tool_call_id=msg.toolCallId,
|
||||||
|
refusal=msg.refusal,
|
||||||
|
tool_calls=_parse_json_field(msg.toolCalls),
|
||||||
|
function_call=_parse_json_field(msg.functionCall),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Parse JSON fields from Prisma
|
# Parse JSON fields from Prisma
|
||||||
credentials = _parse_json_field(prisma_session.credentials, default={})
|
credentials = _parse_json_field(prisma_session.credentials, default={})
|
||||||
successful_agent_runs = _parse_json_field(
|
successful_agent_runs = _parse_json_field(
|
||||||
@@ -110,10 +179,11 @@ class ChatSessionInfo(BaseModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(
|
return ChatSession(
|
||||||
session_id=prisma_session.id,
|
session_id=prisma_session.id,
|
||||||
user_id=prisma_session.userId,
|
user_id=prisma_session.userId,
|
||||||
title=prisma_session.title,
|
title=prisma_session.title,
|
||||||
|
messages=messages,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
started_at=prisma_session.createdAt,
|
started_at=prisma_session.createdAt,
|
||||||
@@ -122,55 +192,46 @@ class ChatSessionInfo(BaseModel):
|
|||||||
successful_agent_schedules=successful_agent_schedules,
|
successful_agent_schedules=successful_agent_schedules,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_consecutive_assistant_messages(
|
||||||
|
messages: list[ChatCompletionMessageParam],
|
||||||
|
) -> list[ChatCompletionMessageParam]:
|
||||||
|
"""Merge consecutive assistant messages into single messages.
|
||||||
|
|
||||||
class ChatSession(ChatSessionInfo):
|
Long-running tool flows can create split assistant messages: one with
|
||||||
messages: list[ChatMessage]
|
text content and another with tool_calls. Anthropic's API requires
|
||||||
|
tool_result blocks to reference a tool_use in the immediately preceding
|
||||||
@classmethod
|
assistant message, so these splits cause 400 errors via OpenRouter.
|
||||||
def new(cls, user_id: str) -> Self:
|
|
||||||
return cls(
|
|
||||||
session_id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
title=None,
|
|
||||||
messages=[],
|
|
||||||
usage=[],
|
|
||||||
credentials={},
|
|
||||||
started_at=datetime.now(UTC),
|
|
||||||
updated_at=datetime.now(UTC),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
|
||||||
"""Convert Prisma ChatSession to Pydantic ChatSession."""
|
|
||||||
if prisma_session.Messages is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Prisma session {prisma_session.id} is missing Messages relation"
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
**ChatSessionInfo.from_db(prisma_session).model_dump(),
|
|
||||||
messages=[ChatMessage.from_db(m) for m in prisma_session.Messages],
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
|
||||||
"""Attach a tool_call to the current turn's assistant message.
|
|
||||||
|
|
||||||
Searches backwards for the most recent assistant message (stopping at
|
|
||||||
any user message boundary). If found, appends the tool_call to it.
|
|
||||||
Otherwise creates a new assistant message with the tool_call.
|
|
||||||
"""
|
"""
|
||||||
for msg in reversed(self.messages):
|
if len(messages) < 2:
|
||||||
if msg.role == "user":
|
return messages
|
||||||
break
|
|
||||||
if msg.role == "assistant":
|
|
||||||
if not msg.tool_calls:
|
|
||||||
msg.tool_calls = []
|
|
||||||
msg.tool_calls.append(tool_call)
|
|
||||||
return
|
|
||||||
|
|
||||||
self.messages.append(
|
result: list[ChatCompletionMessageParam] = [messages[0]]
|
||||||
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
|
for msg in messages[1:]:
|
||||||
)
|
prev = result[-1]
|
||||||
|
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
|
||||||
|
result.append(msg)
|
||||||
|
continue
|
||||||
|
|
||||||
|
prev = cast(ChatCompletionAssistantMessageParam, prev)
|
||||||
|
curr = cast(ChatCompletionAssistantMessageParam, msg)
|
||||||
|
|
||||||
|
curr_content = curr.get("content") or ""
|
||||||
|
if curr_content:
|
||||||
|
prev_content = prev.get("content") or ""
|
||||||
|
prev["content"] = (
|
||||||
|
f"{prev_content}\n{curr_content}" if prev_content else curr_content
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_tool_calls = curr.get("tool_calls")
|
||||||
|
if curr_tool_calls:
|
||||||
|
prev_tool_calls = prev.get("tool_calls")
|
||||||
|
prev["tool_calls"] = (
|
||||||
|
list(prev_tool_calls) + list(curr_tool_calls)
|
||||||
|
if prev_tool_calls
|
||||||
|
else list(curr_tool_calls)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||||
messages = []
|
messages = []
|
||||||
@@ -260,70 +321,40 @@ class ChatSession(ChatSessionInfo):
|
|||||||
)
|
)
|
||||||
return self._merge_consecutive_assistant_messages(messages)
|
return self._merge_consecutive_assistant_messages(messages)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _merge_consecutive_assistant_messages(
|
|
||||||
messages: list[ChatCompletionMessageParam],
|
|
||||||
) -> list[ChatCompletionMessageParam]:
|
|
||||||
"""Merge consecutive assistant messages into single messages.
|
|
||||||
|
|
||||||
Long-running tool flows can create split assistant messages: one with
|
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||||
text content and another with tool_calls. Anthropic's API requires
|
"""Get a chat session from Redis cache."""
|
||||||
tool_result blocks to reference a tool_use in the immediately preceding
|
redis_key = _get_session_cache_key(session_id)
|
||||||
assistant message, so these splits cause 400 errors via OpenRouter.
|
async_redis = await get_redis_async()
|
||||||
"""
|
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||||
if len(messages) < 2:
|
|
||||||
return messages
|
|
||||||
|
|
||||||
result: list[ChatCompletionMessageParam] = [messages[0]]
|
if raw_session is None:
|
||||||
for msg in messages[1:]:
|
return None
|
||||||
prev = result[-1]
|
|
||||||
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
|
|
||||||
result.append(msg)
|
|
||||||
continue
|
|
||||||
|
|
||||||
prev = cast(ChatCompletionAssistantMessageParam, prev)
|
try:
|
||||||
curr = cast(ChatCompletionAssistantMessageParam, msg)
|
session = ChatSession.model_validate_json(raw_session)
|
||||||
|
logger.info(
|
||||||
curr_content = curr.get("content") or ""
|
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
|
||||||
if curr_content:
|
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
|
||||||
prev_content = prev.get("content") or ""
|
)
|
||||||
prev["content"] = (
|
return session
|
||||||
f"{prev_content}\n{curr_content}" if prev_content else curr_content
|
except Exception as e:
|
||||||
)
|
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||||
|
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||||
curr_tool_calls = curr.get("tool_calls")
|
|
||||||
if curr_tool_calls:
|
|
||||||
prev_tool_calls = prev.get("tool_calls")
|
|
||||||
prev["tool_calls"] = (
|
|
||||||
list(prev_tool_calls) + list(curr_tool_calls)
|
|
||||||
if prev_tool_calls
|
|
||||||
else list(curr_tool_calls)
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
async def _cache_session(session: ChatSession) -> None:
|
||||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
"""Cache a chat session in Redis."""
|
||||||
if value is None:
|
|
||||||
return default
|
|
||||||
if isinstance(value, str):
|
|
||||||
return json.loads(value)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
# ================ Chat cache + DB operations ================ #
|
|
||||||
|
|
||||||
# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not
|
|
||||||
# connected directly.
|
|
||||||
|
|
||||||
|
|
||||||
async def cache_chat_session(session: ChatSession) -> None:
|
|
||||||
"""Cache a chat session in Redis (without persisting to the database)."""
|
|
||||||
redis_key = _get_session_cache_key(session.session_id)
|
redis_key = _get_session_cache_key(session.session_id)
|
||||||
async_redis = await get_redis_async()
|
async_redis = await get_redis_async()
|
||||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||||
|
|
||||||
|
|
||||||
|
async def cache_chat_session(session: ChatSession) -> None:
|
||||||
|
"""Cache a chat session without persisting to the database."""
|
||||||
|
await _cache_session(session)
|
||||||
|
|
||||||
|
|
||||||
async def invalidate_session_cache(session_id: str) -> None:
|
async def invalidate_session_cache(session_id: str) -> None:
|
||||||
"""Invalidate a chat session from Redis cache.
|
"""Invalidate a chat session from Redis cache.
|
||||||
|
|
||||||
@@ -339,6 +370,77 @@ async def invalidate_session_cache(session_id: str) -> None:
|
|||||||
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||||
|
"""Get a chat session from the database."""
|
||||||
|
prisma_session = await chat_db.get_chat_session(session_id)
|
||||||
|
if not prisma_session:
|
||||||
|
return None
|
||||||
|
|
||||||
|
messages = prisma_session.Messages
|
||||||
|
logger.debug(
|
||||||
|
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
|
||||||
|
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatSession.from_db(prisma_session, messages)
|
||||||
|
|
||||||
|
|
||||||
|
async def _save_session_to_db(
|
||||||
|
session: ChatSession, existing_message_count: int
|
||||||
|
) -> None:
|
||||||
|
"""Save or update a chat session in the database."""
|
||||||
|
# Check if session exists in DB
|
||||||
|
existing = await chat_db.get_chat_session(session.session_id)
|
||||||
|
|
||||||
|
if not existing:
|
||||||
|
# Create new session
|
||||||
|
await chat_db.create_chat_session(
|
||||||
|
session_id=session.session_id,
|
||||||
|
user_id=session.user_id,
|
||||||
|
)
|
||||||
|
existing_message_count = 0
|
||||||
|
|
||||||
|
# Calculate total tokens from usage
|
||||||
|
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||||
|
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||||
|
|
||||||
|
# Update session metadata
|
||||||
|
await chat_db.update_chat_session(
|
||||||
|
session_id=session.session_id,
|
||||||
|
credentials=session.credentials,
|
||||||
|
successful_agent_runs=session.successful_agent_runs,
|
||||||
|
successful_agent_schedules=session.successful_agent_schedules,
|
||||||
|
total_prompt_tokens=total_prompt,
|
||||||
|
total_completion_tokens=total_completion,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add new messages (only those after existing count)
|
||||||
|
new_messages = session.messages[existing_message_count:]
|
||||||
|
if new_messages:
|
||||||
|
messages_data = []
|
||||||
|
for msg in new_messages:
|
||||||
|
messages_data.append(
|
||||||
|
{
|
||||||
|
"role": msg.role,
|
||||||
|
"content": msg.content,
|
||||||
|
"name": msg.name,
|
||||||
|
"tool_call_id": msg.tool_call_id,
|
||||||
|
"refusal": msg.refusal,
|
||||||
|
"tool_calls": msg.tool_calls,
|
||||||
|
"function_call": msg.function_call,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
|
||||||
|
f"roles={[m['role'] for m in messages_data]}"
|
||||||
|
)
|
||||||
|
await chat_db.add_chat_messages_batch(
|
||||||
|
session_id=session.session_id,
|
||||||
|
messages=messages_data,
|
||||||
|
start_sequence=existing_message_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_chat_session(
|
async def get_chat_session(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
@@ -386,53 +488,16 @@ async def get_chat_session(
|
|||||||
|
|
||||||
# Cache the session from DB
|
# Cache the session from DB
|
||||||
try:
|
try:
|
||||||
await cache_chat_session(session)
|
await _cache_session(session)
|
||||||
logger.info(f"Cached session {session_id} from database")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||||
|
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
async def upsert_chat_session(
|
||||||
"""Get a chat session from Redis cache."""
|
session: ChatSession,
|
||||||
redis_key = _get_session_cache_key(session_id)
|
) -> ChatSession:
|
||||||
async_redis = await get_redis_async()
|
|
||||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
|
||||||
|
|
||||||
if raw_session is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
session = ChatSession.model_validate_json(raw_session)
|
|
||||||
logger.info(
|
|
||||||
f"Loading session {session_id} from cache: "
|
|
||||||
f"message_count={len(session.messages)}, "
|
|
||||||
f"roles={[m.role for m in session.messages]}"
|
|
||||||
)
|
|
||||||
return session
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
|
||||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
|
||||||
"""Get a chat session from the database."""
|
|
||||||
session = await chat_db().get_chat_session(session_id)
|
|
||||||
if not session:
|
|
||||||
return None
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Loaded session {session_id} from DB: "
|
|
||||||
f"has_messages={bool(session.messages)}, "
|
|
||||||
f"message_count={len(session.messages)}, "
|
|
||||||
f"roles={[m.role for m in session.messages]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
async def upsert_chat_session(session: ChatSession) -> ChatSession:
|
|
||||||
"""Update a chat session in both cache and database.
|
"""Update a chat session in both cache and database.
|
||||||
|
|
||||||
Uses session-level locking to prevent race conditions when concurrent
|
Uses session-level locking to prevent race conditions when concurrent
|
||||||
@@ -450,7 +515,7 @@ async def upsert_chat_session(session: ChatSession) -> ChatSession:
|
|||||||
|
|
||||||
async with lock:
|
async with lock:
|
||||||
# Get existing message count from DB for incremental saves
|
# Get existing message count from DB for incremental saves
|
||||||
existing_message_count = await chat_db().get_chat_session_message_count(
|
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||||
session.session_id
|
session.session_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -467,7 +532,7 @@ async def upsert_chat_session(session: ChatSession) -> ChatSession:
|
|||||||
|
|
||||||
# Save to cache (best-effort, even if DB failed)
|
# Save to cache (best-effort, even if DB failed)
|
||||||
try:
|
try:
|
||||||
await cache_chat_session(session)
|
await _cache_session(session)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If DB succeeded but cache failed, raise cache error
|
# If DB succeeded but cache failed, raise cache error
|
||||||
if db_error is None:
|
if db_error is None:
|
||||||
@@ -488,65 +553,6 @@ async def upsert_chat_session(session: ChatSession) -> ChatSession:
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def _save_session_to_db(
|
|
||||||
session: ChatSession, existing_message_count: int
|
|
||||||
) -> None:
|
|
||||||
"""Save or update a chat session in the database."""
|
|
||||||
db = chat_db()
|
|
||||||
|
|
||||||
# Check if session exists in DB
|
|
||||||
existing = await db.get_chat_session(session.session_id)
|
|
||||||
|
|
||||||
if not existing:
|
|
||||||
# Create new session
|
|
||||||
await db.create_chat_session(
|
|
||||||
session_id=session.session_id,
|
|
||||||
user_id=session.user_id,
|
|
||||||
)
|
|
||||||
existing_message_count = 0
|
|
||||||
|
|
||||||
# Calculate total tokens from usage
|
|
||||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
|
||||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
|
||||||
|
|
||||||
# Update session metadata
|
|
||||||
await db.update_chat_session(
|
|
||||||
session_id=session.session_id,
|
|
||||||
credentials=session.credentials,
|
|
||||||
successful_agent_runs=session.successful_agent_runs,
|
|
||||||
successful_agent_schedules=session.successful_agent_schedules,
|
|
||||||
total_prompt_tokens=total_prompt,
|
|
||||||
total_completion_tokens=total_completion,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add new messages (only those after existing count)
|
|
||||||
new_messages = session.messages[existing_message_count:]
|
|
||||||
if new_messages:
|
|
||||||
messages_data = []
|
|
||||||
for msg in new_messages:
|
|
||||||
messages_data.append(
|
|
||||||
{
|
|
||||||
"role": msg.role,
|
|
||||||
"content": msg.content,
|
|
||||||
"name": msg.name,
|
|
||||||
"tool_call_id": msg.tool_call_id,
|
|
||||||
"refusal": msg.refusal,
|
|
||||||
"tool_calls": msg.tool_calls,
|
|
||||||
"function_call": msg.function_call,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
|
||||||
f"roles={[m['role'] for m in messages_data]}, "
|
|
||||||
f"start_sequence={existing_message_count}"
|
|
||||||
)
|
|
||||||
await db.add_chat_messages_batch(
|
|
||||||
session_id=session.session_id,
|
|
||||||
messages=messages_data,
|
|
||||||
start_sequence=existing_message_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
||||||
"""Atomically append a message to a session and persist it.
|
"""Atomically append a message to a session and persist it.
|
||||||
|
|
||||||
@@ -562,7 +568,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
|||||||
raise ValueError(f"Session {session_id} not found")
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
session.messages.append(message)
|
session.messages.append(message)
|
||||||
existing_message_count = await chat_db().get_chat_session_message_count(
|
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||||
session_id
|
session_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -574,7 +580,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await cache_chat_session(session)
|
await _cache_session(session)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
||||||
|
|
||||||
@@ -593,7 +599,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
|||||||
|
|
||||||
# Create in database first - fail fast if this fails
|
# Create in database first - fail fast if this fails
|
||||||
try:
|
try:
|
||||||
await chat_db().create_chat_session(
|
await chat_db.create_chat_session(
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
@@ -605,7 +611,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
|||||||
|
|
||||||
# Cache the session (best-effort optimization, DB is source of truth)
|
# Cache the session (best-effort optimization, DB is source of truth)
|
||||||
try:
|
try:
|
||||||
await cache_chat_session(session)
|
await _cache_session(session)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
||||||
|
|
||||||
@@ -616,16 +622,20 @@ async def get_user_sessions(
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> tuple[list[ChatSessionInfo], int]:
|
) -> tuple[list[ChatSession], int]:
|
||||||
"""Get chat sessions for a user from the database with total count.
|
"""Get chat sessions for a user from the database with total count.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (sessions, total_count) where total_count is the overall
|
A tuple of (sessions, total_count) where total_count is the overall
|
||||||
number of sessions for the user (not just the current page).
|
number of sessions for the user (not just the current page).
|
||||||
"""
|
"""
|
||||||
db = chat_db()
|
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||||
sessions = await db.get_user_chat_sessions(user_id, limit, offset)
|
total_count = await chat_db.get_user_session_count(user_id)
|
||||||
total_count = await db.get_user_session_count(user_id)
|
|
||||||
|
sessions = []
|
||||||
|
for prisma_session in prisma_sessions:
|
||||||
|
# Convert without messages for listing (lighter weight)
|
||||||
|
sessions.append(ChatSession.from_db(prisma_session, None))
|
||||||
|
|
||||||
return sessions, total_count
|
return sessions, total_count
|
||||||
|
|
||||||
@@ -643,7 +653,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
|||||||
"""
|
"""
|
||||||
# Delete from database first (with optional user_id validation)
|
# Delete from database first (with optional user_id validation)
|
||||||
# This confirms ownership before invalidating cache
|
# This confirms ownership before invalidating cache
|
||||||
deleted = await chat_db().delete_chat_session(session_id, user_id)
|
deleted = await chat_db.delete_chat_session(session_id, user_id)
|
||||||
|
|
||||||
if not deleted:
|
if not deleted:
|
||||||
return False
|
return False
|
||||||
@@ -678,7 +688,7 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
|||||||
True if updated successfully, False otherwise.
|
True if updated successfully, False otherwise.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result = await chat_db().update_chat_session(session_id=session_id, title=title)
|
result = await chat_db.update_chat_session(session_id=session_id, title=title)
|
||||||
if result is None:
|
if result is None:
|
||||||
logger.warning(f"Session {session_id} not found for title update")
|
logger.warning(f"Session {session_id} not found for title update")
|
||||||
return False
|
return False
|
||||||
@@ -690,7 +700,7 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
|||||||
cached = await _get_session_from_cache(session_id)
|
cached = await _get_session_from_cache(session_id)
|
||||||
if cached:
|
if cached:
|
||||||
cached.title = title
|
cached.title = title
|
||||||
await cache_chat_session(cached)
|
await _cache_session(cached)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Not critical - title will be correct on next full cache refresh
|
# Not critical - title will be correct on next full cache refresh
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -701,29 +711,3 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to update title for session {session_id}: {e}")
|
logger.error(f"Failed to update title for session {session_id}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
# ==================== Chat session locks ==================== #
|
|
||||||
|
|
||||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
|
||||||
_session_locks_mutex = asyncio.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
|
||||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
|
||||||
|
|
||||||
This was originally added to solve the specific problem of race conditions between
|
|
||||||
the session title thread and the conversation thread, which always occurs on the
|
|
||||||
same instance as we prevent rapid request sends on the frontend.
|
|
||||||
|
|
||||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
|
||||||
when no coroutine holds a reference to them, preventing memory leaks from
|
|
||||||
unbounded growth of session locks. Explicit cleanup also occurs
|
|
||||||
in `delete_chat_session()`.
|
|
||||||
"""
|
|
||||||
async with _session_locks_mutex:
|
|
||||||
lock = _session_locks.get(session_id)
|
|
||||||
if lock is None:
|
|
||||||
lock = asyncio.Lock()
|
|
||||||
_session_locks[session_id] = lock
|
|
||||||
return lock
|
|
||||||
@@ -11,25 +11,24 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response,
|
|||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.copilot import service as chat_service
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.copilot import stream_registry
|
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||||
from backend.copilot.completion_handler import (
|
|
||||||
process_operation_failure,
|
from . import service as chat_service
|
||||||
process_operation_success,
|
from . import stream_registry
|
||||||
)
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
from backend.copilot.config import ChatConfig
|
from .config import ChatConfig
|
||||||
from backend.copilot.executor.utils import enqueue_copilot_task
|
from .model import (
|
||||||
from backend.copilot.model import (
|
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatSession,
|
ChatSession,
|
||||||
append_and_save_message,
|
append_and_save_message,
|
||||||
create_chat_session,
|
create_chat_session,
|
||||||
delete_chat_session,
|
|
||||||
get_chat_session,
|
get_chat_session,
|
||||||
get_user_sessions,
|
get_user_sessions,
|
||||||
)
|
)
|
||||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
|
||||||
from backend.copilot.tools.models import (
|
from .sdk import service as sdk_service
|
||||||
|
from .tools.models import (
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
AgentOutputResponse,
|
AgentOutputResponse,
|
||||||
AgentPreviewResponse,
|
AgentPreviewResponse,
|
||||||
@@ -52,8 +51,7 @@ from backend.copilot.tools.models import (
|
|||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
UnderstandingUpdatedResponse,
|
UnderstandingUpdatedResponse,
|
||||||
)
|
)
|
||||||
from backend.copilot.tracking import track_user_message
|
from .tracking import track_user_message
|
||||||
from backend.util.exceptions import NotFoundError
|
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -213,43 +211,6 @@ async def create_session(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
|
||||||
"/sessions/{session_id}",
|
|
||||||
dependencies=[Security(auth.requires_user)],
|
|
||||||
status_code=204,
|
|
||||||
responses={404: {"description": "Session not found or access denied"}},
|
|
||||||
)
|
|
||||||
async def delete_session(
|
|
||||||
session_id: str,
|
|
||||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
|
||||||
) -> Response:
|
|
||||||
"""
|
|
||||||
Delete a chat session.
|
|
||||||
|
|
||||||
Permanently removes a chat session and all its messages.
|
|
||||||
Only the owner can delete their sessions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: The session ID to delete.
|
|
||||||
user_id: The authenticated user's ID.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
204 No Content on success.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: 404 if session not found or not owned by user.
|
|
||||||
"""
|
|
||||||
deleted = await delete_chat_session(session_id, user_id)
|
|
||||||
|
|
||||||
if not deleted:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"Session {session_id} not found or access denied",
|
|
||||||
)
|
|
||||||
|
|
||||||
return Response(status_code=204)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/sessions/{session_id}",
|
"/sessions/{session_id}",
|
||||||
)
|
)
|
||||||
@@ -355,7 +316,7 @@ async def stream_chat_post(
|
|||||||
f"user={user_id}, message_len={len(request.message)}",
|
f"user={user_id}, message_len={len(request.message)}",
|
||||||
extra={"json_fields": log_meta},
|
extra={"json_fields": log_meta},
|
||||||
)
|
)
|
||||||
await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
||||||
extra={
|
extra={
|
||||||
@@ -382,7 +343,7 @@ async def stream_chat_post(
|
|||||||
message_length=len(request.message),
|
message_length=len(request.message),
|
||||||
)
|
)
|
||||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||||
await append_and_save_message(session_id, message)
|
session = await append_and_save_message(session_id, message)
|
||||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Create a task in the stream registry for reconnection support
|
||||||
@@ -409,19 +370,125 @@ async def stream_chat_post(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
await enqueue_copilot_task(
|
# Background task that runs the AI generation independently of SSE connection
|
||||||
task_id=task_id,
|
async def run_ai_generation():
|
||||||
session_id=session_id,
|
import time as time_module
|
||||||
user_id=user_id,
|
|
||||||
operation_id=operation_id,
|
|
||||||
message=request.message,
|
|
||||||
is_user_message=request.is_user_message,
|
|
||||||
context=request.context,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
gen_start_time = time_module.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
first_chunk_time, ttfc = None, None
|
||||||
|
chunk_count = 0
|
||||||
|
try:
|
||||||
|
# Emit a start event with task_id for reconnection
|
||||||
|
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||||
|
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
|
||||||
|
* 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Choose service based on LaunchDarkly flag (falls back to config default)
|
||||||
|
use_sdk = await is_feature_enabled(
|
||||||
|
Flag.COPILOT_SDK,
|
||||||
|
user_id or "anonymous",
|
||||||
|
default=config.use_claude_agent_sdk,
|
||||||
|
)
|
||||||
|
stream_fn = (
|
||||||
|
sdk_service.stream_chat_completion_sdk
|
||||||
|
if use_sdk
|
||||||
|
else chat_service.stream_chat_completion
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
# Pass message=None since we already added it to the session above
|
||||||
|
async for chunk in stream_fn(
|
||||||
|
session_id,
|
||||||
|
None, # Message already in session
|
||||||
|
is_user_message=request.is_user_message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session, # Pass session with message already added
|
||||||
|
context=request.context,
|
||||||
|
):
|
||||||
|
# Skip duplicate StreamStart — we already published one above
|
||||||
|
if isinstance(chunk, StreamStart):
|
||||||
|
continue
|
||||||
|
chunk_count += 1
|
||||||
|
if first_chunk_time is None:
|
||||||
|
first_chunk_time = time_module.perf_counter()
|
||||||
|
ttfc = first_chunk_time - gen_start_time
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunk_type": type(chunk).__name__,
|
||||||
|
"time_to_first_chunk_ms": ttfc * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Write to Redis (subscribers will receive via XREAD)
|
||||||
|
await stream_registry.publish_chunk(task_id, chunk)
|
||||||
|
|
||||||
|
gen_end_time = time_module.perf_counter()
|
||||||
|
total_time = (gen_end_time - gen_start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
|
||||||
|
f"task={task_id}, session={session_id}, "
|
||||||
|
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"time_to_first_chunk_ms": (
|
||||||
|
ttfc * 1000 if ttfc is not None else None
|
||||||
|
),
|
||||||
|
"n_chunks": chunk_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await stream_registry.mark_task_completed(task_id, "completed")
|
||||||
|
except Exception as e:
|
||||||
|
elapsed = time_module.perf_counter() - gen_start_time
|
||||||
|
logger.error(
|
||||||
|
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed * 1000,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Publish a StreamError so the frontend can display an error message
|
||||||
|
try:
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task_id,
|
||||||
|
StreamError(
|
||||||
|
errorText="An error occurred. Please try again.",
|
||||||
|
code="stream_error",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Best-effort; mark_task_completed will publish StreamFinish
|
||||||
|
await stream_registry.mark_task_completed(task_id, "failed")
|
||||||
|
|
||||||
|
# Start the AI generation in a background task
|
||||||
|
bg_task = asyncio.create_task(run_ai_generation())
|
||||||
|
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
|
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
||||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from claude_agent_sdk import (
|
|||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
from backend.copilot.response_model import (
|
from backend.api.features.chat.response_model import (
|
||||||
StreamBaseResponse,
|
StreamBaseResponse,
|
||||||
StreamError,
|
StreamError,
|
||||||
StreamFinish,
|
StreamFinish,
|
||||||
@@ -34,8 +34,10 @@ from backend.copilot.response_model import (
|
|||||||
StreamToolInputStart,
|
StreamToolInputStart,
|
||||||
StreamToolOutputAvailable,
|
StreamToolOutputAvailable,
|
||||||
)
|
)
|
||||||
|
from backend.api.features.chat.sdk.tool_adapter import (
|
||||||
from .tool_adapter import MCP_TOOL_PREFIX, pop_pending_tool_output
|
MCP_TOOL_PREFIX,
|
||||||
|
pop_pending_tool_output,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -10,7 +10,7 @@ from claude_agent_sdk import (
|
|||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
from backend.copilot.response_model import (
|
from backend.api.features.chat.response_model import (
|
||||||
StreamBaseResponse,
|
StreamBaseResponse,
|
||||||
StreamError,
|
StreamError,
|
||||||
StreamFinish,
|
StreamFinish,
|
||||||
@@ -11,7 +11,7 @@ import re
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from .tool_adapter import (
|
from backend.api.features.chat.sdk.tool_adapter import (
|
||||||
BLOCKED_TOOLS,
|
BLOCKED_TOOLS,
|
||||||
DANGEROUS_PATTERNS,
|
DANGEROUS_PATTERNS,
|
||||||
MCP_TOOL_PREFIX,
|
MCP_TOOL_PREFIX,
|
||||||
@@ -1,9 +1,4 @@
|
|||||||
"""Tests for SDK security hooks — workspace paths, tool access, and deny messages.
|
"""Unit tests for SDK security hooks."""
|
||||||
|
|
||||||
These are pure unit tests with no external dependencies (no SDK, no DB, no server).
|
|
||||||
They validate that the security hooks correctly block unauthorized paths,
|
|
||||||
tool access, and dangerous input patterns.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -17,10 +12,6 @@ def _is_denied(result: dict) -> bool:
|
|||||||
return hook.get("permissionDecision") == "deny"
|
return hook.get("permissionDecision") == "deny"
|
||||||
|
|
||||||
|
|
||||||
def _reason(result: dict) -> str:
|
|
||||||
return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "")
|
|
||||||
|
|
||||||
|
|
||||||
# -- Blocked tools -----------------------------------------------------------
|
# -- Blocked tools -----------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -172,19 +163,3 @@ def test_non_workspace_tool_passes_isolation():
|
|||||||
"find_agent", {"query": "email"}, user_id="user-1"
|
"find_agent", {"query": "email"}, user_id="user-1"
|
||||||
)
|
)
|
||||||
assert result == {}
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
# -- Deny message quality ----------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_blocked_tool_message_clarity():
|
|
||||||
"""Deny messages must include [SECURITY] and 'cannot be bypassed'."""
|
|
||||||
reason = _reason(_validate_tool_access("bash", {}))
|
|
||||||
assert "[SECURITY]" in reason
|
|
||||||
assert "cannot be bypassed" in reason
|
|
||||||
|
|
||||||
|
|
||||||
def test_bash_builtin_blocked_message_clarity():
|
|
||||||
reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"}))
|
|
||||||
assert "[SECURITY]" in reason
|
|
||||||
assert "cannot be bypassed" in reason
|
|
||||||
@@ -255,7 +255,7 @@ def _build_sdk_env() -> dict[str, str]:
|
|||||||
def _make_sdk_cwd(session_id: str) -> str:
|
def _make_sdk_cwd(session_id: str) -> str:
|
||||||
"""Create a safe, session-specific working directory path.
|
"""Create a safe, session-specific working directory path.
|
||||||
|
|
||||||
Delegates to :func:`~backend.copilot.tools.sandbox.make_session_path`
|
Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path`
|
||||||
(single source of truth for path sanitization) and adds a defence-in-depth
|
(single source of truth for path sanitization) and adds a defence-in-depth
|
||||||
assertion.
|
assertion.
|
||||||
"""
|
"""
|
||||||
@@ -440,16 +440,12 @@ async def stream_chat_completion_sdk(
|
|||||||
f"Session {session_id} not found. Please create a new session first."
|
f"Session {session_id} not found. Please create a new session first."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Append the new message to the session if it's not already there
|
if message:
|
||||||
new_message_role = "user" if is_user_message else "assistant"
|
session.messages.append(
|
||||||
if message and (
|
ChatMessage(
|
||||||
len(session.messages) == 0
|
role="user" if is_user_message else "assistant", content=message
|
||||||
or not (
|
)
|
||||||
session.messages[-1].role == new_message_role
|
|
||||||
and session.messages[-1].content == message
|
|
||||||
)
|
)
|
||||||
):
|
|
||||||
session.messages.append(ChatMessage(role=new_message_role, content=message))
|
|
||||||
if is_user_message:
|
if is_user_message:
|
||||||
track_user_message(
|
track_user_message(
|
||||||
user_id=user_id, session_id=session_id, message_length=len(message)
|
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||||
@@ -693,15 +689,11 @@ async def stream_chat_completion_sdk(
|
|||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
raw_transcript = read_transcript_file(captured_transcript.path)
|
raw_transcript = read_transcript_file(captured_transcript.path)
|
||||||
if raw_transcript:
|
if raw_transcript:
|
||||||
try:
|
task = asyncio.create_task(
|
||||||
async with asyncio.timeout(30):
|
_upload_transcript_bg(user_id, session_id, raw_transcript)
|
||||||
await _upload_transcript_bg(
|
)
|
||||||
user_id, session_id, raw_transcript
|
_background_tasks.add(task)
|
||||||
)
|
task.add_done_callback(_background_tasks.discard)
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(
|
|
||||||
f"[SDK] Transcript upload timed out for {session_id}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.debug("[SDK] Stop hook fired but transcript not usable")
|
logger.debug("[SDK] Stop hook fired but transcript not usable")
|
||||||
|
|
||||||
@@ -18,9 +18,9 @@ from collections.abc import Awaitable, Callable
|
|||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.copilot.tools import TOOL_REGISTRY
|
from backend.api.features.chat.tools import TOOL_REGISTRY
|
||||||
from backend.copilot.tools.base import BaseTool
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -27,18 +27,20 @@ from openai.types.chat import (
|
|||||||
ChatCompletionToolParam,
|
ChatCompletionToolParam,
|
||||||
)
|
)
|
||||||
|
|
||||||
from backend.data.db_accessors import chat_db, understanding_db
|
|
||||||
from backend.data.redis_client import get_redis_async
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.data.understanding import format_understanding_for_prompt
|
from backend.data.understanding import (
|
||||||
|
format_understanding_for_prompt,
|
||||||
|
get_business_understanding,
|
||||||
|
)
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.settings import AppEnvironment, Settings
|
from backend.util.settings import AppEnvironment, Settings
|
||||||
|
|
||||||
|
from . import db as chat_db
|
||||||
from . import stream_registry
|
from . import stream_registry
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import (
|
from .model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatSession,
|
ChatSession,
|
||||||
ChatSessionInfo,
|
|
||||||
Usage,
|
Usage,
|
||||||
cache_chat_session,
|
cache_chat_session,
|
||||||
get_chat_session,
|
get_chat_session,
|
||||||
@@ -261,7 +263,7 @@ async def _build_system_prompt(
|
|||||||
understanding = None
|
understanding = None
|
||||||
if user_id:
|
if user_id:
|
||||||
try:
|
try:
|
||||||
understanding = await understanding_db().get_business_understanding(user_id)
|
understanding = await get_business_understanding(user_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to fetch business understanding: {e}")
|
logger.warning(f"Failed to fetch business understanding: {e}")
|
||||||
understanding = None
|
understanding = None
|
||||||
@@ -337,7 +339,7 @@ async def _generate_session_title(
|
|||||||
async def assign_user_to_session(
|
async def assign_user_to_session(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
) -> ChatSessionInfo:
|
) -> ChatSession:
|
||||||
"""
|
"""
|
||||||
Assign a user to a chat session.
|
Assign a user to a chat session.
|
||||||
"""
|
"""
|
||||||
@@ -426,16 +428,12 @@ async def stream_chat_completion(
|
|||||||
f"Session {session_id} not found. Please create a new session first."
|
f"Session {session_id} not found. Please create a new session first."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Append the new message to the session if it's not already there
|
if message:
|
||||||
new_message_role = "user" if is_user_message else "assistant"
|
session.messages.append(
|
||||||
if message and (
|
ChatMessage(
|
||||||
len(session.messages) == 0
|
role="user" if is_user_message else "assistant", content=message
|
||||||
or not (
|
)
|
||||||
session.messages[-1].role == new_message_role
|
|
||||||
and session.messages[-1].content == message
|
|
||||||
)
|
)
|
||||||
):
|
|
||||||
session.messages.append(ChatMessage(role=new_message_role, content=message))
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Appended message (role={'user' if is_user_message else 'assistant'}), "
|
f"Appended message (role={'user' if is_user_message else 'assistant'}), "
|
||||||
f"new message_count={len(session.messages)}"
|
f"new message_count={len(session.messages)}"
|
||||||
@@ -1772,7 +1770,7 @@ async def _update_pending_operation(
|
|||||||
This is called by background tasks when long-running operations complete.
|
This is called by background tasks when long-running operations complete.
|
||||||
"""
|
"""
|
||||||
# Update the message in database
|
# Update the message in database
|
||||||
updated = await chat_db().update_tool_message_content(
|
updated = await chat_db.update_tool_message_content(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
new_content=result,
|
new_content=result,
|
||||||
@@ -3,8 +3,8 @@ from typing import TYPE_CHECKING, Any
|
|||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.copilot.tracking import track_tool_called
|
from backend.api.features.chat.tracking import track_tool_called
|
||||||
|
|
||||||
from .add_understanding import AddUnderstandingTool
|
from .add_understanding import AddUnderstandingTool
|
||||||
from .agent_output import AgentOutputTool
|
from .agent_output import AgentOutputTool
|
||||||
@@ -31,7 +31,7 @@ from .workspace_files import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -6,11 +6,11 @@ import pytest
|
|||||||
from prisma.types import ProfileCreateInput
|
from prisma.types import ProfileCreateInput
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||||
from backend.blocks.llm import AITextGeneratorBlock
|
from backend.blocks.llm import AITextGeneratorBlock
|
||||||
from backend.copilot.model import ChatSession
|
|
||||||
from backend.data.db import prisma
|
from backend.data.db import prisma
|
||||||
from backend.data.graph import Graph, Link, Node, create_graph
|
from backend.data.graph import Graph, Link, Node, create_graph
|
||||||
from backend.data.model import APIKeyCredentials
|
from backend.data.model import APIKeyCredentials
|
||||||
@@ -3,9 +3,11 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.db_accessors import understanding_db
|
from backend.data.understanding import (
|
||||||
from backend.data.understanding import BusinessUnderstandingInput
|
BusinessUnderstandingInput,
|
||||||
|
upsert_business_understanding,
|
||||||
|
)
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
||||||
@@ -97,9 +99,7 @@ and automations for the user's specific needs."""
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Upsert with merge
|
# Upsert with merge
|
||||||
understanding = await understanding_db().upsert_business_understanding(
|
understanding = await upsert_business_understanding(user_id, input_data)
|
||||||
user_id, input_data
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build current understanding summary (filter out empty values)
|
# Build current understanding summary (filter out empty values)
|
||||||
current_understanding = {
|
current_understanding = {
|
||||||
@@ -5,8 +5,9 @@ import re
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, NotRequired, TypedDict
|
from typing import Any, NotRequired, TypedDict
|
||||||
|
|
||||||
from backend.data.db_accessors import graph_db, library_db, store_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.data.graph import Graph, Link, Node
|
from backend.api.features.store import db as store_db
|
||||||
|
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .service import (
|
||||||
@@ -144,9 +145,8 @@ async def get_library_agent_by_id(
|
|||||||
Returns:
|
Returns:
|
||||||
LibraryAgentSummary if found, None otherwise
|
LibraryAgentSummary if found, None otherwise
|
||||||
"""
|
"""
|
||||||
db = library_db()
|
|
||||||
try:
|
try:
|
||||||
agent = await db.get_library_agent_by_graph_id(user_id, agent_id)
|
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||||
if agent:
|
if agent:
|
||||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||||
return LibraryAgentSummary(
|
return LibraryAgentSummary(
|
||||||
@@ -163,7 +163,7 @@ async def get_library_agent_by_id(
|
|||||||
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent = await db.get_library_agent(agent_id, user_id)
|
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
if agent:
|
if agent:
|
||||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||||
return LibraryAgentSummary(
|
return LibraryAgentSummary(
|
||||||
@@ -215,7 +215,7 @@ async def get_library_agents_for_generation(
|
|||||||
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = await library_db().list_library_agents(
|
response = await library_db.list_library_agents(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_term=search_query,
|
search_term=search_query,
|
||||||
page=1,
|
page=1,
|
||||||
@@ -272,7 +272,7 @@ async def search_marketplace_agents_for_generation(
|
|||||||
List of LibraryAgentSummary with full input/output schemas
|
List of LibraryAgentSummary with full input/output schemas
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = await store_db().get_store_agents(
|
response = await store_db.get_store_agents(
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
page=1,
|
page=1,
|
||||||
page_size=max_results,
|
page_size=max_results,
|
||||||
@@ -286,7 +286,7 @@ async def search_marketplace_agents_for_generation(
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
|
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
|
||||||
graphs = await graph_db().get_store_listed_graphs(graph_ids)
|
graphs = await get_store_listed_graphs(*graph_ids)
|
||||||
|
|
||||||
results: list[LibraryAgentSummary] = []
|
results: list[LibraryAgentSummary] = []
|
||||||
for agent in agents_with_graphs:
|
for agent in agents_with_graphs:
|
||||||
@@ -673,10 +673,9 @@ async def save_agent_to_library(
|
|||||||
Tuple of (created Graph, LibraryAgent)
|
Tuple of (created Graph, LibraryAgent)
|
||||||
"""
|
"""
|
||||||
graph = json_to_graph(agent_json)
|
graph = json_to_graph(agent_json)
|
||||||
db = library_db()
|
|
||||||
if is_update:
|
if is_update:
|
||||||
return await db.update_graph_in_library(graph, user_id)
|
return await library_db.update_graph_in_library(graph, user_id)
|
||||||
return await db.create_graph_in_library(graph, user_id)
|
return await library_db.create_graph_in_library(graph, user_id)
|
||||||
|
|
||||||
|
|
||||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||||
@@ -736,14 +735,12 @@ async def get_agent_as_json(
|
|||||||
Returns:
|
Returns:
|
||||||
Agent as JSON dict or None if not found
|
Agent as JSON dict or None if not found
|
||||||
"""
|
"""
|
||||||
db = graph_db()
|
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||||
|
|
||||||
graph = await db.get_graph(agent_id, version=None, user_id=user_id)
|
|
||||||
|
|
||||||
if not graph and user_id:
|
if not graph and user_id:
|
||||||
try:
|
try:
|
||||||
library_agent = await library_db().get_library_agent(agent_id, user_id)
|
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
graph = await db.get_graph(
|
graph = await get_graph(
|
||||||
library_agent.graph_id, version=None, user_id=user_id
|
library_agent.graph_id, version=None, user_id=user_id
|
||||||
)
|
)
|
||||||
except NotFoundError:
|
except NotFoundError:
|
||||||
@@ -7,9 +7,10 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.library.model import LibraryAgent
|
from backend.api.features.library.model import LibraryAgent
|
||||||
from backend.copilot.model import ChatSession
|
from backend.data import execution as execution_db
|
||||||
from backend.data.db_accessors import execution_db, library_db
|
|
||||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
@@ -164,12 +165,10 @@ class AgentOutputTool(BaseTool):
|
|||||||
Resolve agent from provided identifiers.
|
Resolve agent from provided identifiers.
|
||||||
Returns (library_agent, error_message).
|
Returns (library_agent, error_message).
|
||||||
"""
|
"""
|
||||||
lib_db = library_db()
|
|
||||||
|
|
||||||
# Priority 1: Exact library agent ID
|
# Priority 1: Exact library agent ID
|
||||||
if library_agent_id:
|
if library_agent_id:
|
||||||
try:
|
try:
|
||||||
agent = await lib_db.get_library_agent(library_agent_id, user_id)
|
agent = await library_db.get_library_agent(library_agent_id, user_id)
|
||||||
return agent, None
|
return agent, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get library agent by ID: {e}")
|
logger.warning(f"Failed to get library agent by ID: {e}")
|
||||||
@@ -183,7 +182,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
return None, f"Agent '{store_slug}' not found in marketplace"
|
return None, f"Agent '{store_slug}' not found in marketplace"
|
||||||
|
|
||||||
# Find in user's library by graph_id
|
# Find in user's library by graph_id
|
||||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, graph.id)
|
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||||
if not agent:
|
if not agent:
|
||||||
return (
|
return (
|
||||||
None,
|
None,
|
||||||
@@ -195,7 +194,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
# Priority 3: Fuzzy name search in library
|
# Priority 3: Fuzzy name search in library
|
||||||
if agent_name:
|
if agent_name:
|
||||||
try:
|
try:
|
||||||
response = await lib_db.list_library_agents(
|
response = await library_db.list_library_agents(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_term=agent_name,
|
search_term=agent_name,
|
||||||
page_size=5,
|
page_size=5,
|
||||||
@@ -229,11 +228,9 @@ class AgentOutputTool(BaseTool):
|
|||||||
Fetch execution(s) based on filters.
|
Fetch execution(s) based on filters.
|
||||||
Returns (single_execution, available_executions_meta, error_message).
|
Returns (single_execution, available_executions_meta, error_message).
|
||||||
"""
|
"""
|
||||||
exec_db = execution_db()
|
|
||||||
|
|
||||||
# If specific execution_id provided, fetch it directly
|
# If specific execution_id provided, fetch it directly
|
||||||
if execution_id:
|
if execution_id:
|
||||||
execution = await exec_db.get_graph_execution(
|
execution = await execution_db.get_graph_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
include_node_executions=False,
|
include_node_executions=False,
|
||||||
@@ -243,7 +240,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
return execution, [], None
|
return execution, [], None
|
||||||
|
|
||||||
# Get completed executions with time filters
|
# Get completed executions with time filters
|
||||||
executions = await exec_db.get_graph_executions(
|
executions = await execution_db.get_graph_executions(
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
statuses=[ExecutionStatus.COMPLETED],
|
statuses=[ExecutionStatus.COMPLETED],
|
||||||
@@ -257,7 +254,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
|
|
||||||
# If only one execution, fetch full details
|
# If only one execution, fetch full details
|
||||||
if len(executions) == 1:
|
if len(executions) == 1:
|
||||||
full_execution = await exec_db.get_graph_execution(
|
full_execution = await execution_db.get_graph_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
execution_id=executions[0].id,
|
execution_id=executions[0].id,
|
||||||
include_node_executions=False,
|
include_node_executions=False,
|
||||||
@@ -265,7 +262,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
return full_execution, [], None
|
return full_execution, [], None
|
||||||
|
|
||||||
# Multiple executions - return latest with full details, plus list of available
|
# Multiple executions - return latest with full details, plus list of available
|
||||||
full_execution = await exec_db.get_graph_execution(
|
full_execution = await execution_db.get_graph_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
execution_id=executions[0].id,
|
execution_id=executions[0].id,
|
||||||
include_node_executions=False,
|
include_node_executions=False,
|
||||||
@@ -383,7 +380,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
and not input_data.store_slug
|
and not input_data.store_slug
|
||||||
):
|
):
|
||||||
# Fetch execution directly to get graph_id
|
# Fetch execution directly to get graph_id
|
||||||
execution = await execution_db().get_graph_execution(
|
execution = await execution_db.get_graph_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
execution_id=input_data.execution_id,
|
execution_id=input_data.execution_id,
|
||||||
include_node_executions=False,
|
include_node_executions=False,
|
||||||
@@ -395,7 +392,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Find library agent by graph_id
|
# Find library agent by graph_id
|
||||||
agent = await library_db().get_library_agent_by_graph_id(
|
agent = await library_db.get_library_agent_by_graph_id(
|
||||||
user_id, execution.graph_id
|
user_id, execution.graph_id
|
||||||
)
|
)
|
||||||
if not agent:
|
if not agent:
|
||||||
@@ -4,7 +4,8 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from backend.data.db_accessors import library_db, store_db
|
from backend.api.features.library import db as library_db
|
||||||
|
from backend.api.features.store import db as store_db
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .models import (
|
from .models import (
|
||||||
@@ -44,10 +45,8 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
|||||||
Returns:
|
Returns:
|
||||||
AgentInfo if found, None otherwise
|
AgentInfo if found, None otherwise
|
||||||
"""
|
"""
|
||||||
lib_db = library_db()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
|
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||||
if agent:
|
if agent:
|
||||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||||
return AgentInfo(
|
return AgentInfo(
|
||||||
@@ -72,7 +71,7 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent = await lib_db.get_library_agent(agent_id, user_id)
|
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
if agent:
|
if agent:
|
||||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||||
return AgentInfo(
|
return AgentInfo(
|
||||||
@@ -134,7 +133,7 @@ async def search_agents(
|
|||||||
try:
|
try:
|
||||||
if source == "marketplace":
|
if source == "marketplace":
|
||||||
logger.info(f"Searching marketplace for: {query}")
|
logger.info(f"Searching marketplace for: {query}")
|
||||||
results = await store_db().get_store_agents(search_query=query, page_size=5)
|
results = await store_db.get_store_agents(search_query=query, page_size=5)
|
||||||
for agent in results.agents:
|
for agent in results.agents:
|
||||||
agents.append(
|
agents.append(
|
||||||
AgentInfo(
|
AgentInfo(
|
||||||
@@ -160,7 +159,7 @@ async def search_agents(
|
|||||||
|
|
||||||
if not agents:
|
if not agents:
|
||||||
logger.info(f"Searching user library for: {query}")
|
logger.info(f"Searching user library for: {query}")
|
||||||
results = await library_db().list_library_agents(
|
results = await library_db.list_library_agents(
|
||||||
user_id=user_id, # type: ignore[arg-type]
|
user_id=user_id, # type: ignore[arg-type]
|
||||||
search_term=query,
|
search_term=query,
|
||||||
page_size=10,
|
page_size=10,
|
||||||
@@ -5,8 +5,8 @@ from typing import Any
|
|||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||||
|
|
||||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||||
|
|
||||||
@@ -11,11 +11,18 @@ available (e.g. macOS development).
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
from .base import BaseTool
|
from backend.api.features.chat.tools.models import (
|
||||||
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
|
BashExecResponse,
|
||||||
from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
|
ErrorResponse,
|
||||||
|
ToolResponseBase,
|
||||||
|
)
|
||||||
|
from backend.api.features.chat.tools.sandbox import (
|
||||||
|
get_workspace_dir,
|
||||||
|
has_full_sandbox,
|
||||||
|
run_sandboxed,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -3,10 +3,13 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
from .base import BaseTool
|
from backend.api.features.chat.tools.models import (
|
||||||
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
ErrorResponse,
|
||||||
|
ResponseType,
|
||||||
|
ToolResponseBase,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -75,7 +78,7 @@ class CheckOperationStatusTool(BaseTool):
|
|||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
from backend.copilot import stream_registry
|
from backend.api.features.chat import stream_registry
|
||||||
|
|
||||||
operation_id = (kwargs.get("operation_id") or "").strip()
|
operation_id = (kwargs.get("operation_id") or "").strip()
|
||||||
task_id = (kwargs.get("task_id") or "").strip()
|
task_id = (kwargs.get("task_id") or "").strip()
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
@@ -3,9 +3,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.store import db as store_db
|
||||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||||
from backend.copilot.model import ChatSession
|
|
||||||
from backend.data.db_accessors import store_db as get_store_db
|
|
||||||
|
|
||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
@@ -137,8 +137,6 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
|
|
||||||
creator_username, agent_slug = parts
|
creator_username, agent_slug = parts
|
||||||
|
|
||||||
store_db = get_store_db()
|
|
||||||
|
|
||||||
# Fetch the marketplace agent details
|
# Fetch the marketplace agent details
|
||||||
try:
|
try:
|
||||||
agent_details = await store_db.get_store_agent_details(
|
agent_details = await store_db.get_store_agent_details(
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
@@ -5,14 +5,9 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.blocks.linear._api import LinearClient
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
from backend.data.db_accessors import user_db
|
from backend.api.features.chat.tools.models import (
|
||||||
from backend.data.model import APIKeyCredentials
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
|
|
||||||
from .base import BaseTool
|
|
||||||
from .models import (
|
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
FeatureRequestCreatedResponse,
|
FeatureRequestCreatedResponse,
|
||||||
FeatureRequestInfo,
|
FeatureRequestInfo,
|
||||||
@@ -20,6 +15,10 @@ from .models import (
|
|||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.linear._api import LinearClient
|
||||||
|
from backend.data.model import APIKeyCredentials
|
||||||
|
from backend.data.user import get_user_email_by_id
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -105,8 +104,8 @@ def _get_linear_config() -> tuple[LinearClient, str, str]:
|
|||||||
Raises RuntimeError if any required setting is missing.
|
Raises RuntimeError if any required setting is missing.
|
||||||
"""
|
"""
|
||||||
secrets = _get_settings().secrets
|
secrets = _get_settings().secrets
|
||||||
if not secrets.copilot_linear_api_key:
|
if not secrets.linear_api_key:
|
||||||
raise RuntimeError("COPILOT_LINEAR_API_KEY is not configured")
|
raise RuntimeError("LINEAR_API_KEY is not configured")
|
||||||
if not secrets.linear_feature_request_project_id:
|
if not secrets.linear_feature_request_project_id:
|
||||||
raise RuntimeError("LINEAR_FEATURE_REQUEST_PROJECT_ID is not configured")
|
raise RuntimeError("LINEAR_FEATURE_REQUEST_PROJECT_ID is not configured")
|
||||||
if not secrets.linear_feature_request_team_id:
|
if not secrets.linear_feature_request_team_id:
|
||||||
@@ -115,7 +114,7 @@ def _get_linear_config() -> tuple[LinearClient, str, str]:
|
|||||||
credentials = APIKeyCredentials(
|
credentials = APIKeyCredentials(
|
||||||
id="system-linear",
|
id="system-linear",
|
||||||
provider="linear",
|
provider="linear",
|
||||||
api_key=SecretStr(secrets.copilot_linear_api_key),
|
api_key=SecretStr(secrets.linear_api_key),
|
||||||
title="System Linear API Key",
|
title="System Linear API Key",
|
||||||
)
|
)
|
||||||
client = LinearClient(credentials=credentials)
|
client = LinearClient(credentials=credentials)
|
||||||
@@ -333,9 +332,7 @@ class CreateFeatureRequestTool(BaseTool):
|
|||||||
# Resolve a human-readable name (email) for the Linear customer record.
|
# Resolve a human-readable name (email) for the Linear customer record.
|
||||||
# Fall back to user_id if the lookup fails or returns None.
|
# Fall back to user_id if the lookup fails or returns None.
|
||||||
try:
|
try:
|
||||||
customer_display_name = (
|
customer_display_name = await get_user_email_by_id(user_id) or user_id
|
||||||
await user_db().get_user_email_by_id(user_id) or user_id
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
customer_display_name = user_id
|
customer_display_name = user_id
|
||||||
|
|
||||||
@@ -1,18 +1,22 @@
|
|||||||
"""Tests for SearchFeatureRequestsTool and CreateFeatureRequestTool."""
|
"""Tests for SearchFeatureRequestsTool and CreateFeatureRequestTool."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ._test_data import make_session
|
from backend.api.features.chat.tools.feature_requests import (
|
||||||
from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool
|
CreateFeatureRequestTool,
|
||||||
from .models import (
|
SearchFeatureRequestsTool,
|
||||||
|
)
|
||||||
|
from backend.api.features.chat.tools.models import (
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
FeatureRequestCreatedResponse,
|
FeatureRequestCreatedResponse,
|
||||||
FeatureRequestSearchResponse,
|
FeatureRequestSearchResponse,
|
||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ._test_data import make_session
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-feature-requests"
|
_TEST_USER_ID = "test-user-feature-requests"
|
||||||
_TEST_USER_EMAIL = "testuser@example.com"
|
_TEST_USER_EMAIL = "testuser@example.com"
|
||||||
|
|
||||||
@@ -35,7 +39,7 @@ def _mock_linear_config(*, query_return=None, mutate_return=None):
|
|||||||
client.mutate.return_value = mutate_return
|
client.mutate.return_value = mutate_return
|
||||||
return (
|
return (
|
||||||
patch(
|
patch(
|
||||||
"backend.copilot.tools.feature_requests._get_linear_config",
|
"backend.api.features.chat.tools.feature_requests._get_linear_config",
|
||||||
return_value=(client, _FAKE_PROJECT_ID, _FAKE_TEAM_ID),
|
return_value=(client, _FAKE_PROJECT_ID, _FAKE_TEAM_ID),
|
||||||
),
|
),
|
||||||
client,
|
client,
|
||||||
@@ -204,7 +208,7 @@ class TestSearchFeatureRequestsTool:
|
|||||||
async def test_linear_client_init_failure(self):
|
async def test_linear_client_init_failure(self):
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.feature_requests._get_linear_config",
|
"backend.api.features.chat.tools.feature_requests._get_linear_config",
|
||||||
side_effect=RuntimeError("No API key"),
|
side_effect=RuntimeError("No API key"),
|
||||||
):
|
):
|
||||||
tool = SearchFeatureRequestsTool()
|
tool = SearchFeatureRequestsTool()
|
||||||
@@ -227,11 +231,10 @@ class TestCreateFeatureRequestTool:
|
|||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def _patch_email_lookup(self):
|
def _patch_email_lookup(self):
|
||||||
mock_user_db = MagicMock()
|
|
||||||
mock_user_db.get_user_email_by_id = AsyncMock(return_value=_TEST_USER_EMAIL)
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.feature_requests.user_db",
|
"backend.api.features.chat.tools.feature_requests.get_user_email_by_id",
|
||||||
return_value=mock_user_db,
|
new_callable=AsyncMock,
|
||||||
|
return_value=_TEST_USER_EMAIL,
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -344,7 +347,7 @@ class TestCreateFeatureRequestTool:
|
|||||||
async def test_linear_client_init_failure(self):
|
async def test_linear_client_init_failure(self):
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.feature_requests._get_linear_config",
|
"backend.api.features.chat.tools.feature_requests._get_linear_config",
|
||||||
side_effect=RuntimeError("No API key"),
|
side_effect=RuntimeError("No API key"),
|
||||||
):
|
):
|
||||||
tool = CreateFeatureRequestTool()
|
tool = CreateFeatureRequestTool()
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_search import search_agents
|
from .agent_search import search_agents
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
@@ -3,18 +3,17 @@ from typing import Any
|
|||||||
|
|
||||||
from prisma.enums import ContentType
|
from prisma.enums import ContentType
|
||||||
|
|
||||||
from backend.blocks import get_block
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.blocks._base import BlockType
|
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.tools.models import (
|
||||||
from backend.data.db_accessors import search
|
|
||||||
|
|
||||||
from .base import BaseTool, ToolResponseBase
|
|
||||||
from .models import (
|
|
||||||
BlockInfoSummary,
|
BlockInfoSummary,
|
||||||
BlockListResponse,
|
BlockListResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
)
|
)
|
||||||
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
|
from backend.blocks import get_block
|
||||||
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -108,7 +107,7 @@ class FindBlockTool(BaseTool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Search for blocks using hybrid search
|
# Search for blocks using hybrid search
|
||||||
results, total = await search().unified_hybrid_search(
|
results, total = await unified_hybrid_search(
|
||||||
query=query,
|
query=query,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
@@ -4,15 +4,15 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from backend.blocks._base import BlockType
|
from backend.api.features.chat.tools.find_block import (
|
||||||
|
|
||||||
from ._test_data import make_session
|
|
||||||
from .find_block import (
|
|
||||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||||
FindBlockTool,
|
FindBlockTool,
|
||||||
)
|
)
|
||||||
from .models import BlockListResponse
|
from backend.api.features.chat.tools.models import BlockListResponse
|
||||||
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
|
from ._test_data import make_session
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-find-block"
|
_TEST_USER_ID = "test-user-find-block"
|
||||||
|
|
||||||
@@ -84,17 +84,13 @@ class TestFindBlockFiltering:
|
|||||||
"standard-block-id": standard_block,
|
"standard-block-id": standard_block,
|
||||||
}.get(block_id)
|
}.get(block_id)
|
||||||
|
|
||||||
mock_search_db = MagicMock()
|
|
||||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
|
||||||
return_value=(search_results, 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.find_block.search",
|
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||||
return_value=mock_search_db,
|
new_callable=AsyncMock,
|
||||||
|
return_value=(search_results, 2),
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.find_block.get_block",
|
"backend.api.features.chat.tools.find_block.get_block",
|
||||||
side_effect=mock_get_block,
|
side_effect=mock_get_block,
|
||||||
):
|
):
|
||||||
tool = FindBlockTool()
|
tool = FindBlockTool()
|
||||||
@@ -132,17 +128,13 @@ class TestFindBlockFiltering:
|
|||||||
"normal-block-id": normal_block,
|
"normal-block-id": normal_block,
|
||||||
}.get(block_id)
|
}.get(block_id)
|
||||||
|
|
||||||
mock_search_db = MagicMock()
|
|
||||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
|
||||||
return_value=(search_results, 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.find_block.search",
|
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||||
return_value=mock_search_db,
|
new_callable=AsyncMock,
|
||||||
|
return_value=(search_results, 2),
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.find_block.get_block",
|
"backend.api.features.chat.tools.find_block.get_block",
|
||||||
side_effect=mock_get_block,
|
side_effect=mock_get_block,
|
||||||
):
|
):
|
||||||
tool = FindBlockTool()
|
tool = FindBlockTool()
|
||||||
@@ -361,16 +353,12 @@ class TestFindBlockFiltering:
|
|||||||
for d in block_defs
|
for d in block_defs
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_search_db = MagicMock()
|
|
||||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
|
||||||
return_value=(search_results, len(search_results))
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.find_block.search",
|
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||||
return_value=mock_search_db,
|
new_callable=AsyncMock,
|
||||||
|
return_value=(search_results, len(search_results)),
|
||||||
), patch(
|
), patch(
|
||||||
"backend.copilot.tools.find_block.get_block",
|
"backend.api.features.chat.tools.find_block.get_block",
|
||||||
side_effect=lambda bid: mock_blocks.get(bid),
|
side_effect=lambda bid: mock_blocks.get(bid),
|
||||||
):
|
):
|
||||||
tool = FindBlockTool()
|
tool = FindBlockTool()
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_search import search_agents
|
from .agent_search import search_agents
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
@@ -4,10 +4,13 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
from .base import BaseTool
|
from backend.api.features.chat.tools.models import (
|
||||||
from .models import DocPageResponse, ErrorResponse, ToolResponseBase
|
DocPageResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
ToolResponseBase,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -5,12 +5,16 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from backend.copilot.config import ChatConfig
|
from backend.api.features.chat.config import ChatConfig
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
|
from backend.api.features.chat.tracking import (
|
||||||
from backend.data.db_accessors import graph_db, library_db, user_db
|
track_agent_run_success,
|
||||||
|
track_agent_scheduled,
|
||||||
|
)
|
||||||
|
from backend.api.features.library import db as library_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
|
from backend.data.user import get_user_by_id
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
@@ -196,7 +200,7 @@ class RunAgentTool(BaseTool):
|
|||||||
|
|
||||||
# Priority: library_agent_id if provided
|
# Priority: library_agent_id if provided
|
||||||
if has_library_id:
|
if has_library_id:
|
||||||
library_agent = await library_db().get_library_agent(
|
library_agent = await library_db.get_library_agent(
|
||||||
params.library_agent_id, user_id
|
params.library_agent_id, user_id
|
||||||
)
|
)
|
||||||
if not library_agent:
|
if not library_agent:
|
||||||
@@ -205,7 +209,9 @@ class RunAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
# Get the graph from the library agent
|
# Get the graph from the library agent
|
||||||
graph = await graph_db().get_graph(
|
from backend.data.graph import get_graph
|
||||||
|
|
||||||
|
graph = await get_graph(
|
||||||
library_agent.graph_id,
|
library_agent.graph_id,
|
||||||
library_agent.graph_version,
|
library_agent.graph_version,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -516,7 +522,7 @@ class RunAgentTool(BaseTool):
|
|||||||
library_agent = await get_or_create_library_agent(graph, user_id)
|
library_agent = await get_or_create_library_agent(graph, user_id)
|
||||||
|
|
||||||
# Get user timezone
|
# Get user timezone
|
||||||
user = await user_db().get_user_by_id(user_id)
|
user = await get_user_by_id(user_id)
|
||||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
|
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
|
||||||
|
|
||||||
# Create schedule
|
# Create schedule
|
||||||
@@ -7,17 +7,20 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tools.find_block import (
|
||||||
|
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||||
|
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||||
|
)
|
||||||
from backend.blocks import get_block
|
from backend.blocks import get_block
|
||||||
from backend.blocks._base import AnyBlockSchema
|
from backend.blocks._base import AnyBlockSchema
|
||||||
from backend.copilot.model import ChatSession
|
|
||||||
from backend.data.db_accessors import workspace_db
|
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
from .helpers import get_inputs_from_schema
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
BlockDetails,
|
BlockDetails,
|
||||||
@@ -273,7 +276,7 @@ class RunBlockTool(BaseTool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Get or create user's workspace for CoPilot file operations
|
# Get or create user's workspace for CoPilot file operations
|
||||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
|
||||||
# Generate synthetic IDs for CoPilot context
|
# Generate synthetic IDs for CoPilot context
|
||||||
# Each chat session is treated as its own agent with one continuous run
|
# Each chat session is treated as its own agent with one continuous run
|
||||||
@@ -4,16 +4,16 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from backend.blocks._base import BlockType
|
from backend.api.features.chat.tools.models import (
|
||||||
|
|
||||||
from ._test_data import make_session
|
|
||||||
from .models import (
|
|
||||||
BlockDetailsResponse,
|
BlockDetailsResponse,
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
InputValidationErrorResponse,
|
InputValidationErrorResponse,
|
||||||
)
|
)
|
||||||
from .run_block import RunBlockTool
|
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||||
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
|
from ._test_data import make_session
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-run-block"
|
_TEST_USER_ID = "test-user-run-block"
|
||||||
|
|
||||||
@@ -77,7 +77,7 @@ class TestRunBlockFiltering:
|
|||||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.run_block.get_block",
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
return_value=input_block,
|
return_value=input_block,
|
||||||
):
|
):
|
||||||
tool = RunBlockTool()
|
tool = RunBlockTool()
|
||||||
@@ -103,7 +103,7 @@ class TestRunBlockFiltering:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.run_block.get_block",
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
return_value=smart_block,
|
return_value=smart_block,
|
||||||
):
|
):
|
||||||
tool = RunBlockTool()
|
tool = RunBlockTool()
|
||||||
@@ -127,7 +127,7 @@ class TestRunBlockFiltering:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.run_block.get_block",
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
return_value=standard_block,
|
return_value=standard_block,
|
||||||
):
|
):
|
||||||
tool = RunBlockTool()
|
tool = RunBlockTool()
|
||||||
@@ -183,7 +183,7 @@ class TestRunBlockInputValidation:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.run_block.get_block",
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
return_value=mock_block,
|
return_value=mock_block,
|
||||||
):
|
):
|
||||||
tool = RunBlockTool()
|
tool = RunBlockTool()
|
||||||
@@ -222,7 +222,7 @@ class TestRunBlockInputValidation:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.run_block.get_block",
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
return_value=mock_block,
|
return_value=mock_block,
|
||||||
):
|
):
|
||||||
tool = RunBlockTool()
|
tool = RunBlockTool()
|
||||||
@@ -263,7 +263,7 @@ class TestRunBlockInputValidation:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.run_block.get_block",
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
return_value=mock_block,
|
return_value=mock_block,
|
||||||
):
|
):
|
||||||
tool = RunBlockTool()
|
tool = RunBlockTool()
|
||||||
@@ -302,19 +302,15 @@ class TestRunBlockInputValidation:
|
|||||||
|
|
||||||
mock_block.execute = mock_execute
|
mock_block.execute = mock_execute
|
||||||
|
|
||||||
mock_workspace_db = MagicMock()
|
|
||||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
|
||||||
return_value=MagicMock(id="test-workspace-id")
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"backend.copilot.tools.run_block.get_block",
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
return_value=mock_block,
|
return_value=mock_block,
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"backend.copilot.tools.run_block.workspace_db",
|
"backend.api.features.chat.tools.run_block.get_or_create_workspace",
|
||||||
return_value=mock_workspace_db,
|
new_callable=AsyncMock,
|
||||||
|
return_value=MagicMock(id="test-workspace-id"),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
tool = RunBlockTool()
|
tool = RunBlockTool()
|
||||||
@@ -348,7 +344,7 @@ class TestRunBlockInputValidation:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.run_block.get_block",
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
return_value=mock_block,
|
return_value=mock_block,
|
||||||
):
|
):
|
||||||
tool = RunBlockTool()
|
tool = RunBlockTool()
|
||||||
@@ -5,17 +5,16 @@ from typing import Any
|
|||||||
|
|
||||||
from prisma.enums import ContentType
|
from prisma.enums import ContentType
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.db_accessors import search
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
|
from backend.api.features.chat.tools.models import (
|
||||||
from .base import BaseTool
|
|
||||||
from .models import (
|
|
||||||
DocSearchResult,
|
DocSearchResult,
|
||||||
DocSearchResultsResponse,
|
DocSearchResultsResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
)
|
)
|
||||||
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -118,7 +117,7 @@ class SearchDocsTool(BaseTool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Search using hybrid search for DOCUMENTATION content type only
|
# Search using hybrid search for DOCUMENTATION content type only
|
||||||
results, total = await search().unified_hybrid_search(
|
results, total = await unified_hybrid_search(
|
||||||
query=query,
|
query=query,
|
||||||
content_types=[ContentType.DOCUMENTATION],
|
content_types=[ContentType.DOCUMENTATION],
|
||||||
page=1,
|
page=1,
|
||||||
@@ -4,13 +4,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.models import BlockDetailsResponse
|
||||||
|
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||||
from backend.blocks._base import BlockType
|
from backend.blocks._base import BlockType
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
from ._test_data import make_session
|
from ._test_data import make_session
|
||||||
from .models import BlockDetailsResponse
|
|
||||||
from .run_block import RunBlockTool
|
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-run-block-details"
|
_TEST_USER_ID = "test-user-run-block-details"
|
||||||
|
|
||||||
@@ -61,7 +61,7 @@ async def test_run_block_returns_details_when_no_input_provided():
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.run_block.get_block",
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
return_value=http_block,
|
return_value=http_block,
|
||||||
):
|
):
|
||||||
# Mock credentials check to return no missing credentials
|
# Mock credentials check to return no missing credentials
|
||||||
@@ -120,7 +120,7 @@ async def test_run_block_returns_details_when_only_credentials_provided():
|
|||||||
}
|
}
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.copilot.tools.run_block.get_block",
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
return_value=mock,
|
return_value=mock,
|
||||||
):
|
):
|
||||||
with patch.object(
|
with patch.object(
|
||||||
@@ -3,8 +3,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.library import model as library_model
|
from backend.api.features.library import model as library_model
|
||||||
from backend.data.db_accessors import library_db, store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
Credentials,
|
Credentials,
|
||||||
@@ -38,14 +39,13 @@ async def fetch_graph_from_store_slug(
|
|||||||
Raises:
|
Raises:
|
||||||
DatabaseError: If there's a database error during lookup.
|
DatabaseError: If there's a database error during lookup.
|
||||||
"""
|
"""
|
||||||
sdb = store_db()
|
|
||||||
try:
|
try:
|
||||||
store_agent = await sdb.get_store_agent_details(username, agent_name)
|
store_agent = await store_db.get_store_agent_details(username, agent_name)
|
||||||
except NotFoundError:
|
except NotFoundError:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Get the graph from store listing version
|
# Get the graph from store listing version
|
||||||
graph = await sdb.get_available_graph(
|
graph = await store_db.get_available_graph(
|
||||||
store_agent.store_listing_version_id, hide_nodes=False
|
store_agent.store_listing_version_id, hide_nodes=False
|
||||||
)
|
)
|
||||||
return graph, store_agent
|
return graph, store_agent
|
||||||
@@ -210,13 +210,13 @@ async def get_or_create_library_agent(
|
|||||||
Returns:
|
Returns:
|
||||||
LibraryAgent instance
|
LibraryAgent instance
|
||||||
"""
|
"""
|
||||||
existing = await library_db().get_library_agent_by_graph_id(
|
existing = await library_db.get_library_agent_by_graph_id(
|
||||||
graph_id=graph.id, user_id=user_id
|
graph_id=graph.id, user_id=user_id
|
||||||
)
|
)
|
||||||
if existing:
|
if existing:
|
||||||
return existing
|
return existing
|
||||||
|
|
||||||
library_agents = await library_db().create_library_agent(
|
library_agents = await library_db.create_library_agent(
|
||||||
graph=graph,
|
graph=graph,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
create_library_agents_for_sub_graphs=False,
|
create_library_agents_for_sub_graphs=False,
|
||||||
@@ -6,12 +6,15 @@ from typing import Any
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import html2text
|
import html2text
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
|
from backend.api.features.chat.tools.models import (
|
||||||
|
ErrorResponse,
|
||||||
|
ToolResponseBase,
|
||||||
|
WebFetchResponse,
|
||||||
|
)
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
from .base import BaseTool
|
|
||||||
from .models import ErrorResponse, ToolResponseBase, WebFetchResponse
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Limits
|
# Limits
|
||||||
@@ -6,8 +6,8 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.db_accessors import workspace_db
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
from backend.util.workspace import WorkspaceManager
|
from backend.util.workspace import WorkspaceManager
|
||||||
@@ -148,7 +148,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
|||||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
workspace = await get_or_create_workspace(user_id)
|
||||||
# Pass session_id for session-scoped file access
|
# Pass session_id for session-scoped file access
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
@@ -167,8 +167,8 @@ class ListWorkspaceFilesTool(BaseTool):
|
|||||||
file_id=f.id,
|
file_id=f.id,
|
||||||
name=f.name,
|
name=f.name,
|
||||||
path=f.path,
|
path=f.path,
|
||||||
mime_type=f.mime_type,
|
mime_type=f.mimeType,
|
||||||
size_bytes=f.size_bytes,
|
size_bytes=f.sizeBytes,
|
||||||
)
|
)
|
||||||
for f in files
|
for f in files
|
||||||
]
|
]
|
||||||
@@ -284,7 +284,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
workspace = await get_or_create_workspace(user_id)
|
||||||
# Pass session_id for session-scoped file access
|
# Pass session_id for session-scoped file access
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
@@ -309,8 +309,8 @@ class ReadWorkspaceFileTool(BaseTool):
|
|||||||
target_file_id = file_info.id
|
target_file_id = file_info.id
|
||||||
|
|
||||||
# Decide whether to return inline content or metadata+URL
|
# Decide whether to return inline content or metadata+URL
|
||||||
is_small_file = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
|
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
|
||||||
is_text_file = self._is_text_mime_type(file_info.mime_type)
|
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
||||||
|
|
||||||
# Return inline content for small text files (unless force_download_url)
|
# Return inline content for small text files (unless force_download_url)
|
||||||
if is_small_file and is_text_file and not force_download_url:
|
if is_small_file and is_text_file and not force_download_url:
|
||||||
@@ -321,7 +321,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
|||||||
file_id=file_info.id,
|
file_id=file_info.id,
|
||||||
name=file_info.name,
|
name=file_info.name,
|
||||||
path=file_info.path,
|
path=file_info.path,
|
||||||
mime_type=file_info.mime_type,
|
mime_type=file_info.mimeType,
|
||||||
content_base64=content_b64,
|
content_base64=content_b64,
|
||||||
message=f"Successfully read file: {file_info.name}",
|
message=f"Successfully read file: {file_info.name}",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -350,11 +350,11 @@ class ReadWorkspaceFileTool(BaseTool):
|
|||||||
file_id=file_info.id,
|
file_id=file_info.id,
|
||||||
name=file_info.name,
|
name=file_info.name,
|
||||||
path=file_info.path,
|
path=file_info.path,
|
||||||
mime_type=file_info.mime_type,
|
mime_type=file_info.mimeType,
|
||||||
size_bytes=file_info.size_bytes,
|
size_bytes=file_info.sizeBytes,
|
||||||
download_url=download_url,
|
download_url=download_url,
|
||||||
preview=preview,
|
preview=preview,
|
||||||
message=f"File: {file_info.name} ({file_info.size_bytes} bytes). Use download_url to retrieve content.",
|
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -484,7 +484,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
|||||||
# Virus scan
|
# Virus scan
|
||||||
await scan_content_safe(content, filename=filename)
|
await scan_content_safe(content, filename=filename)
|
||||||
|
|
||||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
workspace = await get_or_create_workspace(user_id)
|
||||||
# Pass session_id for session-scoped file access
|
# Pass session_id for session-scoped file access
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
@@ -500,7 +500,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
|||||||
file_id=file_record.id,
|
file_id=file_record.id,
|
||||||
name=file_record.name,
|
name=file_record.name,
|
||||||
path=file_record.path,
|
path=file_record.path,
|
||||||
size_bytes=file_record.size_bytes,
|
size_bytes=file_record.sizeBytes,
|
||||||
message=f"Successfully wrote file: {file_record.name}",
|
message=f"Successfully wrote file: {file_record.name}",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
@@ -583,7 +583,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
workspace = await get_or_create_workspace(user_id)
|
||||||
# Pass session_id for session-scoped file access
|
# Pass session_id for session-scoped file access
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
@@ -393,6 +393,7 @@ async def get_creators(
|
|||||||
@router.get(
|
@router.get(
|
||||||
"/creator/{username}",
|
"/creator/{username}",
|
||||||
summary="Get creator details",
|
summary="Get creator details",
|
||||||
|
operation_id="getV2GetCreatorDetails",
|
||||||
tags=["store", "public"],
|
tags=["store", "public"],
|
||||||
response_model=store_model.CreatorDetails,
|
response_model=store_model.CreatorDetails,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import fastapi
|
|||||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
|
|
||||||
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file
|
from backend.data.workspace import get_workspace, get_workspace_file
|
||||||
from backend.util.workspace_storage import get_workspace_storage
|
from backend.util.workspace_storage import get_workspace_storage
|
||||||
|
|
||||||
|
|
||||||
@@ -44,11 +44,11 @@ router = fastapi.APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
|
def _create_streaming_response(content: bytes, file) -> Response:
|
||||||
"""Create a streaming response for file content."""
|
"""Create a streaming response for file content."""
|
||||||
return Response(
|
return Response(
|
||||||
content=content,
|
content=content,
|
||||||
media_type=file.mime_type,
|
media_type=file.mimeType,
|
||||||
headers={
|
headers={
|
||||||
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
||||||
"Content-Length": str(len(content)),
|
"Content-Length": str(len(content)),
|
||||||
@@ -56,7 +56,7 @@ def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
async def _create_file_download_response(file) -> Response:
|
||||||
"""
|
"""
|
||||||
Create a download response for a workspace file.
|
Create a download response for a workspace file.
|
||||||
|
|
||||||
@@ -66,33 +66,33 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
|||||||
storage = await get_workspace_storage()
|
storage = await get_workspace_storage()
|
||||||
|
|
||||||
# For local storage, stream the file directly
|
# For local storage, stream the file directly
|
||||||
if file.storage_path.startswith("local://"):
|
if file.storagePath.startswith("local://"):
|
||||||
content = await storage.retrieve(file.storage_path)
|
content = await storage.retrieve(file.storagePath)
|
||||||
return _create_streaming_response(content, file)
|
return _create_streaming_response(content, file)
|
||||||
|
|
||||||
# For GCS, try to redirect to signed URL, fall back to streaming
|
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||||
try:
|
try:
|
||||||
url = await storage.get_download_url(file.storage_path, expires_in=300)
|
url = await storage.get_download_url(file.storagePath, expires_in=300)
|
||||||
# If we got back an API path (fallback), stream directly instead
|
# If we got back an API path (fallback), stream directly instead
|
||||||
if url.startswith("/api/"):
|
if url.startswith("/api/"):
|
||||||
content = await storage.retrieve(file.storage_path)
|
content = await storage.retrieve(file.storagePath)
|
||||||
return _create_streaming_response(content, file)
|
return _create_streaming_response(content, file)
|
||||||
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log the signed URL failure with context
|
# Log the signed URL failure with context
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to get signed URL for file {file.id} "
|
f"Failed to get signed URL for file {file.id} "
|
||||||
f"(storagePath={file.storage_path}): {e}",
|
f"(storagePath={file.storagePath}): {e}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
# Fall back to streaming directly from GCS
|
# Fall back to streaming directly from GCS
|
||||||
try:
|
try:
|
||||||
content = await storage.retrieve(file.storage_path)
|
content = await storage.retrieve(file.storagePath)
|
||||||
return _create_streaming_response(content, file)
|
return _create_streaming_response(content, file)
|
||||||
except Exception as fallback_error:
|
except Exception as fallback_error:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Fallback streaming also failed for file {file.id} "
|
f"Fallback streaming also failed for file {file.id} "
|
||||||
f"(storagePath={file.storage_path}): {fallback_error}",
|
f"(storagePath={file.storagePath}): {fallback_error}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
|
|||||||
|
|
||||||
import backend.api.features.admin.credit_admin_routes
|
import backend.api.features.admin.credit_admin_routes
|
||||||
import backend.api.features.admin.execution_analytics_routes
|
import backend.api.features.admin.execution_analytics_routes
|
||||||
|
import backend.api.features.admin.llm_routes
|
||||||
import backend.api.features.admin.store_admin_routes
|
import backend.api.features.admin.store_admin_routes
|
||||||
import backend.api.features.builder
|
import backend.api.features.builder
|
||||||
import backend.api.features.builder.routes
|
import backend.api.features.builder.routes
|
||||||
@@ -39,13 +40,15 @@ import backend.data.db
|
|||||||
import backend.data.graph
|
import backend.data.graph
|
||||||
import backend.data.user
|
import backend.data.user
|
||||||
import backend.integrations.webhooks.utils
|
import backend.integrations.webhooks.utils
|
||||||
|
import backend.server.v2.llm.routes as public_llm_routes
|
||||||
import backend.util.service
|
import backend.util.service
|
||||||
import backend.util.settings
|
import backend.util.settings
|
||||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
from backend.api.features.chat.completion_consumer import (
|
||||||
from backend.copilot.completion_consumer import (
|
|
||||||
start_completion_consumer,
|
start_completion_consumer,
|
||||||
stop_completion_consumer,
|
stop_completion_consumer,
|
||||||
)
|
)
|
||||||
|
from backend.data import llm_registry
|
||||||
|
from backend.data.block_cost_config import refresh_llm_costs
|
||||||
from backend.data.model import Credentials
|
from backend.data.model import Credentials
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.monitoring.instrumentation import instrument_fastapi
|
from backend.monitoring.instrumentation import instrument_fastapi
|
||||||
@@ -116,11 +119,27 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||||||
|
|
||||||
AutoRegistry.patch_integrations()
|
AutoRegistry.patch_integrations()
|
||||||
|
|
||||||
|
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||||
|
await llm_registry.refresh_llm_registry()
|
||||||
|
await refresh_llm_costs()
|
||||||
|
|
||||||
|
# Clear block schema caches so they're regenerated with updated discriminator_mapping
|
||||||
|
from backend.blocks._base import BlockSchema
|
||||||
|
|
||||||
|
BlockSchema.clear_all_schema_caches()
|
||||||
|
|
||||||
await backend.data.block.initialize_blocks()
|
await backend.data.block.initialize_blocks()
|
||||||
|
|
||||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||||
await backend.data.graph.fix_llm_provider_credentials()
|
await backend.data.graph.fix_llm_provider_credentials()
|
||||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
# migrate_llm_models uses registry default model
|
||||||
|
from backend.blocks.llm import LlmModel
|
||||||
|
|
||||||
|
default_model_slug = llm_registry.get_default_model_slug()
|
||||||
|
if default_model_slug:
|
||||||
|
await backend.data.graph.migrate_llm_models(LlmModel(default_model_slug))
|
||||||
|
else:
|
||||||
|
logger.warning("Skipping LLM model migration: no default model available")
|
||||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||||
|
|
||||||
# Start chat completion consumer for Redis Streams notifications
|
# Start chat completion consumer for Redis Streams notifications
|
||||||
@@ -322,6 +341,16 @@ app.include_router(
|
|||||||
tags=["v2", "executions", "review"],
|
tags=["v2", "executions", "review"],
|
||||||
prefix="/api/review",
|
prefix="/api/review",
|
||||||
)
|
)
|
||||||
|
app.include_router(
|
||||||
|
backend.api.features.admin.llm_routes.router,
|
||||||
|
tags=["v2", "admin", "llm"],
|
||||||
|
prefix="/api/llm/admin",
|
||||||
|
)
|
||||||
|
app.include_router(
|
||||||
|
public_llm_routes.router,
|
||||||
|
tags=["v2", "llm"],
|
||||||
|
prefix="/api",
|
||||||
|
)
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.library.routes.router, tags=["v2"], prefix="/api/library"
|
backend.api.features.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -79,11 +79,49 @@ async def event_broadcaster(manager: ConnectionManager):
|
|||||||
payload=notification.payload,
|
payload=notification.payload,
|
||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.gather(execution_worker(), notification_worker())
|
# Track registry pubsub for cleanup
|
||||||
|
registry_pubsub = None
|
||||||
|
|
||||||
|
async def registry_refresh_worker():
|
||||||
|
"""Listen for LLM registry refresh notifications and broadcast to all clients."""
|
||||||
|
nonlocal registry_pubsub
|
||||||
|
from backend.data.llm_registry import REGISTRY_REFRESH_CHANNEL
|
||||||
|
from backend.data.redis_client import connect_async
|
||||||
|
|
||||||
|
redis = await connect_async()
|
||||||
|
registry_pubsub = redis.pubsub()
|
||||||
|
await registry_pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||||
|
logger.info(
|
||||||
|
"Subscribed to LLM registry refresh notifications for WebSocket broadcast"
|
||||||
|
)
|
||||||
|
|
||||||
|
async for message in registry_pubsub.listen():
|
||||||
|
if (
|
||||||
|
message["type"] == "message"
|
||||||
|
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Broadcasting LLM registry refresh to all WebSocket clients"
|
||||||
|
)
|
||||||
|
await manager.broadcast_to_all(
|
||||||
|
method=WSMethod.NOTIFICATION,
|
||||||
|
data={
|
||||||
|
"type": "LLM_REGISTRY_REFRESH",
|
||||||
|
"event": "registry_updated",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.gather(
|
||||||
|
execution_worker(),
|
||||||
|
notification_worker(),
|
||||||
|
registry_refresh_worker(),
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
# Ensure PubSub connections are closed on any exit to prevent leaks
|
# Ensure PubSub connections are closed on any exit to prevent leaks
|
||||||
await execution_bus.close()
|
await execution_bus.close()
|
||||||
await notification_bus.close()
|
await notification_bus.close()
|
||||||
|
if registry_pubsub:
|
||||||
|
await registry_pubsub.close()
|
||||||
|
|
||||||
|
|
||||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||||
|
|||||||
@@ -38,9 +38,7 @@ def main(**kwargs):
|
|||||||
|
|
||||||
from backend.api.rest_api import AgentServer
|
from backend.api.rest_api import AgentServer
|
||||||
from backend.api.ws_api import WebsocketServer
|
from backend.api.ws_api import WebsocketServer
|
||||||
from backend.copilot.executor.manager import CoPilotExecutor
|
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||||
from backend.data.db_manager import DatabaseManager
|
|
||||||
from backend.executor import ExecutionManager, Scheduler
|
|
||||||
from backend.notifications import NotificationManager
|
from backend.notifications import NotificationManager
|
||||||
|
|
||||||
run_processes(
|
run_processes(
|
||||||
@@ -50,7 +48,6 @@ def main(**kwargs):
|
|||||||
WebsocketServer(),
|
WebsocketServer(),
|
||||||
AgentServer(),
|
AgentServer(),
|
||||||
ExecutionManager(),
|
ExecutionManager(),
|
||||||
CoPilotExecutor(),
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -134,7 +134,26 @@ class BlockInfo(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class BlockSchema(BaseModel):
|
class BlockSchema(BaseModel):
|
||||||
cached_jsonschema: ClassVar[dict[str, Any]]
|
cached_jsonschema: ClassVar[dict[str, Any] | None] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clear_schema_cache(cls) -> None:
|
||||||
|
"""Clear the cached JSON schema for this class."""
|
||||||
|
# Use None instead of {} because {} is truthy and would prevent regeneration
|
||||||
|
cls.cached_jsonschema = None # type: ignore
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clear_all_schema_caches() -> None:
|
||||||
|
"""Clear cached JSON schemas for all BlockSchema subclasses."""
|
||||||
|
|
||||||
|
def clear_recursive(cls: type) -> None:
|
||||||
|
"""Recursively clear cache for class and all subclasses."""
|
||||||
|
if hasattr(cls, "clear_schema_cache"):
|
||||||
|
cls.clear_schema_cache()
|
||||||
|
for subclass in cls.__subclasses__():
|
||||||
|
clear_recursive(subclass)
|
||||||
|
|
||||||
|
clear_recursive(BlockSchema)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def jsonschema(cls) -> dict[str, Any]:
|
def jsonschema(cls) -> dict[str, Any]:
|
||||||
@@ -225,7 +244,8 @@ class BlockSchema(BaseModel):
|
|||||||
super().__pydantic_init_subclass__(**kwargs)
|
super().__pydantic_init_subclass__(**kwargs)
|
||||||
|
|
||||||
# Reset cached JSON schema to prevent inheriting it from parent class
|
# Reset cached JSON schema to prevent inheriting it from parent class
|
||||||
cls.cached_jsonschema = {}
|
# Use None instead of {} because {} is truthy and would prevent regeneration
|
||||||
|
cls.cached_jsonschema = None
|
||||||
|
|
||||||
credentials_fields = cls.get_credentials_fields()
|
credentials_fields = cls.get_credentials_fields()
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from backend.blocks._base import (
|
|||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.blocks.llm import (
|
from backend.blocks.llm import (
|
||||||
DEFAULT_LLM_MODEL,
|
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
TEST_CREDENTIALS_INPUT,
|
TEST_CREDENTIALS_INPUT,
|
||||||
AIBlockBase,
|
AIBlockBase,
|
||||||
@@ -16,6 +15,7 @@ from backend.blocks.llm import (
|
|||||||
LlmModel,
|
LlmModel,
|
||||||
LLMResponse,
|
LLMResponse,
|
||||||
llm_call,
|
llm_call,
|
||||||
|
llm_model_schema_extra,
|
||||||
)
|
)
|
||||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||||
|
|
||||||
@@ -50,9 +50,10 @@ class AIConditionBlock(AIBlockBase):
|
|||||||
)
|
)
|
||||||
model: LlmModel = SchemaField(
|
model: LlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
default=DEFAULT_LLM_MODEL,
|
default_factory=LlmModel.default,
|
||||||
description="The language model to use for evaluating the condition.",
|
description="The language model to use for evaluating the condition.",
|
||||||
advanced=False,
|
advanced=False,
|
||||||
|
json_schema_extra=llm_model_schema_extra(),
|
||||||
)
|
)
|
||||||
credentials: AICredentials = AICredentialsField()
|
credentials: AICredentials = AICredentialsField()
|
||||||
|
|
||||||
@@ -82,7 +83,7 @@ class AIConditionBlock(AIBlockBase):
|
|||||||
"condition": "the input is an email address",
|
"condition": "the input is an email address",
|
||||||
"yes_value": "Valid email",
|
"yes_value": "Valid email",
|
||||||
"no_value": "Not an email",
|
"no_value": "Not an email",
|
||||||
"model": DEFAULT_LLM_MODEL,
|
"model": LlmModel.default(),
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
|||||||
@@ -4,16 +4,18 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from enum import Enum, EnumMeta
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
|
from typing import Any, Iterable, List, Literal, Optional
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
import ollama
|
import ollama
|
||||||
import openai
|
import openai
|
||||||
from anthropic.types import ToolParam
|
from anthropic.types import ToolParam
|
||||||
from groq import AsyncGroq
|
from groq import AsyncGroq
|
||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import BaseModel, GetCoreSchemaHandler, SecretStr
|
||||||
|
from pydantic_core import CoreSchema, core_schema
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
@@ -22,6 +24,8 @@ from backend.blocks._base import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data import llm_registry
|
||||||
|
from backend.data.llm_registry import ModelMetadata
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -66,114 +70,123 @@ TEST_CREDENTIALS_INPUT = {
|
|||||||
|
|
||||||
|
|
||||||
def AICredentialsField() -> AICredentials:
|
def AICredentialsField() -> AICredentials:
|
||||||
|
"""
|
||||||
|
Returns a CredentialsField for LLM providers.
|
||||||
|
The discriminator_mapping will be refreshed when the schema is generated
|
||||||
|
if it's empty, ensuring the LLM registry is loaded.
|
||||||
|
"""
|
||||||
|
# Get the mapping now - it may be empty initially, but will be refreshed
|
||||||
|
# when the schema is generated via CredentialsMetaInput._add_json_schema_extra
|
||||||
|
mapping = llm_registry.get_llm_discriminator_mapping()
|
||||||
|
|
||||||
return CredentialsField(
|
return CredentialsField(
|
||||||
description="API key for the LLM provider.",
|
description="API key for the LLM provider.",
|
||||||
discriminator="model",
|
discriminator="model",
|
||||||
discriminator_mapping={
|
discriminator_mapping=mapping, # May be empty initially, refreshed later
|
||||||
model.value: model.metadata.provider for model in LlmModel
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadata(NamedTuple):
|
def llm_model_schema_extra() -> dict[str, Any]:
|
||||||
provider: str
|
return {"options": llm_registry.get_llm_model_schema_options()}
|
||||||
context_window: int
|
|
||||||
max_output_tokens: int | None
|
|
||||||
display_name: str
|
|
||||||
provider_name: str
|
|
||||||
creator_name: str
|
|
||||||
price_tier: Literal[1, 2, 3]
|
|
||||||
|
|
||||||
|
|
||||||
class LlmModelMeta(EnumMeta):
|
class LlmModelMeta(type):
|
||||||
pass
|
"""
|
||||||
|
Metaclass for LlmModel that enables attribute-style access to dynamic models.
|
||||||
|
|
||||||
|
This allows code like `LlmModel.GPT4O` to work by converting the attribute
|
||||||
|
name to a slug format:
|
||||||
|
- GPT4O -> gpt-4o
|
||||||
|
- GPT4O_MINI -> gpt-4o-mini
|
||||||
|
- CLAUDE_3_5_SONNET -> claude-3-5-sonnet
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __getattr__(cls, name: str):
|
||||||
|
# Don't intercept private/dunder attributes
|
||||||
|
if name.startswith("_"):
|
||||||
|
raise AttributeError(f"type object 'LlmModel' has no attribute '{name}'")
|
||||||
|
|
||||||
|
# Convert attribute name to slug format:
|
||||||
|
# 1. Lowercase: GPT4O -> gpt4o
|
||||||
|
# 2. Underscores to hyphens: GPT4O_MINI -> gpt4o-mini
|
||||||
|
slug = name.lower().replace("_", "-")
|
||||||
|
|
||||||
|
# Check for exact match in registry first (e.g., "o1" stays "o1")
|
||||||
|
registry_slugs = llm_registry.get_dynamic_model_slugs()
|
||||||
|
if slug in registry_slugs:
|
||||||
|
return cls(slug)
|
||||||
|
|
||||||
|
# If no exact match, try inserting hyphen between letter and digit
|
||||||
|
# e.g., gpt4o -> gpt-4o
|
||||||
|
transformed_slug = re.sub(r"([a-z])(\d)", r"\1-\2", slug)
|
||||||
|
return cls(transformed_slug)
|
||||||
|
|
||||||
|
def __iter__(cls):
|
||||||
|
"""Iterate over all models from the registry.
|
||||||
|
|
||||||
|
Yields LlmModel instances for each model in the dynamic registry.
|
||||||
|
Used by __get_pydantic_json_schema__ to build model metadata.
|
||||||
|
"""
|
||||||
|
for model in llm_registry.iter_dynamic_models():
|
||||||
|
yield cls(model.slug)
|
||||||
|
|
||||||
|
|
||||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
class LlmModel(str, metaclass=LlmModelMeta):
|
||||||
# OpenAI models
|
"""
|
||||||
O3_MINI = "o3-mini"
|
Dynamic LLM model type that accepts any model slug from the registry.
|
||||||
O3 = "o3-2025-04-16"
|
|
||||||
O1 = "o1"
|
This is a string subclass (not an Enum) that allows any model slug value.
|
||||||
O1_MINI = "o1-mini"
|
All models are managed via the LLM Registry in the database.
|
||||||
# GPT-5 models
|
|
||||||
GPT5_2 = "gpt-5.2-2025-12-11"
|
Usage:
|
||||||
GPT5_1 = "gpt-5.1-2025-11-13"
|
model = LlmModel("gpt-4o") # Direct construction
|
||||||
GPT5 = "gpt-5-2025-08-07"
|
model = LlmModel.GPT4O # Attribute access (converted to "gpt-4o")
|
||||||
GPT5_MINI = "gpt-5-mini-2025-08-07"
|
model.value # Returns the slug string
|
||||||
GPT5_NANO = "gpt-5-nano-2025-08-07"
|
model.provider # Returns the provider from registry
|
||||||
GPT5_CHAT = "gpt-5-chat-latest"
|
"""
|
||||||
GPT41 = "gpt-4.1-2025-04-14"
|
|
||||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
def __new__(cls, value: str):
|
||||||
GPT4O_MINI = "gpt-4o-mini"
|
if isinstance(value, LlmModel):
|
||||||
GPT4O = "gpt-4o"
|
return value
|
||||||
GPT4_TURBO = "gpt-4-turbo"
|
return str.__new__(cls, value)
|
||||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
|
||||||
# Anthropic models
|
@classmethod
|
||||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
def __get_pydantic_core_schema__(
|
||||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||||
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
|
) -> CoreSchema:
|
||||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
"""
|
||||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
Tell Pydantic how to validate LlmModel.
|
||||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
|
||||||
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
Accepts strings and converts them to LlmModel instances.
|
||||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
"""
|
||||||
# AI/ML API models
|
return core_schema.no_info_after_validator_function(
|
||||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
cls, # The validator function (LlmModel constructor)
|
||||||
AIML_API_LLAMA3_1_70B = "nvidia/llama-3.1-nemotron-70b-instruct"
|
core_schema.str_schema(), # Accept string input
|
||||||
AIML_API_LLAMA3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
|
serialization=core_schema.to_string_ser_schema(), # Serialize as string
|
||||||
AIML_API_META_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
|
)
|
||||||
AIML_API_LLAMA_3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo"
|
|
||||||
# Groq models
|
@property
|
||||||
LLAMA3_3_70B = "llama-3.3-70b-versatile"
|
def value(self) -> str:
|
||||||
LLAMA3_1_8B = "llama-3.1-8b-instant"
|
"""Return the model slug (for compatibility with enum-style access)."""
|
||||||
# Ollama models
|
return str(self)
|
||||||
OLLAMA_LLAMA3_3 = "llama3.3"
|
|
||||||
OLLAMA_LLAMA3_2 = "llama3.2"
|
@classmethod
|
||||||
OLLAMA_LLAMA3_8B = "llama3"
|
def default(cls) -> "LlmModel":
|
||||||
OLLAMA_LLAMA3_405B = "llama3.1:405b"
|
"""
|
||||||
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
|
Get the default model from the registry.
|
||||||
# OpenRouter models
|
|
||||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
Returns the recommended model if set, otherwise gpt-4o if available
|
||||||
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
and enabled, otherwise the first enabled model from the registry.
|
||||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
Falls back to "gpt-4o" if registry is empty (e.g., at module import time).
|
||||||
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
|
"""
|
||||||
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
from backend.data.llm_registry import get_default_model_slug
|
||||||
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
|
||||||
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
|
slug = get_default_model_slug()
|
||||||
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
|
if slug is None:
|
||||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
# Registry is empty (e.g., at module import time before DB connection).
|
||||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
# Fall back to gpt-4o for backward compatibility.
|
||||||
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
|
slug = "gpt-4o"
|
||||||
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
|
return cls(slug)
|
||||||
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
|
|
||||||
PERPLEXITY_SONAR = "perplexity/sonar"
|
|
||||||
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
|
||||||
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
|
||||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
|
||||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
|
||||||
AMAZON_NOVA_LITE_V1 = "amazon/nova-lite-v1"
|
|
||||||
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
|
||||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
|
||||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
|
||||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
|
||||||
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
|
||||||
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
|
||||||
GROK_4 = "x-ai/grok-4"
|
|
||||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
|
||||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
|
||||||
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
|
|
||||||
KIMI_K2 = "moonshotai/kimi-k2"
|
|
||||||
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
|
|
||||||
QWEN3_CODER = "qwen/qwen3-coder"
|
|
||||||
# Llama API models
|
|
||||||
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
|
|
||||||
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
|
|
||||||
LLAMA_API_LLAMA3_3_8B = "Llama-3.3-8B-Instruct"
|
|
||||||
LLAMA_API_LLAMA3_3_70B = "Llama-3.3-70B-Instruct"
|
|
||||||
# v0 by Vercel models
|
|
||||||
V0_1_5_MD = "v0-1.5-md"
|
|
||||||
V0_1_5_LG = "v0-1.5-lg"
|
|
||||||
V0_1_0_MD = "v0-1.0-md"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_pydantic_json_schema__(cls, schema, handler):
|
def __get_pydantic_json_schema__(cls, schema, handler):
|
||||||
@@ -181,7 +194,15 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
llm_model_metadata = {}
|
llm_model_metadata = {}
|
||||||
for model in cls:
|
for model in cls:
|
||||||
model_name = model.value
|
model_name = model.value
|
||||||
metadata = model.metadata
|
# Skip disabled models - only show enabled models in the picker
|
||||||
|
if not llm_registry.is_model_enabled(model_name):
|
||||||
|
continue
|
||||||
|
# Use registry directly with None check to gracefully handle
|
||||||
|
# missing metadata during startup/import before registry is populated
|
||||||
|
metadata = llm_registry.get_llm_model_metadata(model_name)
|
||||||
|
if metadata is None:
|
||||||
|
# Skip models without metadata (registry not yet populated)
|
||||||
|
continue
|
||||||
llm_model_metadata[model_name] = {
|
llm_model_metadata[model_name] = {
|
||||||
"creator": metadata.creator_name,
|
"creator": metadata.creator_name,
|
||||||
"creator_name": metadata.creator_name,
|
"creator_name": metadata.creator_name,
|
||||||
@@ -197,7 +218,12 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def metadata(self) -> ModelMetadata:
|
def metadata(self) -> ModelMetadata:
|
||||||
return MODEL_METADATA[self]
|
metadata = llm_registry.get_llm_model_metadata(self.value)
|
||||||
|
if metadata:
|
||||||
|
return metadata
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing metadata for model: {self.value}. Model not found in LLM registry."
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider(self) -> str:
|
def provider(self) -> str:
|
||||||
@@ -212,300 +238,125 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
return self.metadata.max_output_tokens
|
return self.metadata.max_output_tokens
|
||||||
|
|
||||||
|
|
||||||
MODEL_METADATA = {
|
# Default model constant for backward compatibility
|
||||||
# https://platform.openai.com/docs/models
|
# Uses the dynamic registry to get the default model
|
||||||
LlmModel.O3: ModelMetadata("openai", 200000, 100000, "O3", "OpenAI", "OpenAI", 2),
|
DEFAULT_LLM_MODEL = LlmModel.default()
|
||||||
LlmModel.O3_MINI: ModelMetadata(
|
|
||||||
"openai", 200000, 100000, "O3 Mini", "OpenAI", "OpenAI", 1
|
|
||||||
), # o3-mini-2025-01-31
|
|
||||||
LlmModel.O1: ModelMetadata(
|
|
||||||
"openai", 200000, 100000, "O1", "OpenAI", "OpenAI", 3
|
|
||||||
), # o1-2024-12-17
|
|
||||||
LlmModel.O1_MINI: ModelMetadata(
|
|
||||||
"openai", 128000, 65536, "O1 Mini", "OpenAI", "OpenAI", 2
|
|
||||||
), # o1-mini-2024-09-12
|
|
||||||
# GPT-5 models
|
|
||||||
LlmModel.GPT5_2: ModelMetadata(
|
|
||||||
"openai", 400000, 128000, "GPT-5.2", "OpenAI", "OpenAI", 3
|
|
||||||
),
|
|
||||||
LlmModel.GPT5_1: ModelMetadata(
|
|
||||||
"openai", 400000, 128000, "GPT-5.1", "OpenAI", "OpenAI", 2
|
|
||||||
),
|
|
||||||
LlmModel.GPT5: ModelMetadata(
|
|
||||||
"openai", 400000, 128000, "GPT-5", "OpenAI", "OpenAI", 1
|
|
||||||
),
|
|
||||||
LlmModel.GPT5_MINI: ModelMetadata(
|
|
||||||
"openai", 400000, 128000, "GPT-5 Mini", "OpenAI", "OpenAI", 1
|
|
||||||
),
|
|
||||||
LlmModel.GPT5_NANO: ModelMetadata(
|
|
||||||
"openai", 400000, 128000, "GPT-5 Nano", "OpenAI", "OpenAI", 1
|
|
||||||
),
|
|
||||||
LlmModel.GPT5_CHAT: ModelMetadata(
|
|
||||||
"openai", 400000, 16384, "GPT-5 Chat Latest", "OpenAI", "OpenAI", 2
|
|
||||||
),
|
|
||||||
LlmModel.GPT41: ModelMetadata(
|
|
||||||
"openai", 1047576, 32768, "GPT-4.1", "OpenAI", "OpenAI", 1
|
|
||||||
),
|
|
||||||
LlmModel.GPT41_MINI: ModelMetadata(
|
|
||||||
"openai", 1047576, 32768, "GPT-4.1 Mini", "OpenAI", "OpenAI", 1
|
|
||||||
),
|
|
||||||
LlmModel.GPT4O_MINI: ModelMetadata(
|
|
||||||
"openai", 128000, 16384, "GPT-4o Mini", "OpenAI", "OpenAI", 1
|
|
||||||
), # gpt-4o-mini-2024-07-18
|
|
||||||
LlmModel.GPT4O: ModelMetadata(
|
|
||||||
"openai", 128000, 16384, "GPT-4o", "OpenAI", "OpenAI", 2
|
|
||||||
), # gpt-4o-2024-08-06
|
|
||||||
LlmModel.GPT4_TURBO: ModelMetadata(
|
|
||||||
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
|
|
||||||
), # gpt-4-turbo-2024-04-09
|
|
||||||
LlmModel.GPT3_5_TURBO: ModelMetadata(
|
|
||||||
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
|
|
||||||
), # gpt-3.5-turbo-0125
|
|
||||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
|
||||||
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
|
|
||||||
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3
|
|
||||||
), # claude-opus-4-1-20250805
|
|
||||||
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
|
|
||||||
"anthropic", 200000, 32000, "Claude Opus 4", "Anthropic", "Anthropic", 3
|
|
||||||
), # claude-4-opus-20250514
|
|
||||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
|
||||||
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
|
||||||
), # claude-4-sonnet-20250514
|
|
||||||
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
|
||||||
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
|
||||||
), # claude-opus-4-6
|
|
||||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
|
||||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
|
||||||
), # claude-opus-4-5-20251101
|
|
||||||
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
|
|
||||||
"anthropic", 200000, 64000, "Claude Sonnet 4.5", "Anthropic", "Anthropic", 3
|
|
||||||
), # claude-sonnet-4-5-20250929
|
|
||||||
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
|
||||||
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
|
|
||||||
), # claude-haiku-4-5-20251001
|
|
||||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
|
||||||
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
|
|
||||||
), # claude-3-haiku-20240307
|
|
||||||
# https://docs.aimlapi.com/api-overview/model-database/text-models
|
|
||||||
LlmModel.AIML_API_QWEN2_5_72B: ModelMetadata(
|
|
||||||
"aiml_api", 32000, 8000, "Qwen 2.5 72B Instruct Turbo", "AI/ML", "Qwen", 1
|
|
||||||
),
|
|
||||||
LlmModel.AIML_API_LLAMA3_1_70B: ModelMetadata(
|
|
||||||
"aiml_api",
|
|
||||||
128000,
|
|
||||||
40000,
|
|
||||||
"Llama 3.1 Nemotron 70B Instruct",
|
|
||||||
"AI/ML",
|
|
||||||
"Nvidia",
|
|
||||||
1,
|
|
||||||
),
|
|
||||||
LlmModel.AIML_API_LLAMA3_3_70B: ModelMetadata(
|
|
||||||
"aiml_api", 128000, None, "Llama 3.3 70B Instruct Turbo", "AI/ML", "Meta", 1
|
|
||||||
),
|
|
||||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: ModelMetadata(
|
|
||||||
"aiml_api", 131000, 2000, "Llama 3.1 70B Instruct Turbo", "AI/ML", "Meta", 1
|
|
||||||
),
|
|
||||||
LlmModel.AIML_API_LLAMA_3_2_3B: ModelMetadata(
|
|
||||||
"aiml_api", 128000, None, "Llama 3.2 3B Instruct Turbo", "AI/ML", "Meta", 1
|
|
||||||
),
|
|
||||||
# https://console.groq.com/docs/models
|
|
||||||
LlmModel.LLAMA3_3_70B: ModelMetadata(
|
|
||||||
"groq", 128000, 32768, "Llama 3.3 70B Versatile", "Groq", "Meta", 1
|
|
||||||
),
|
|
||||||
LlmModel.LLAMA3_1_8B: ModelMetadata(
|
|
||||||
"groq", 128000, 8192, "Llama 3.1 8B Instant", "Groq", "Meta", 1
|
|
||||||
),
|
|
||||||
# https://ollama.com/library
|
|
||||||
LlmModel.OLLAMA_LLAMA3_3: ModelMetadata(
|
|
||||||
"ollama", 8192, None, "Llama 3.3", "Ollama", "Meta", 1
|
|
||||||
),
|
|
||||||
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata(
|
|
||||||
"ollama", 8192, None, "Llama 3.2", "Ollama", "Meta", 1
|
|
||||||
),
|
|
||||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata(
|
|
||||||
"ollama", 8192, None, "Llama 3", "Ollama", "Meta", 1
|
|
||||||
),
|
|
||||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata(
|
|
||||||
"ollama", 8192, None, "Llama 3.1 405B", "Ollama", "Meta", 1
|
|
||||||
),
|
|
||||||
LlmModel.OLLAMA_DOLPHIN: ModelMetadata(
|
|
||||||
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
|
|
||||||
),
|
|
||||||
# https://openrouter.ai/models
|
|
||||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
1050000,
|
|
||||||
8192,
|
|
||||||
"Gemini 2.5 Pro Preview 03.25",
|
|
||||||
"OpenRouter",
|
|
||||||
"Google",
|
|
||||||
2,
|
|
||||||
),
|
|
||||||
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
|
|
||||||
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
|
|
||||||
),
|
|
||||||
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
|
|
||||||
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
|
|
||||||
),
|
|
||||||
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
|
|
||||||
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
|
|
||||||
),
|
|
||||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
1048576,
|
|
||||||
65535,
|
|
||||||
"Gemini 2.5 Flash Lite Preview 06.17",
|
|
||||||
"OpenRouter",
|
|
||||||
"Google",
|
|
||||||
1,
|
|
||||||
),
|
|
||||||
LlmModel.GEMINI_2_0_FLASH_LITE: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
1048576,
|
|
||||||
8192,
|
|
||||||
"Gemini 2.0 Flash Lite 001",
|
|
||||||
"OpenRouter",
|
|
||||||
"Google",
|
|
||||||
1,
|
|
||||||
),
|
|
||||||
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
|
||||||
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
|
|
||||||
),
|
|
||||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
|
||||||
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
|
|
||||||
),
|
|
||||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
|
||||||
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
|
|
||||||
),
|
|
||||||
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
|
||||||
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
|
|
||||||
),
|
|
||||||
LlmModel.DEEPSEEK_R1_0528: ModelMetadata(
|
|
||||||
"open_router", 163840, 163840, "DeepSeek R1 0528", "OpenRouter", "DeepSeek", 1
|
|
||||||
),
|
|
||||||
LlmModel.PERPLEXITY_SONAR: ModelMetadata(
|
|
||||||
"open_router", 127000, 8000, "Sonar", "OpenRouter", "Perplexity", 1
|
|
||||||
),
|
|
||||||
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
|
|
||||||
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
|
|
||||||
),
|
|
||||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
128000,
|
|
||||||
16000,
|
|
||||||
"Sonar Deep Research",
|
|
||||||
"OpenRouter",
|
|
||||||
"Perplexity",
|
|
||||||
3,
|
|
||||||
),
|
|
||||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
131000,
|
|
||||||
4096,
|
|
||||||
"Hermes 3 Llama 3.1 405B",
|
|
||||||
"OpenRouter",
|
|
||||||
"Nous Research",
|
|
||||||
1,
|
|
||||||
),
|
|
||||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
12288,
|
|
||||||
12288,
|
|
||||||
"Hermes 3 Llama 3.1 70B",
|
|
||||||
"OpenRouter",
|
|
||||||
"Nous Research",
|
|
||||||
1,
|
|
||||||
),
|
|
||||||
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata(
|
|
||||||
"open_router", 131072, 131072, "GPT-OSS 120B", "OpenRouter", "OpenAI", 1
|
|
||||||
),
|
|
||||||
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata(
|
|
||||||
"open_router", 131072, 32768, "GPT-OSS 20B", "OpenRouter", "OpenAI", 1
|
|
||||||
),
|
|
||||||
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata(
|
|
||||||
"open_router", 300000, 5120, "Nova Lite V1", "OpenRouter", "Amazon", 1
|
|
||||||
),
|
|
||||||
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata(
|
|
||||||
"open_router", 128000, 5120, "Nova Micro V1", "OpenRouter", "Amazon", 1
|
|
||||||
),
|
|
||||||
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata(
|
|
||||||
"open_router", 300000, 5120, "Nova Pro V1", "OpenRouter", "Amazon", 1
|
|
||||||
),
|
|
||||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
|
||||||
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
|
|
||||||
),
|
|
||||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
|
||||||
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
|
|
||||||
),
|
|
||||||
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata(
|
|
||||||
"open_router", 131072, 131072, "Llama 4 Scout", "OpenRouter", "Meta", 1
|
|
||||||
),
|
|
||||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
|
|
||||||
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
|
|
||||||
),
|
|
||||||
LlmModel.GROK_4: ModelMetadata(
|
|
||||||
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
|
|
||||||
),
|
|
||||||
LlmModel.GROK_4_FAST: ModelMetadata(
|
|
||||||
"open_router", 2000000, 30000, "Grok 4 Fast", "OpenRouter", "xAI", 1
|
|
||||||
),
|
|
||||||
LlmModel.GROK_4_1_FAST: ModelMetadata(
|
|
||||||
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
|
|
||||||
),
|
|
||||||
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
|
|
||||||
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
|
|
||||||
),
|
|
||||||
LlmModel.KIMI_K2: ModelMetadata(
|
|
||||||
"open_router", 131000, 131000, "Kimi K2", "OpenRouter", "Moonshot AI", 1
|
|
||||||
),
|
|
||||||
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
262144,
|
|
||||||
262144,
|
|
||||||
"Qwen 3 235B A22B Thinking 2507",
|
|
||||||
"OpenRouter",
|
|
||||||
"Qwen",
|
|
||||||
1,
|
|
||||||
),
|
|
||||||
LlmModel.QWEN3_CODER: ModelMetadata(
|
|
||||||
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
|
|
||||||
),
|
|
||||||
# Llama API models
|
|
||||||
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
|
|
||||||
"llama_api",
|
|
||||||
128000,
|
|
||||||
4028,
|
|
||||||
"Llama 4 Scout 17B 16E Instruct FP8",
|
|
||||||
"Llama API",
|
|
||||||
"Meta",
|
|
||||||
1,
|
|
||||||
),
|
|
||||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata(
|
|
||||||
"llama_api",
|
|
||||||
128000,
|
|
||||||
4028,
|
|
||||||
"Llama 4 Maverick 17B 128E Instruct FP8",
|
|
||||||
"Llama API",
|
|
||||||
"Meta",
|
|
||||||
1,
|
|
||||||
),
|
|
||||||
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata(
|
|
||||||
"llama_api", 128000, 4028, "Llama 3.3 8B Instruct", "Llama API", "Meta", 1
|
|
||||||
),
|
|
||||||
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata(
|
|
||||||
"llama_api", 128000, 4028, "Llama 3.3 70B Instruct", "Llama API", "Meta", 1
|
|
||||||
),
|
|
||||||
# v0 by Vercel models
|
|
||||||
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000, "v0 1.5 MD", "V0", "V0", 1),
|
|
||||||
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000, "v0 1.5 LG", "V0", "V0", 1),
|
|
||||||
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000, "v0 1.0 MD", "V0", "V0", 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
|
|
||||||
|
|
||||||
for model in LlmModel:
|
class ModelUnavailableError(ValueError):
|
||||||
if model not in MODEL_METADATA:
|
"""Raised when a requested LLM model cannot be resolved for use."""
|
||||||
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ResolvedModel:
|
||||||
|
"""Result of resolving a model for an LLM call."""
|
||||||
|
|
||||||
|
slug: str # The actual model slug to use (may differ from requested if fallback)
|
||||||
|
provider: str
|
||||||
|
context_window: int
|
||||||
|
max_output_tokens: int
|
||||||
|
used_fallback: bool = False
|
||||||
|
original_slug: str | None = None # Set if fallback was used
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_model_for_call(llm_model: LlmModel) -> ResolvedModel:
|
||||||
|
"""
|
||||||
|
Resolve a model for use in an LLM call.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Checking if the model exists in the registry
|
||||||
|
- Falling back to an enabled model from the same provider if disabled
|
||||||
|
- Refreshing the registry cache if model not found (with DB access)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_model: The requested LlmModel
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResolvedModel with all necessary metadata for the call
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ModelUnavailableError: If model cannot be resolved (not found, disabled with no fallback)
|
||||||
|
"""
|
||||||
|
from backend.data.llm_registry import (
|
||||||
|
get_fallback_model_for_disabled,
|
||||||
|
get_model_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_info = get_model_info(llm_model.value)
|
||||||
|
|
||||||
|
# Case 1: Model found and disabled - try fallback
|
||||||
|
if model_info and not model_info.is_enabled:
|
||||||
|
fallback = get_fallback_model_for_disabled(llm_model.value)
|
||||||
|
if fallback:
|
||||||
|
logger.warning(
|
||||||
|
f"Model '{llm_model.value}' is disabled. Using fallback "
|
||||||
|
f"'{fallback.slug}' from same provider ({fallback.metadata.provider})."
|
||||||
|
)
|
||||||
|
return ResolvedModel(
|
||||||
|
slug=fallback.slug,
|
||||||
|
provider=fallback.metadata.provider,
|
||||||
|
context_window=fallback.metadata.context_window,
|
||||||
|
max_output_tokens=fallback.metadata.max_output_tokens or 2**15,
|
||||||
|
used_fallback=True,
|
||||||
|
original_slug=llm_model.value,
|
||||||
|
)
|
||||||
|
raise ModelUnavailableError(
|
||||||
|
f"Model '{llm_model.value}' is disabled and no fallback from the same "
|
||||||
|
f"provider is available. Enable the model or select a different one."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Case 2: Model found and enabled - use it directly
|
||||||
|
if model_info:
|
||||||
|
return ResolvedModel(
|
||||||
|
slug=llm_model.value,
|
||||||
|
provider=model_info.metadata.provider,
|
||||||
|
context_window=model_info.metadata.context_window,
|
||||||
|
max_output_tokens=model_info.metadata.max_output_tokens or 2**15,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Case 3: Model not in registry - try refresh if DB available
|
||||||
|
logger.warning(f"Model '{llm_model.value}' not found in registry cache")
|
||||||
|
|
||||||
|
from backend.data.db import is_connected
|
||||||
|
|
||||||
|
if not is_connected():
|
||||||
|
raise ModelUnavailableError(
|
||||||
|
f"Model '{llm_model.value}' not found in registry. "
|
||||||
|
f"The registry may need to be refreshed via the admin UI."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try refreshing the registry
|
||||||
|
try:
|
||||||
|
logger.info(f"Refreshing LLM registry for model '{llm_model.value}'")
|
||||||
|
await llm_registry.refresh_llm_registry()
|
||||||
|
except Exception as e:
|
||||||
|
raise ModelUnavailableError(
|
||||||
|
f"Model '{llm_model.value}' not found and registry refresh failed: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Check again after refresh
|
||||||
|
model_info = get_model_info(llm_model.value)
|
||||||
|
if not model_info:
|
||||||
|
raise ModelUnavailableError(
|
||||||
|
f"Model '{llm_model.value}' not found in registry. "
|
||||||
|
f"Add it via the admin UI at /admin/llms."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not model_info.is_enabled:
|
||||||
|
raise ModelUnavailableError(
|
||||||
|
f"Model '{llm_model.value}' exists but is disabled. "
|
||||||
|
f"Enable it via the admin UI at /admin/llms."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Model '{llm_model.value}' loaded after registry refresh")
|
||||||
|
return ResolvedModel(
|
||||||
|
slug=llm_model.value,
|
||||||
|
provider=model_info.metadata.provider,
|
||||||
|
context_window=model_info.metadata.context_window,
|
||||||
|
max_output_tokens=model_info.metadata.max_output_tokens or 2**15,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolCall(BaseModel):
|
class ToolCall(BaseModel):
|
||||||
@@ -531,12 +382,12 @@ class LLMResponse(BaseModel):
|
|||||||
|
|
||||||
def convert_openai_tool_fmt_to_anthropic(
|
def convert_openai_tool_fmt_to_anthropic(
|
||||||
openai_tools: list[dict] | None = None,
|
openai_tools: list[dict] | None = None,
|
||||||
) -> Iterable[ToolParam] | anthropic.Omit:
|
) -> Iterable[ToolParam] | anthropic.NotGiven:
|
||||||
"""
|
"""
|
||||||
Convert OpenAI tool format to Anthropic tool format.
|
Convert OpenAI tool format to Anthropic tool format.
|
||||||
"""
|
"""
|
||||||
if not openai_tools or len(openai_tools) == 0:
|
if not openai_tools or len(openai_tools) == 0:
|
||||||
return anthropic.omit
|
return anthropic.NOT_GIVEN
|
||||||
|
|
||||||
anthropic_tools = []
|
anthropic_tools = []
|
||||||
for tool in openai_tools:
|
for tool in openai_tools:
|
||||||
@@ -598,7 +449,12 @@ def get_parallel_tool_calls_param(
|
|||||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||||
) -> bool | openai.Omit:
|
) -> bool | openai.Omit:
|
||||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
# Check for o-series models (o1, o1-mini, o3-mini, etc.) which don't support
|
||||||
|
# parallel tool calls. Handle both bare slugs ("o1-mini") and provider-prefixed
|
||||||
|
# slugs ("openai/o1-mini"). The pattern matches "o" followed by a digit at the
|
||||||
|
# start of the string or after a "/" separator.
|
||||||
|
is_o_series = re.search(r"(^|/)o\d", llm_model) is not None
|
||||||
|
if is_o_series or parallel_tool_calls is None:
|
||||||
return openai.omit
|
return openai.omit
|
||||||
return parallel_tool_calls
|
return parallel_tool_calls
|
||||||
|
|
||||||
@@ -634,15 +490,22 @@ async def llm_call(
|
|||||||
- prompt_tokens: The number of tokens used in the prompt.
|
- prompt_tokens: The number of tokens used in the prompt.
|
||||||
- completion_tokens: The number of tokens used in the completion.
|
- completion_tokens: The number of tokens used in the completion.
|
||||||
"""
|
"""
|
||||||
provider = llm_model.metadata.provider
|
# Resolve the model - handles disabled models, fallbacks, and cache misses
|
||||||
context_window = llm_model.context_window
|
resolved = await resolve_model_for_call(llm_model)
|
||||||
|
|
||||||
|
model_to_use = resolved.slug
|
||||||
|
provider = resolved.provider
|
||||||
|
context_window = resolved.context_window
|
||||||
|
model_max_output = resolved.max_output_tokens
|
||||||
|
|
||||||
|
# Create effective model for model-specific parameter resolution (e.g., o-series check)
|
||||||
|
effective_model = LlmModel(model_to_use)
|
||||||
|
|
||||||
if compress_prompt_to_fit:
|
if compress_prompt_to_fit:
|
||||||
result = await compress_context(
|
result = await compress_context(
|
||||||
messages=prompt,
|
messages=prompt,
|
||||||
target_tokens=llm_model.context_window // 2,
|
target_tokens=context_window // 2,
|
||||||
client=None, # Truncation-only, no LLM summarization
|
client=None, # Truncation-only, no LLM summarization
|
||||||
reserve=0, # Caller handles response token budget separately
|
|
||||||
)
|
)
|
||||||
if result.error:
|
if result.error:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -653,7 +516,7 @@ async def llm_call(
|
|||||||
|
|
||||||
# Calculate available tokens based on context window and input length
|
# Calculate available tokens based on context window and input length
|
||||||
estimated_input_tokens = estimate_token_count(prompt)
|
estimated_input_tokens = estimate_token_count(prompt)
|
||||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
# model_max_output already set above
|
||||||
user_max = max_tokens or model_max_output
|
user_max = max_tokens or model_max_output
|
||||||
available_tokens = max(context_window - estimated_input_tokens, 0)
|
available_tokens = max(context_window - estimated_input_tokens, 0)
|
||||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
||||||
@@ -664,14 +527,14 @@ async def llm_call(
|
|||||||
response_format = None
|
response_format = None
|
||||||
|
|
||||||
parallel_tool_calls = get_parallel_tool_calls_param(
|
parallel_tool_calls = get_parallel_tool_calls_param(
|
||||||
llm_model, parallel_tool_calls
|
effective_model, parallel_tool_calls
|
||||||
)
|
)
|
||||||
|
|
||||||
if force_json_output:
|
if force_json_output:
|
||||||
response_format = {"type": "json_object"}
|
response_format = {"type": "json_object"}
|
||||||
|
|
||||||
response = await oai_client.chat.completions.create(
|
response = await oai_client.chat.completions.create(
|
||||||
model=llm_model.value,
|
model=model_to_use,
|
||||||
messages=prompt, # type: ignore
|
messages=prompt, # type: ignore
|
||||||
response_format=response_format, # type: ignore
|
response_format=response_format, # type: ignore
|
||||||
max_completion_tokens=max_tokens,
|
max_completion_tokens=max_tokens,
|
||||||
@@ -718,7 +581,7 @@ async def llm_call(
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
resp = await client.messages.create(
|
resp = await client.messages.create(
|
||||||
model=llm_model.value,
|
model=model_to_use,
|
||||||
system=sysprompt,
|
system=sysprompt,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@@ -782,7 +645,7 @@ async def llm_call(
|
|||||||
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
|
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
|
||||||
response_format = {"type": "json_object"} if force_json_output else None
|
response_format = {"type": "json_object"} if force_json_output else None
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
model=llm_model.value,
|
model=model_to_use,
|
||||||
messages=prompt, # type: ignore
|
messages=prompt, # type: ignore
|
||||||
response_format=response_format, # type: ignore
|
response_format=response_format, # type: ignore
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@@ -804,7 +667,7 @@ async def llm_call(
|
|||||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||||
response = await client.generate(
|
response = await client.generate(
|
||||||
model=llm_model.value,
|
model=model_to_use,
|
||||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||||
stream=False,
|
stream=False,
|
||||||
options={"num_ctx": max_tokens},
|
options={"num_ctx": max_tokens},
|
||||||
@@ -826,7 +689,7 @@ async def llm_call(
|
|||||||
)
|
)
|
||||||
|
|
||||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||||
llm_model, parallel_tool_calls
|
effective_model, parallel_tool_calls
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
@@ -834,7 +697,7 @@ async def llm_call(
|
|||||||
"HTTP-Referer": "https://agpt.co",
|
"HTTP-Referer": "https://agpt.co",
|
||||||
"X-Title": "AutoGPT",
|
"X-Title": "AutoGPT",
|
||||||
},
|
},
|
||||||
model=llm_model.value,
|
model=model_to_use,
|
||||||
messages=prompt, # type: ignore
|
messages=prompt, # type: ignore
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
tools=tools_param, # type: ignore
|
tools=tools_param, # type: ignore
|
||||||
@@ -868,7 +731,7 @@ async def llm_call(
|
|||||||
)
|
)
|
||||||
|
|
||||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||||
llm_model, parallel_tool_calls
|
effective_model, parallel_tool_calls
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
@@ -876,7 +739,7 @@ async def llm_call(
|
|||||||
"HTTP-Referer": "https://agpt.co",
|
"HTTP-Referer": "https://agpt.co",
|
||||||
"X-Title": "AutoGPT",
|
"X-Title": "AutoGPT",
|
||||||
},
|
},
|
||||||
model=llm_model.value,
|
model=model_to_use,
|
||||||
messages=prompt, # type: ignore
|
messages=prompt, # type: ignore
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
tools=tools_param, # type: ignore
|
tools=tools_param, # type: ignore
|
||||||
@@ -903,7 +766,7 @@ async def llm_call(
|
|||||||
reasoning=reasoning,
|
reasoning=reasoning,
|
||||||
)
|
)
|
||||||
elif provider == "aiml_api":
|
elif provider == "aiml_api":
|
||||||
client = openai.OpenAI(
|
client = openai.AsyncOpenAI(
|
||||||
base_url="https://api.aimlapi.com/v2",
|
base_url="https://api.aimlapi.com/v2",
|
||||||
api_key=credentials.api_key.get_secret_value(),
|
api_key=credentials.api_key.get_secret_value(),
|
||||||
default_headers={
|
default_headers={
|
||||||
@@ -913,8 +776,8 @@ async def llm_call(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
completion = client.chat.completions.create(
|
completion = await client.chat.completions.create(
|
||||||
model=llm_model.value,
|
model=model_to_use,
|
||||||
messages=prompt, # type: ignore
|
messages=prompt, # type: ignore
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
@@ -942,11 +805,11 @@ async def llm_call(
|
|||||||
response_format = {"type": "json_object"}
|
response_format = {"type": "json_object"}
|
||||||
|
|
||||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||||
llm_model, parallel_tool_calls
|
effective_model, parallel_tool_calls
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
model=llm_model.value,
|
model=model_to_use,
|
||||||
messages=prompt, # type: ignore
|
messages=prompt, # type: ignore
|
||||||
response_format=response_format, # type: ignore
|
response_format=response_format, # type: ignore
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@@ -997,9 +860,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
|||||||
)
|
)
|
||||||
model: LlmModel = SchemaField(
|
model: LlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
default=DEFAULT_LLM_MODEL,
|
default_factory=LlmModel.default,
|
||||||
description="The language model to use for answering the prompt.",
|
description="The language model to use for answering the prompt.",
|
||||||
advanced=False,
|
advanced=False,
|
||||||
|
json_schema_extra=llm_model_schema_extra(),
|
||||||
)
|
)
|
||||||
force_json_output: bool = SchemaField(
|
force_json_output: bool = SchemaField(
|
||||||
title="Restrict LLM to pure JSON output",
|
title="Restrict LLM to pure JSON output",
|
||||||
@@ -1062,7 +926,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
|||||||
input_schema=AIStructuredResponseGeneratorBlock.Input,
|
input_schema=AIStructuredResponseGeneratorBlock.Input,
|
||||||
output_schema=AIStructuredResponseGeneratorBlock.Output,
|
output_schema=AIStructuredResponseGeneratorBlock.Output,
|
||||||
test_input={
|
test_input={
|
||||||
"model": DEFAULT_LLM_MODEL,
|
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
"expected_format": {
|
"expected_format": {
|
||||||
"key1": "value1",
|
"key1": "value1",
|
||||||
@@ -1428,9 +1292,10 @@ class AITextGeneratorBlock(AIBlockBase):
|
|||||||
)
|
)
|
||||||
model: LlmModel = SchemaField(
|
model: LlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
default=DEFAULT_LLM_MODEL,
|
default_factory=LlmModel.default,
|
||||||
description="The language model to use for answering the prompt.",
|
description="The language model to use for answering the prompt.",
|
||||||
advanced=False,
|
advanced=False,
|
||||||
|
json_schema_extra=llm_model_schema_extra(),
|
||||||
)
|
)
|
||||||
credentials: AICredentials = AICredentialsField()
|
credentials: AICredentials = AICredentialsField()
|
||||||
sys_prompt: str = SchemaField(
|
sys_prompt: str = SchemaField(
|
||||||
@@ -1524,8 +1389,9 @@ class AITextSummarizerBlock(AIBlockBase):
|
|||||||
)
|
)
|
||||||
model: LlmModel = SchemaField(
|
model: LlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
default=DEFAULT_LLM_MODEL,
|
default_factory=LlmModel.default,
|
||||||
description="The language model to use for summarizing the text.",
|
description="The language model to use for summarizing the text.",
|
||||||
|
json_schema_extra=llm_model_schema_extra(),
|
||||||
)
|
)
|
||||||
focus: str = SchemaField(
|
focus: str = SchemaField(
|
||||||
title="Focus",
|
title="Focus",
|
||||||
@@ -1741,8 +1607,9 @@ class AIConversationBlock(AIBlockBase):
|
|||||||
)
|
)
|
||||||
model: LlmModel = SchemaField(
|
model: LlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
default=DEFAULT_LLM_MODEL,
|
default_factory=LlmModel.default,
|
||||||
description="The language model to use for the conversation.",
|
description="The language model to use for the conversation.",
|
||||||
|
json_schema_extra=llm_model_schema_extra(),
|
||||||
)
|
)
|
||||||
credentials: AICredentials = AICredentialsField()
|
credentials: AICredentials = AICredentialsField()
|
||||||
max_tokens: int | None = SchemaField(
|
max_tokens: int | None = SchemaField(
|
||||||
@@ -1779,7 +1646,7 @@ class AIConversationBlock(AIBlockBase):
|
|||||||
},
|
},
|
||||||
{"role": "user", "content": "Where was it played?"},
|
{"role": "user", "content": "Where was it played?"},
|
||||||
],
|
],
|
||||||
"model": DEFAULT_LLM_MODEL,
|
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -1842,9 +1709,10 @@ class AIListGeneratorBlock(AIBlockBase):
|
|||||||
)
|
)
|
||||||
model: LlmModel = SchemaField(
|
model: LlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
default=DEFAULT_LLM_MODEL,
|
default_factory=LlmModel.default,
|
||||||
description="The language model to use for generating the list.",
|
description="The language model to use for generating the list.",
|
||||||
advanced=True,
|
advanced=True,
|
||||||
|
json_schema_extra=llm_model_schema_extra(),
|
||||||
)
|
)
|
||||||
credentials: AICredentials = AICredentialsField()
|
credentials: AICredentials = AICredentialsField()
|
||||||
max_retries: int = SchemaField(
|
max_retries: int = SchemaField(
|
||||||
@@ -1899,7 +1767,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
|||||||
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
||||||
"fictional worlds."
|
"fictional worlds."
|
||||||
),
|
),
|
||||||
"model": DEFAULT_LLM_MODEL,
|
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
"max_retries": 3,
|
"max_retries": 3,
|
||||||
"force_json_output": False,
|
"force_json_output": False,
|
||||||
|
|||||||
@@ -226,9 +226,10 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
)
|
)
|
||||||
model: llm.LlmModel = SchemaField(
|
model: llm.LlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
default=llm.DEFAULT_LLM_MODEL,
|
default_factory=llm.LlmModel.default,
|
||||||
description="The language model to use for answering the prompt.",
|
description="The language model to use for answering the prompt.",
|
||||||
advanced=False,
|
advanced=False,
|
||||||
|
json_schema_extra=llm.llm_model_schema_extra(),
|
||||||
)
|
)
|
||||||
credentials: llm.AICredentials = llm.AICredentialsField()
|
credentials: llm.AICredentials = llm.AICredentialsField()
|
||||||
multiple_tool_calls: bool = SchemaField(
|
multiple_tool_calls: bool = SchemaField(
|
||||||
|
|||||||
@@ -10,13 +10,13 @@ import stagehand.main
|
|||||||
from stagehand import Stagehand
|
from stagehand import Stagehand
|
||||||
|
|
||||||
from backend.blocks.llm import (
|
from backend.blocks.llm import (
|
||||||
MODEL_METADATA,
|
|
||||||
AICredentials,
|
AICredentials,
|
||||||
AICredentialsField,
|
AICredentialsField,
|
||||||
LlmModel,
|
LlmModel,
|
||||||
ModelMetadata,
|
ModelMetadata,
|
||||||
)
|
)
|
||||||
from backend.blocks.stagehand._config import stagehand as stagehand_provider
|
from backend.blocks.stagehand._config import stagehand as stagehand_provider
|
||||||
|
from backend.data import llm_registry
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
@@ -91,7 +91,7 @@ class StagehandRecommendedLlmModel(str, Enum):
|
|||||||
Returns the provider name for the model in the required format for Stagehand:
|
Returns the provider name for the model in the required format for Stagehand:
|
||||||
provider/model_name
|
provider/model_name
|
||||||
"""
|
"""
|
||||||
model_metadata = MODEL_METADATA[LlmModel(self.value)]
|
model_metadata = self.metadata
|
||||||
model_name = self.value
|
model_name = self.value
|
||||||
|
|
||||||
if len(model_name.split("/")) == 1 and not self.value.startswith(
|
if len(model_name.split("/")) == 1 and not self.value.startswith(
|
||||||
@@ -102,24 +102,28 @@ class StagehandRecommendedLlmModel(str, Enum):
|
|||||||
), "Logic failed and open_router provider attempted to be prepended to model name! in stagehand/_config.py"
|
), "Logic failed and open_router provider attempted to be prepended to model name! in stagehand/_config.py"
|
||||||
model_name = f"{model_metadata.provider}/{model_name}"
|
model_name = f"{model_metadata.provider}/{model_name}"
|
||||||
|
|
||||||
logger.error(f"Model name: {model_name}")
|
logger.debug(f"Model name: {model_name}")
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider(self) -> str:
|
def provider(self) -> str:
|
||||||
return MODEL_METADATA[LlmModel(self.value)].provider
|
return self.metadata.provider
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metadata(self) -> ModelMetadata:
|
def metadata(self) -> ModelMetadata:
|
||||||
return MODEL_METADATA[LlmModel(self.value)]
|
metadata = llm_registry.get_llm_model_metadata(self.value)
|
||||||
|
if metadata:
|
||||||
|
return metadata
|
||||||
|
# Fallback to LlmModel enum if registry lookup fails
|
||||||
|
return LlmModel(self.value).metadata
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def context_window(self) -> int:
|
def context_window(self) -> int:
|
||||||
return MODEL_METADATA[LlmModel(self.value)].context_window
|
return self.metadata.context_window
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_output_tokens(self) -> int | None:
|
def max_output_tokens(self) -> int | None:
|
||||||
return MODEL_METADATA[LlmModel(self.value)].max_output_tokens
|
return self.metadata.max_output_tokens
|
||||||
|
|
||||||
|
|
||||||
class StagehandObserveBlock(Block):
|
class StagehandObserveBlock(Block):
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
@@ -28,54 +27,6 @@ async def server():
|
|||||||
yield server
|
yield server
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_user_id() -> str:
|
|
||||||
"""Test user ID fixture."""
|
|
||||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def admin_user_id() -> str:
|
|
||||||
"""Admin user ID fixture."""
|
|
||||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def target_user_id() -> str:
|
|
||||||
"""Target user ID fixture."""
|
|
||||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def setup_test_user(test_user_id):
|
|
||||||
"""Create test user in database before tests."""
|
|
||||||
from backend.data.user import get_or_create_user
|
|
||||||
|
|
||||||
# Create the test user in the database using JWT token format
|
|
||||||
user_data = {
|
|
||||||
"sub": test_user_id,
|
|
||||||
"email": "test@example.com",
|
|
||||||
"user_metadata": {"name": "Test User"},
|
|
||||||
}
|
|
||||||
await get_or_create_user(user_data)
|
|
||||||
return test_user_id
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def setup_admin_user(admin_user_id):
|
|
||||||
"""Create admin user in database before tests."""
|
|
||||||
from backend.data.user import get_or_create_user
|
|
||||||
|
|
||||||
# Create the admin user in the database using JWT token format
|
|
||||||
user_data = {
|
|
||||||
"sub": admin_user_id,
|
|
||||||
"email": "test-admin@example.com",
|
|
||||||
"user_metadata": {"name": "Test Admin"},
|
|
||||||
}
|
|
||||||
await get_or_create_user(user_data)
|
|
||||||
return admin_user_id
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
|
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
|
||||||
async def graph_cleanup(server):
|
async def graph_cleanup(server):
|
||||||
created_graph_ids = []
|
created_graph_ids = []
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
"""Entry point for running the CoPilot Executor service.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python -m backend.copilot.executor
|
|
||||||
"""
|
|
||||||
|
|
||||||
from backend.app import run_processes
|
|
||||||
|
|
||||||
from .manager import CoPilotExecutor
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run the CoPilot Executor service."""
|
|
||||||
run_processes(CoPilotExecutor())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,521 +0,0 @@
|
|||||||
"""CoPilot Executor Manager - main service for CoPilot task execution.
|
|
||||||
|
|
||||||
This module contains the CoPilotExecutor class that consumes chat tasks from
|
|
||||||
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
|
||||||
|
|
||||||
from pika.adapters.blocking_connection import BlockingChannel
|
|
||||||
from pika.exceptions import AMQPChannelError, AMQPConnectionError
|
|
||||||
from pika.spec import Basic, BasicProperties
|
|
||||||
from prometheus_client import Gauge, start_http_server
|
|
||||||
|
|
||||||
from backend.data import redis_client as redis
|
|
||||||
from backend.data.rabbitmq import SyncRabbitMQ
|
|
||||||
from backend.executor.cluster_lock import ClusterLock
|
|
||||||
from backend.util.decorator import error_logged
|
|
||||||
from backend.util.logging import TruncatedLogger
|
|
||||||
from backend.util.process import AppProcess
|
|
||||||
from backend.util.retry import continuous_retry
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
|
|
||||||
from .processor import execute_copilot_task, init_worker
|
|
||||||
from .utils import (
|
|
||||||
COPILOT_CANCEL_QUEUE_NAME,
|
|
||||||
COPILOT_EXECUTION_QUEUE_NAME,
|
|
||||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
|
||||||
CancelCoPilotEvent,
|
|
||||||
CoPilotExecutionEntry,
|
|
||||||
create_copilot_queue_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
|
||||||
settings = Settings()
|
|
||||||
|
|
||||||
# Prometheus metrics
|
|
||||||
active_tasks_gauge = Gauge(
|
|
||||||
"copilot_executor_active_tasks",
|
|
||||||
"Number of active CoPilot tasks",
|
|
||||||
)
|
|
||||||
pool_size_gauge = Gauge(
|
|
||||||
"copilot_executor_pool_size",
|
|
||||||
"Maximum number of CoPilot executor workers",
|
|
||||||
)
|
|
||||||
utilization_gauge = Gauge(
|
|
||||||
"copilot_executor_utilization_ratio",
|
|
||||||
"Ratio of active tasks to pool size",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CoPilotExecutor(AppProcess):
|
|
||||||
"""CoPilot Executor service for processing chat generation tasks.
|
|
||||||
|
|
||||||
This service consumes tasks from RabbitMQ, processes them using a thread pool,
|
|
||||||
and publishes results to Redis Streams. It follows the graph executor pattern
|
|
||||||
for reliable message handling and graceful shutdown.
|
|
||||||
|
|
||||||
Key features:
|
|
||||||
- RabbitMQ-based task distribution with manual acknowledgment
|
|
||||||
- Thread pool executor for concurrent task processing
|
|
||||||
- Cluster lock for duplicate prevention across pods
|
|
||||||
- Graceful shutdown with timeout for in-flight tasks
|
|
||||||
- FANOUT exchange for cancellation broadcast
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.pool_size = settings.config.num_copilot_workers
|
|
||||||
self.active_tasks: dict[str, tuple[Future, threading.Event]] = {}
|
|
||||||
self.executor_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
self._executor = None
|
|
||||||
self._stop_consuming = None
|
|
||||||
|
|
||||||
self._cancel_thread = None
|
|
||||||
self._cancel_client = None
|
|
||||||
self._run_thread = None
|
|
||||||
self._run_client = None
|
|
||||||
|
|
||||||
self._task_locks: dict[str, ClusterLock] = {}
|
|
||||||
self._active_tasks_lock = threading.Lock()
|
|
||||||
|
|
||||||
# ============ Main Entry Points (AppProcess interface) ============ #
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
"""Main service loop - consume from RabbitMQ."""
|
|
||||||
logger.info(f"Pod assigned executor_id: {self.executor_id}")
|
|
||||||
logger.info(f"Spawn max-{self.pool_size} workers...")
|
|
||||||
|
|
||||||
pool_size_gauge.set(self.pool_size)
|
|
||||||
self._update_metrics()
|
|
||||||
start_http_server(settings.config.copilot_executor_port)
|
|
||||||
|
|
||||||
self.cancel_thread.start()
|
|
||||||
self.run_thread.start()
|
|
||||||
|
|
||||||
while True:
|
|
||||||
time.sleep(1e5)
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
"""Graceful shutdown with active execution waiting."""
|
|
||||||
pid = os.getpid()
|
|
||||||
logger.info(f"[cleanup {pid}] Starting graceful shutdown...")
|
|
||||||
|
|
||||||
# Signal the consumer thread to stop
|
|
||||||
try:
|
|
||||||
self.stop_consuming.set()
|
|
||||||
run_channel = self.run_client.get_channel()
|
|
||||||
run_channel.connection.add_callback_threadsafe(
|
|
||||||
lambda: run_channel.stop_consuming()
|
|
||||||
)
|
|
||||||
logger.info(f"[cleanup {pid}] Consumer has been signaled to stop")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[cleanup {pid}] Error stopping consumer: {e}")
|
|
||||||
|
|
||||||
# Wait for active executions to complete
|
|
||||||
if self.active_tasks:
|
|
||||||
logger.info(
|
|
||||||
f"[cleanup {pid}] Waiting for {len(self.active_tasks)} active tasks to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..."
|
|
||||||
)
|
|
||||||
|
|
||||||
start_time = time.monotonic()
|
|
||||||
last_refresh = start_time
|
|
||||||
lock_refresh_interval = settings.config.cluster_lock_timeout / 10
|
|
||||||
|
|
||||||
while (
|
|
||||||
self.active_tasks
|
|
||||||
and (time.monotonic() - start_time) < GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS
|
|
||||||
):
|
|
||||||
self._cleanup_completed_tasks()
|
|
||||||
if not self.active_tasks:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Refresh cluster locks periodically
|
|
||||||
current_time = time.monotonic()
|
|
||||||
if current_time - last_refresh >= lock_refresh_interval:
|
|
||||||
for lock in list(self._task_locks.values()):
|
|
||||||
try:
|
|
||||||
lock.refresh()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"[cleanup {pid}] Failed to refresh lock: {e}"
|
|
||||||
)
|
|
||||||
last_refresh = current_time
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[cleanup {pid}] {len(self.active_tasks)} tasks still active, waiting..."
|
|
||||||
)
|
|
||||||
time.sleep(10.0)
|
|
||||||
|
|
||||||
# Stop message consumers
|
|
||||||
if self._run_thread:
|
|
||||||
self._stop_message_consumers(
|
|
||||||
self._run_thread, self.run_client, "[cleanup][run]"
|
|
||||||
)
|
|
||||||
if self._cancel_thread:
|
|
||||||
self._stop_message_consumers(
|
|
||||||
self._cancel_thread, self.cancel_client, "[cleanup][cancel]"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clean up worker threads (closes per-loop workspace storage sessions)
|
|
||||||
if self._executor:
|
|
||||||
from .processor import cleanup_worker
|
|
||||||
|
|
||||||
logger.info(f"[cleanup {pid}] Cleaning up workers...")
|
|
||||||
futures = []
|
|
||||||
for _ in range(self._executor._max_workers):
|
|
||||||
futures.append(self._executor.submit(cleanup_worker))
|
|
||||||
for f in futures:
|
|
||||||
try:
|
|
||||||
f.result(timeout=10)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[cleanup {pid}] Worker cleanup error: {e}")
|
|
||||||
|
|
||||||
logger.info(f"[cleanup {pid}] Shutting down executor...")
|
|
||||||
self._executor.shutdown(wait=False)
|
|
||||||
|
|
||||||
# Release any remaining locks
|
|
||||||
for task_id, lock in list(self._task_locks.items()):
|
|
||||||
try:
|
|
||||||
lock.release()
|
|
||||||
logger.info(f"[cleanup {pid}] Released lock for {task_id}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[cleanup {pid}] Failed to release lock for {task_id}: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
|
|
||||||
|
|
||||||
# ============ RabbitMQ Consumer Methods ============ #
|
|
||||||
|
|
||||||
@continuous_retry()
|
|
||||||
def _consume_cancel(self):
|
|
||||||
"""Consume cancellation messages from FANOUT exchange."""
|
|
||||||
if self.stop_consuming.is_set() and not self.active_tasks:
|
|
||||||
logger.info("Stop reconnecting cancel consumer - service cleaned up")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.cancel_client.is_ready:
|
|
||||||
self.cancel_client.disconnect()
|
|
||||||
self.cancel_client.connect()
|
|
||||||
|
|
||||||
# Check again after connect - shutdown may have been requested
|
|
||||||
if self.stop_consuming.is_set() and not self.active_tasks:
|
|
||||||
logger.info("Stop consuming requested during reconnect - disconnecting")
|
|
||||||
self.cancel_client.disconnect()
|
|
||||||
return
|
|
||||||
|
|
||||||
cancel_channel = self.cancel_client.get_channel()
|
|
||||||
cancel_channel.basic_consume(
|
|
||||||
queue=COPILOT_CANCEL_QUEUE_NAME,
|
|
||||||
on_message_callback=self._handle_cancel_message,
|
|
||||||
auto_ack=True,
|
|
||||||
)
|
|
||||||
logger.info("Starting to consume cancel messages...")
|
|
||||||
cancel_channel.start_consuming()
|
|
||||||
if not self.stop_consuming.is_set() or self.active_tasks:
|
|
||||||
raise RuntimeError("Cancel message consumer stopped unexpectedly")
|
|
||||||
logger.info("Cancel message consumer stopped gracefully")
|
|
||||||
|
|
||||||
@continuous_retry()
|
|
||||||
def _consume_run(self):
|
|
||||||
"""Consume run messages from DIRECT exchange."""
|
|
||||||
if self.stop_consuming.is_set():
|
|
||||||
logger.info("Stop reconnecting run consumer - service cleaned up")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.run_client.is_ready:
|
|
||||||
self.run_client.disconnect()
|
|
||||||
self.run_client.connect()
|
|
||||||
|
|
||||||
# Check again after connect - shutdown may have been requested
|
|
||||||
if self.stop_consuming.is_set():
|
|
||||||
logger.info("Stop consuming requested during reconnect - disconnecting")
|
|
||||||
self.run_client.disconnect()
|
|
||||||
return
|
|
||||||
|
|
||||||
run_channel = self.run_client.get_channel()
|
|
||||||
run_channel.basic_qos(prefetch_count=self.pool_size)
|
|
||||||
|
|
||||||
run_channel.basic_consume(
|
|
||||||
queue=COPILOT_EXECUTION_QUEUE_NAME,
|
|
||||||
on_message_callback=self._handle_run_message,
|
|
||||||
auto_ack=False,
|
|
||||||
consumer_tag="copilot_execution_consumer",
|
|
||||||
)
|
|
||||||
logger.info("Starting to consume run messages...")
|
|
||||||
run_channel.start_consuming()
|
|
||||||
if not self.stop_consuming.is_set():
|
|
||||||
raise RuntimeError("Run message consumer stopped unexpectedly")
|
|
||||||
logger.info("Run message consumer stopped gracefully")
|
|
||||||
|
|
||||||
# ============ Message Handlers ============ #
|
|
||||||
|
|
||||||
@error_logged(swallow=True)
|
|
||||||
def _handle_cancel_message(
|
|
||||||
self,
|
|
||||||
_channel: BlockingChannel,
|
|
||||||
_method: Basic.Deliver,
|
|
||||||
_properties: BasicProperties,
|
|
||||||
body: bytes,
|
|
||||||
):
|
|
||||||
"""Handle cancel message from FANOUT exchange."""
|
|
||||||
request = CancelCoPilotEvent.model_validate_json(body)
|
|
||||||
task_id = request.task_id
|
|
||||||
if not task_id:
|
|
||||||
logger.warning("Cancel message missing 'task_id'")
|
|
||||||
return
|
|
||||||
if task_id not in self.active_tasks:
|
|
||||||
logger.debug(f"Cancel received for {task_id} but not active")
|
|
||||||
return
|
|
||||||
|
|
||||||
_, cancel_event = self.active_tasks[task_id]
|
|
||||||
logger.info(f"Received cancel for {task_id}")
|
|
||||||
if not cancel_event.is_set():
|
|
||||||
cancel_event.set()
|
|
||||||
else:
|
|
||||||
logger.debug(f"Cancel already set for {task_id}")
|
|
||||||
|
|
||||||
def _handle_run_message(
|
|
||||||
self,
|
|
||||||
_channel: BlockingChannel,
|
|
||||||
method: Basic.Deliver,
|
|
||||||
_properties: BasicProperties,
|
|
||||||
body: bytes,
|
|
||||||
):
|
|
||||||
"""Handle run message from DIRECT exchange."""
|
|
||||||
delivery_tag = method.delivery_tag
|
|
||||||
# Capture the channel used at message delivery time to ensure we ack
|
|
||||||
# on the correct channel. Delivery tags are channel-scoped and become
|
|
||||||
# invalid if the channel is recreated after reconnection.
|
|
||||||
delivery_channel = _channel
|
|
||||||
|
|
||||||
def ack_message(reject: bool, requeue: bool):
|
|
||||||
"""Acknowledge or reject the message.
|
|
||||||
|
|
||||||
Uses the channel from the original message delivery. If the channel
|
|
||||||
is no longer open (e.g., after reconnection), logs a warning and
|
|
||||||
skips the ack - RabbitMQ will redeliver the message automatically.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not delivery_channel.is_open:
|
|
||||||
logger.warning(
|
|
||||||
f"Channel closed, cannot ack delivery_tag={delivery_tag}. "
|
|
||||||
"Message will be redelivered by RabbitMQ."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if reject:
|
|
||||||
delivery_channel.connection.add_callback_threadsafe(
|
|
||||||
lambda: delivery_channel.basic_nack(
|
|
||||||
delivery_tag, requeue=requeue
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
delivery_channel.connection.add_callback_threadsafe(
|
|
||||||
lambda: delivery_channel.basic_ack(delivery_tag)
|
|
||||||
)
|
|
||||||
except (AMQPChannelError, AMQPConnectionError) as e:
|
|
||||||
# Channel/connection errors indicate stale delivery tag - don't retry
|
|
||||||
logger.warning(
|
|
||||||
f"Cannot ack delivery_tag={delivery_tag} due to channel/connection "
|
|
||||||
f"error: {e}. Message will be redelivered by RabbitMQ."
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
# Other errors might be transient, but log and skip to avoid blocking
|
|
||||||
logger.error(
|
|
||||||
f"Unexpected error acking delivery_tag={delivery_tag}: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if we're shutting down
|
|
||||||
if self.stop_consuming.is_set():
|
|
||||||
logger.info("Rejecting new task during shutdown")
|
|
||||||
ack_message(reject=True, requeue=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if we can accept more tasks
|
|
||||||
self._cleanup_completed_tasks()
|
|
||||||
if len(self.active_tasks) >= self.pool_size:
|
|
||||||
ack_message(reject=True, requeue=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
entry = CoPilotExecutionEntry.model_validate_json(body)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Could not parse run message: {e}, body={body}")
|
|
||||||
ack_message(reject=True, requeue=False)
|
|
||||||
return
|
|
||||||
|
|
||||||
task_id = entry.task_id
|
|
||||||
|
|
||||||
# Check for local duplicate - task is already running on this executor
|
|
||||||
if task_id in self.active_tasks:
|
|
||||||
logger.warning(
|
|
||||||
f"Task {task_id} already running locally, rejecting duplicate"
|
|
||||||
)
|
|
||||||
ack_message(reject=True, requeue=False)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Try to acquire cluster-wide lock
|
|
||||||
cluster_lock = ClusterLock(
|
|
||||||
redis=redis.get_redis(),
|
|
||||||
key=f"copilot:task:{task_id}:lock",
|
|
||||||
owner_id=self.executor_id,
|
|
||||||
timeout=settings.config.cluster_lock_timeout,
|
|
||||||
)
|
|
||||||
current_owner = cluster_lock.try_acquire()
|
|
||||||
if current_owner != self.executor_id:
|
|
||||||
if current_owner is not None:
|
|
||||||
logger.warning(f"Task {task_id} already running on pod {current_owner}")
|
|
||||||
ack_message(reject=True, requeue=False)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Could not acquire lock for {task_id} - Redis unavailable"
|
|
||||||
)
|
|
||||||
ack_message(reject=True, requeue=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Execute the task
|
|
||||||
try:
|
|
||||||
self._task_locks[task_id] = cluster_lock
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Acquired cluster lock for {task_id}, executor_id={self.executor_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
cancel_event = threading.Event()
|
|
||||||
future = self.executor.submit(
|
|
||||||
execute_copilot_task, entry, cancel_event, cluster_lock
|
|
||||||
)
|
|
||||||
self.active_tasks[task_id] = (future, cancel_event)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to setup execution for {task_id}: {e}")
|
|
||||||
cluster_lock.release()
|
|
||||||
if task_id in self._task_locks:
|
|
||||||
del self._task_locks[task_id]
|
|
||||||
ack_message(reject=True, requeue=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
self._update_metrics()
|
|
||||||
|
|
||||||
def on_run_done(f: Future):
|
|
||||||
logger.info(f"Run completed for {task_id}")
|
|
||||||
try:
|
|
||||||
if exec_error := f.exception():
|
|
||||||
logger.error(f"Execution for {task_id} failed: {exec_error}")
|
|
||||||
# Don't requeue failed tasks - they've been marked as failed
|
|
||||||
# in the stream registry. Requeuing would cause infinite retries
|
|
||||||
# for deterministic failures.
|
|
||||||
ack_message(reject=True, requeue=False)
|
|
||||||
else:
|
|
||||||
ack_message(reject=False, requeue=False)
|
|
||||||
except BaseException as e:
|
|
||||||
logger.exception(f"Error in run completion callback: {e}")
|
|
||||||
finally:
|
|
||||||
# Release the cluster lock
|
|
||||||
if task_id in self._task_locks:
|
|
||||||
logger.info(f"Releasing cluster lock for {task_id}")
|
|
||||||
self._task_locks[task_id].release()
|
|
||||||
del self._task_locks[task_id]
|
|
||||||
self._cleanup_completed_tasks()
|
|
||||||
|
|
||||||
future.add_done_callback(on_run_done)
|
|
||||||
|
|
||||||
# ============ Helper Methods ============ #
|
|
||||||
|
|
||||||
def _cleanup_completed_tasks(self) -> list[str]:
|
|
||||||
"""Remove completed futures from active_tasks and update metrics."""
|
|
||||||
completed_tasks = []
|
|
||||||
with self._active_tasks_lock:
|
|
||||||
for task_id, (future, _) in list(self.active_tasks.items()):
|
|
||||||
if future.done():
|
|
||||||
completed_tasks.append(task_id)
|
|
||||||
self.active_tasks.pop(task_id, None)
|
|
||||||
logger.info(f"Cleaned up completed task {task_id}")
|
|
||||||
|
|
||||||
self._update_metrics()
|
|
||||||
return completed_tasks
|
|
||||||
|
|
||||||
def _update_metrics(self):
|
|
||||||
"""Update Prometheus metrics."""
|
|
||||||
active_count = len(self.active_tasks)
|
|
||||||
active_tasks_gauge.set(active_count)
|
|
||||||
if self.stop_consuming.is_set():
|
|
||||||
utilization_gauge.set(1.0)
|
|
||||||
else:
|
|
||||||
utilization_gauge.set(
|
|
||||||
active_count / self.pool_size if self.pool_size > 0 else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
def _stop_message_consumers(
|
|
||||||
self, thread: threading.Thread, client: SyncRabbitMQ, prefix: str
|
|
||||||
):
|
|
||||||
"""Stop a message consumer thread."""
|
|
||||||
try:
|
|
||||||
channel = client.get_channel()
|
|
||||||
channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming())
|
|
||||||
|
|
||||||
thread.join(timeout=300)
|
|
||||||
if thread.is_alive():
|
|
||||||
logger.error(
|
|
||||||
f"{prefix} Thread did not finish in time, forcing disconnect"
|
|
||||||
)
|
|
||||||
|
|
||||||
client.disconnect()
|
|
||||||
logger.info(f"{prefix} Client disconnected")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{prefix} Error disconnecting client: {e}")
|
|
||||||
|
|
||||||
# ============ Lazy-initialized Properties ============ #
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cancel_thread(self) -> threading.Thread:
|
|
||||||
if self._cancel_thread is None:
|
|
||||||
self._cancel_thread = threading.Thread(
|
|
||||||
target=lambda: self._consume_cancel(),
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
return self._cancel_thread
|
|
||||||
|
|
||||||
@property
|
|
||||||
def run_thread(self) -> threading.Thread:
|
|
||||||
if self._run_thread is None:
|
|
||||||
self._run_thread = threading.Thread(
|
|
||||||
target=lambda: self._consume_run(),
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
return self._run_thread
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stop_consuming(self) -> threading.Event:
|
|
||||||
if self._stop_consuming is None:
|
|
||||||
self._stop_consuming = threading.Event()
|
|
||||||
return self._stop_consuming
|
|
||||||
|
|
||||||
@property
|
|
||||||
def executor(self) -> ThreadPoolExecutor:
|
|
||||||
if self._executor is None:
|
|
||||||
self._executor = ThreadPoolExecutor(
|
|
||||||
max_workers=self.pool_size,
|
|
||||||
initializer=init_worker,
|
|
||||||
)
|
|
||||||
return self._executor
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cancel_client(self) -> SyncRabbitMQ:
|
|
||||||
if self._cancel_client is None:
|
|
||||||
self._cancel_client = SyncRabbitMQ(create_copilot_queue_config())
|
|
||||||
return self._cancel_client
|
|
||||||
|
|
||||||
@property
|
|
||||||
def run_client(self) -> SyncRabbitMQ:
|
|
||||||
if self._run_client is None:
|
|
||||||
self._run_client = SyncRabbitMQ(create_copilot_queue_config())
|
|
||||||
return self._run_client
|
|
||||||
@@ -1,287 +0,0 @@
|
|||||||
"""CoPilot execution processor - per-worker execution logic.
|
|
||||||
|
|
||||||
This module contains the processor class that handles CoPilot task execution
|
|
||||||
in a thread-local context, following the graph executor pattern.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
|
|
||||||
from backend.copilot import service as copilot_service
|
|
||||||
from backend.copilot import stream_registry
|
|
||||||
from backend.copilot.config import ChatConfig
|
|
||||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep
|
|
||||||
from backend.copilot.sdk import service as sdk_service
|
|
||||||
from backend.executor.cluster_lock import ClusterLock
|
|
||||||
from backend.util.decorator import error_logged
|
|
||||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
|
||||||
from backend.util.logging import TruncatedLogger, configure_logging
|
|
||||||
from backend.util.process import set_service_name
|
|
||||||
from backend.util.retry import func_retry
|
|
||||||
|
|
||||||
from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
|
||||||
|
|
||||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
|
||||||
|
|
||||||
|
|
||||||
# ============ Module Entry Points ============ #
|
|
||||||
|
|
||||||
# Thread-local storage for processor instances
|
|
||||||
_tls = threading.local()
|
|
||||||
|
|
||||||
|
|
||||||
def execute_copilot_task(
|
|
||||||
entry: CoPilotExecutionEntry,
|
|
||||||
cancel: threading.Event,
|
|
||||||
cluster_lock: ClusterLock,
|
|
||||||
):
|
|
||||||
"""Execute a CoPilot task using the thread-local processor.
|
|
||||||
|
|
||||||
This function is the entry point called by the thread pool executor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
entry: The task payload
|
|
||||||
cancel: Threading event to signal cancellation
|
|
||||||
cluster_lock: Distributed lock for this execution
|
|
||||||
"""
|
|
||||||
processor: CoPilotProcessor = _tls.processor
|
|
||||||
return processor.execute(entry, cancel, cluster_lock)
|
|
||||||
|
|
||||||
|
|
||||||
def init_worker():
|
|
||||||
"""Initialize the processor for the current worker thread.
|
|
||||||
|
|
||||||
This function is called by the thread pool executor when a new worker
|
|
||||||
thread is created. It ensures each worker has its own processor instance.
|
|
||||||
"""
|
|
||||||
_tls.processor = CoPilotProcessor()
|
|
||||||
_tls.processor.on_executor_start()
|
|
||||||
|
|
||||||
|
|
||||||
def cleanup_worker():
|
|
||||||
"""Clean up the processor for the current worker thread.
|
|
||||||
|
|
||||||
Should be called before the worker thread's event loop is destroyed so
|
|
||||||
that event-loop-bound resources (e.g. ``aiohttp.ClientSession``) are
|
|
||||||
closed on the correct loop.
|
|
||||||
"""
|
|
||||||
processor: CoPilotProcessor | None = getattr(_tls, "processor", None)
|
|
||||||
if processor is not None:
|
|
||||||
processor.cleanup()
|
|
||||||
|
|
||||||
|
|
||||||
# ============ Processor Class ============ #
|
|
||||||
|
|
||||||
|
|
||||||
class CoPilotProcessor:
|
|
||||||
"""Per-worker execution logic for CoPilot tasks.
|
|
||||||
|
|
||||||
This class is instantiated once per worker thread and handles the execution
|
|
||||||
of CoPilot chat generation tasks. It maintains an async event loop for
|
|
||||||
running the async service code.
|
|
||||||
|
|
||||||
The execution flow:
|
|
||||||
1. CoPilot task is picked from RabbitMQ queue
|
|
||||||
2. Manager submits task to thread pool
|
|
||||||
3. Processor executes the task in its event loop
|
|
||||||
4. Results are published to Redis Streams
|
|
||||||
"""
|
|
||||||
|
|
||||||
@func_retry
|
|
||||||
def on_executor_start(self):
|
|
||||||
"""Initialize the processor when the worker thread starts.
|
|
||||||
|
|
||||||
This method is called once per worker thread to set up the async event
|
|
||||||
loop and initialize any required resources.
|
|
||||||
|
|
||||||
Database is accessed only through DatabaseManager, so we don't need to connect
|
|
||||||
to Prisma directly.
|
|
||||||
"""
|
|
||||||
configure_logging()
|
|
||||||
set_service_name("CoPilotExecutor")
|
|
||||||
self.tid = threading.get_ident()
|
|
||||||
self.execution_loop = asyncio.new_event_loop()
|
|
||||||
self.execution_thread = threading.Thread(
|
|
||||||
target=self.execution_loop.run_forever, daemon=True
|
|
||||||
)
|
|
||||||
self.execution_thread.start()
|
|
||||||
|
|
||||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
"""Clean up event-loop-bound resources before the loop is destroyed.
|
|
||||||
|
|
||||||
Shuts down the workspace storage instance that belongs to this
|
|
||||||
worker's event loop, ensuring ``aiohttp.ClientSession.close()``
|
|
||||||
runs on the same loop that created the session.
|
|
||||||
"""
|
|
||||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
|
||||||
|
|
||||||
try:
|
|
||||||
future = asyncio.run_coroutine_threadsafe(
|
|
||||||
shutdown_workspace_storage(), self.execution_loop
|
|
||||||
)
|
|
||||||
future.result(timeout=5)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[CoPilotExecutor] Worker {self.tid} cleanup error: {e}")
|
|
||||||
|
|
||||||
# Stop the event loop
|
|
||||||
self.execution_loop.call_soon_threadsafe(self.execution_loop.stop)
|
|
||||||
self.execution_thread.join(timeout=5)
|
|
||||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} cleaned up")
|
|
||||||
|
|
||||||
@error_logged(swallow=False)
|
|
||||||
def execute(
|
|
||||||
self,
|
|
||||||
entry: CoPilotExecutionEntry,
|
|
||||||
cancel: threading.Event,
|
|
||||||
cluster_lock: ClusterLock,
|
|
||||||
):
|
|
||||||
"""Execute a CoPilot task.
|
|
||||||
|
|
||||||
This is the main entry point for task execution. It runs the async
|
|
||||||
execution logic in the worker's event loop and handles errors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
entry: The task payload containing session and message info
|
|
||||||
cancel: Threading event to signal cancellation
|
|
||||||
cluster_lock: Distributed lock to prevent duplicate execution
|
|
||||||
"""
|
|
||||||
log = CoPilotLogMetadata(
|
|
||||||
logging.getLogger(__name__),
|
|
||||||
task_id=entry.task_id,
|
|
||||||
session_id=entry.session_id,
|
|
||||||
user_id=entry.user_id,
|
|
||||||
)
|
|
||||||
log.info("Starting execution")
|
|
||||||
|
|
||||||
start_time = time.monotonic()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Run the async execution in our event loop
|
|
||||||
future = asyncio.run_coroutine_threadsafe(
|
|
||||||
self._execute_async(entry, cancel, cluster_lock, log),
|
|
||||||
self.execution_loop,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for completion, checking cancel periodically
|
|
||||||
while not future.done():
|
|
||||||
try:
|
|
||||||
future.result(timeout=1.0)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
if cancel.is_set():
|
|
||||||
log.info("Cancellation requested")
|
|
||||||
future.cancel()
|
|
||||||
break
|
|
||||||
# Refresh cluster lock to maintain ownership
|
|
||||||
cluster_lock.refresh()
|
|
||||||
|
|
||||||
if not future.cancelled():
|
|
||||||
# Get result to propagate any exceptions
|
|
||||||
future.result()
|
|
||||||
|
|
||||||
elapsed = time.monotonic() - start_time
|
|
||||||
log.info(f"Execution completed in {elapsed:.2f}s")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
elapsed = time.monotonic() - start_time
|
|
||||||
log.error(f"Execution failed after {elapsed:.2f}s: {e}")
|
|
||||||
# Note: _execute_async already marks the task as failed before re-raising,
|
|
||||||
# so we don't call _mark_task_failed here to avoid duplicate error events.
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _execute_async(
|
|
||||||
self,
|
|
||||||
entry: CoPilotExecutionEntry,
|
|
||||||
cancel: threading.Event,
|
|
||||||
cluster_lock: ClusterLock,
|
|
||||||
log: CoPilotLogMetadata,
|
|
||||||
):
|
|
||||||
"""Async execution logic for CoPilot task.
|
|
||||||
|
|
||||||
This method calls the existing stream_chat_completion service function
|
|
||||||
and publishes results to the stream registry.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
entry: The task payload
|
|
||||||
cancel: Threading event to signal cancellation
|
|
||||||
cluster_lock: Distributed lock for refresh
|
|
||||||
log: Structured logger for this task
|
|
||||||
"""
|
|
||||||
last_refresh = time.monotonic()
|
|
||||||
refresh_interval = 30.0 # Refresh lock every 30 seconds
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Choose service based on LaunchDarkly flag
|
|
||||||
config = ChatConfig()
|
|
||||||
use_sdk = await is_feature_enabled(
|
|
||||||
Flag.COPILOT_SDK,
|
|
||||||
entry.user_id or "anonymous",
|
|
||||||
default=config.use_claude_agent_sdk,
|
|
||||||
)
|
|
||||||
stream_fn = (
|
|
||||||
sdk_service.stream_chat_completion_sdk
|
|
||||||
if use_sdk
|
|
||||||
else copilot_service.stream_chat_completion
|
|
||||||
)
|
|
||||||
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
|
|
||||||
|
|
||||||
# Stream chat completion and publish chunks to Redis
|
|
||||||
async for chunk in stream_fn(
|
|
||||||
session_id=entry.session_id,
|
|
||||||
message=entry.message if entry.message else None,
|
|
||||||
is_user_message=entry.is_user_message,
|
|
||||||
user_id=entry.user_id,
|
|
||||||
context=entry.context,
|
|
||||||
):
|
|
||||||
# Check for cancellation
|
|
||||||
if cancel.is_set():
|
|
||||||
log.info("Cancelled during streaming")
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
entry.task_id, StreamError(errorText="Operation cancelled")
|
|
||||||
)
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
entry.task_id, StreamFinishStep()
|
|
||||||
)
|
|
||||||
await stream_registry.publish_chunk(entry.task_id, StreamFinish())
|
|
||||||
await stream_registry.mark_task_completed(
|
|
||||||
entry.task_id, status="failed"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Refresh cluster lock periodically
|
|
||||||
current_time = time.monotonic()
|
|
||||||
if current_time - last_refresh >= refresh_interval:
|
|
||||||
cluster_lock.refresh()
|
|
||||||
last_refresh = current_time
|
|
||||||
|
|
||||||
# Publish chunk to stream registry
|
|
||||||
await stream_registry.publish_chunk(entry.task_id, chunk)
|
|
||||||
|
|
||||||
# Mark task as completed
|
|
||||||
await stream_registry.mark_task_completed(entry.task_id, status="completed")
|
|
||||||
log.info("Task completed successfully")
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
log.info("Task cancelled")
|
|
||||||
await stream_registry.mark_task_completed(entry.task_id, status="failed")
|
|
||||||
raise
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Task failed: {e}")
|
|
||||||
await self._mark_task_failed(entry.task_id, str(e))
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _mark_task_failed(self, task_id: str, error_message: str):
|
|
||||||
"""Mark a task as failed and publish error to stream registry."""
|
|
||||||
try:
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task_id, StreamError(errorText=error_message)
|
|
||||||
)
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
|
||||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to mark task {task_id} as failed: {e}")
|
|
||||||
@@ -1,207 +0,0 @@
|
|||||||
"""RabbitMQ queue configuration for CoPilot executor.
|
|
||||||
|
|
||||||
Defines two exchanges and queues following the graph executor pattern:
|
|
||||||
- 'copilot_execution' (DIRECT) for chat generation tasks
|
|
||||||
- 'copilot_cancel' (FANOUT) for cancellation requests
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
|
||||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ============ Logging Helper ============ #
|
|
||||||
|
|
||||||
|
|
||||||
class CoPilotLogMetadata(TruncatedLogger):
|
|
||||||
"""Structured logging helper for CoPilot executor.
|
|
||||||
|
|
||||||
In cloud environments (structured logging enabled), uses a simple prefix
|
|
||||||
and passes metadata via json_fields. In local environments, uses a detailed
|
|
||||||
prefix with all metadata key-value pairs for easier debugging.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logger: The underlying logger instance
|
|
||||||
max_length: Maximum log message length before truncation
|
|
||||||
**kwargs: Metadata key-value pairs (e.g., task_id="abc", session_id="xyz")
|
|
||||||
These are added to json_fields in cloud mode, or to the prefix in local mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
logger: logging.Logger,
|
|
||||||
max_length: int = 1000,
|
|
||||||
**kwargs: str | None,
|
|
||||||
):
|
|
||||||
# Filter out None values
|
|
||||||
metadata = {k: v for k, v in kwargs.items() if v is not None}
|
|
||||||
metadata["component"] = "CoPilotExecutor"
|
|
||||||
|
|
||||||
if is_structured_logging_enabled():
|
|
||||||
prefix = "[CoPilotExecutor]"
|
|
||||||
else:
|
|
||||||
# Build prefix from metadata key-value pairs
|
|
||||||
meta_parts = "|".join(
|
|
||||||
f"{k}:{v}" for k, v in metadata.items() if k != "component"
|
|
||||||
)
|
|
||||||
prefix = (
|
|
||||||
f"[CoPilotExecutor|{meta_parts}]" if meta_parts else "[CoPilotExecutor]"
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
logger,
|
|
||||||
max_length=max_length,
|
|
||||||
prefix=prefix,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ============ Exchange and Queue Configuration ============ #
|
|
||||||
|
|
||||||
COPILOT_EXECUTION_EXCHANGE = Exchange(
|
|
||||||
name="copilot_execution",
|
|
||||||
type=ExchangeType.DIRECT,
|
|
||||||
durable=True,
|
|
||||||
auto_delete=False,
|
|
||||||
)
|
|
||||||
COPILOT_EXECUTION_QUEUE_NAME = "copilot_execution_queue"
|
|
||||||
COPILOT_EXECUTION_ROUTING_KEY = "copilot.run"
|
|
||||||
|
|
||||||
COPILOT_CANCEL_EXCHANGE = Exchange(
|
|
||||||
name="copilot_cancel",
|
|
||||||
type=ExchangeType.FANOUT,
|
|
||||||
durable=True,
|
|
||||||
auto_delete=False,
|
|
||||||
)
|
|
||||||
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue"
|
|
||||||
|
|
||||||
# CoPilot operations can include extended thinking and agent generation
|
|
||||||
# which may take 30+ minutes to complete
|
|
||||||
COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
|
|
||||||
|
|
||||||
# Graceful shutdown timeout - allow in-flight operations to complete
|
|
||||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
|
|
||||||
|
|
||||||
|
|
||||||
def create_copilot_queue_config() -> RabbitMQConfig:
|
|
||||||
"""Create RabbitMQ configuration for CoPilot executor.
|
|
||||||
|
|
||||||
Defines two exchanges and queues:
|
|
||||||
- 'copilot_execution' (DIRECT) for chat generation tasks
|
|
||||||
- 'copilot_cancel' (FANOUT) for cancellation requests
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
RabbitMQConfig with exchanges and queues defined
|
|
||||||
"""
|
|
||||||
run_queue = Queue(
|
|
||||||
name=COPILOT_EXECUTION_QUEUE_NAME,
|
|
||||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
|
||||||
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
|
||||||
durable=True,
|
|
||||||
auto_delete=False,
|
|
||||||
arguments={
|
|
||||||
# Extended consumer timeout for long-running LLM operations
|
|
||||||
# Default 30-minute timeout is insufficient for extended thinking
|
|
||||||
# and agent generation which can take 30+ minutes
|
|
||||||
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS
|
|
||||||
* 1000,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
cancel_queue = Queue(
|
|
||||||
name=COPILOT_CANCEL_QUEUE_NAME,
|
|
||||||
exchange=COPILOT_CANCEL_EXCHANGE,
|
|
||||||
routing_key="", # not used for FANOUT
|
|
||||||
durable=True,
|
|
||||||
auto_delete=False,
|
|
||||||
)
|
|
||||||
return RabbitMQConfig(
|
|
||||||
vhost="/",
|
|
||||||
exchanges=[COPILOT_EXECUTION_EXCHANGE, COPILOT_CANCEL_EXCHANGE],
|
|
||||||
queues=[run_queue, cancel_queue],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ============ Message Models ============ #
|
|
||||||
|
|
||||||
|
|
||||||
class CoPilotExecutionEntry(BaseModel):
|
|
||||||
"""Task payload for CoPilot AI generation.
|
|
||||||
|
|
||||||
This model represents a chat generation task to be processed by the executor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
task_id: str
|
|
||||||
"""Unique identifier for this task (used for stream registry)"""
|
|
||||||
|
|
||||||
session_id: str
|
|
||||||
"""Chat session ID"""
|
|
||||||
|
|
||||||
user_id: str | None
|
|
||||||
"""User ID (may be None for anonymous users)"""
|
|
||||||
|
|
||||||
operation_id: str
|
|
||||||
"""Operation ID for webhook callbacks and completion tracking"""
|
|
||||||
|
|
||||||
message: str
|
|
||||||
"""User's message to process"""
|
|
||||||
|
|
||||||
is_user_message: bool = True
|
|
||||||
"""Whether the message is from the user (vs system/assistant)"""
|
|
||||||
|
|
||||||
context: dict[str, str] | None = None
|
|
||||||
"""Optional context for the message (e.g., {url: str, content: str})"""
|
|
||||||
|
|
||||||
|
|
||||||
class CancelCoPilotEvent(BaseModel):
|
|
||||||
"""Event to cancel a CoPilot operation."""
|
|
||||||
|
|
||||||
task_id: str
|
|
||||||
"""Task ID to cancel"""
|
|
||||||
|
|
||||||
|
|
||||||
# ============ Queue Publishing Helpers ============ #
|
|
||||||
|
|
||||||
|
|
||||||
async def enqueue_copilot_task(
|
|
||||||
task_id: str,
|
|
||||||
session_id: str,
|
|
||||||
user_id: str | None,
|
|
||||||
operation_id: str,
|
|
||||||
message: str,
|
|
||||||
is_user_message: bool = True,
|
|
||||||
context: dict[str, str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Enqueue a CoPilot task for processing by the executor service.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Unique identifier for this task (used for stream registry)
|
|
||||||
session_id: Chat session ID
|
|
||||||
user_id: User ID (may be None for anonymous users)
|
|
||||||
operation_id: Operation ID for webhook callbacks and completion tracking
|
|
||||||
message: User's message to process
|
|
||||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
|
||||||
context: Optional context for the message (e.g., {url: str, content: str})
|
|
||||||
"""
|
|
||||||
from backend.util.clients import get_async_copilot_queue
|
|
||||||
|
|
||||||
entry = CoPilotExecutionEntry(
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
operation_id=operation_id,
|
|
||||||
message=message,
|
|
||||||
is_user_message=is_user_message,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
queue_client = await get_async_copilot_queue()
|
|
||||||
await queue_client.publish_message(
|
|
||||||
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
|
||||||
message=entry.model_dump_json(),
|
|
||||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
|
||||||
)
|
|
||||||
@@ -19,6 +19,30 @@ CompletedBlockOutput = dict[str, list[Any]] # Completed stream, collected as a
|
|||||||
|
|
||||||
|
|
||||||
async def initialize_blocks() -> None:
|
async def initialize_blocks() -> None:
|
||||||
|
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||||
|
# This ensures the registry cache is populated even in executor context
|
||||||
|
try:
|
||||||
|
from backend.data import llm_registry
|
||||||
|
from backend.data.block_cost_config import refresh_llm_costs
|
||||||
|
|
||||||
|
# Only refresh if we have DB access (check if Prisma is connected)
|
||||||
|
from backend.data.db import is_connected
|
||||||
|
|
||||||
|
if is_connected():
|
||||||
|
await llm_registry.refresh_llm_registry()
|
||||||
|
await refresh_llm_costs()
|
||||||
|
logger.info("LLM registry refreshed during block initialization")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Prisma not connected, skipping LLM registry refresh during block initialization"
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to refresh LLM registry during block initialization: %s", exc
|
||||||
|
)
|
||||||
|
|
||||||
|
# First, sync all provider costs to blocks
|
||||||
|
# Imported here to avoid circular import
|
||||||
from backend.blocks import get_blocks
|
from backend.blocks import get_blocks
|
||||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||||
from backend.util.retry import func_retry
|
from backend.util.retry import func_retry
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
|
import logging
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
|
import prisma.models
|
||||||
|
|
||||||
from backend.blocks._base import Block, BlockCost, BlockCostType
|
from backend.blocks._base import Block, BlockCost, BlockCostType
|
||||||
from backend.blocks.ai_image_customizer import AIImageCustomizerBlock, GeminiImageModel
|
from backend.blocks.ai_image_customizer import AIImageCustomizerBlock, GeminiImageModel
|
||||||
from backend.blocks.ai_image_generator_block import AIImageGeneratorBlock, ImageGenModel
|
from backend.blocks.ai_image_generator_block import AIImageGeneratorBlock, ImageGenModel
|
||||||
@@ -24,13 +27,11 @@ from backend.blocks.ideogram import IdeogramModelBlock
|
|||||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||||
from backend.blocks.jina.search import ExtractWebsiteContentBlock, SearchTheWebBlock
|
from backend.blocks.jina.search import ExtractWebsiteContentBlock, SearchTheWebBlock
|
||||||
from backend.blocks.llm import (
|
from backend.blocks.llm import (
|
||||||
MODEL_METADATA,
|
|
||||||
AIConversationBlock,
|
AIConversationBlock,
|
||||||
AIListGeneratorBlock,
|
AIListGeneratorBlock,
|
||||||
AIStructuredResponseGeneratorBlock,
|
AIStructuredResponseGeneratorBlock,
|
||||||
AITextGeneratorBlock,
|
AITextGeneratorBlock,
|
||||||
AITextSummarizerBlock,
|
AITextSummarizerBlock,
|
||||||
LlmModel,
|
|
||||||
)
|
)
|
||||||
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
|
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
|
||||||
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||||
@@ -38,6 +39,7 @@ from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
|||||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||||
from backend.blocks.video.narration import VideoNarrationBlock
|
from backend.blocks.video.narration import VideoNarrationBlock
|
||||||
|
from backend.data import llm_registry
|
||||||
from backend.integrations.credentials_store import (
|
from backend.integrations.credentials_store import (
|
||||||
aiml_api_credentials,
|
aiml_api_credentials,
|
||||||
anthropic_credentials,
|
anthropic_credentials,
|
||||||
@@ -57,210 +59,116 @@ from backend.integrations.credentials_store import (
|
|||||||
v0_credentials,
|
v0_credentials,
|
||||||
)
|
)
|
||||||
|
|
||||||
# =============== Configure the cost for each LLM Model call =============== #
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MODEL_COST: dict[LlmModel, int] = {
|
PROVIDER_CREDENTIALS = {
|
||||||
LlmModel.O3: 4,
|
"openai": openai_credentials,
|
||||||
LlmModel.O3_MINI: 2,
|
"anthropic": anthropic_credentials,
|
||||||
LlmModel.O1: 16,
|
"groq": groq_credentials,
|
||||||
LlmModel.O1_MINI: 4,
|
"open_router": open_router_credentials,
|
||||||
# GPT-5 models
|
"llama_api": llama_api_credentials,
|
||||||
LlmModel.GPT5_2: 6,
|
"aiml_api": aiml_api_credentials,
|
||||||
LlmModel.GPT5_1: 5,
|
"v0": v0_credentials,
|
||||||
LlmModel.GPT5: 2,
|
|
||||||
LlmModel.GPT5_MINI: 1,
|
|
||||||
LlmModel.GPT5_NANO: 1,
|
|
||||||
LlmModel.GPT5_CHAT: 5,
|
|
||||||
LlmModel.GPT41: 2,
|
|
||||||
LlmModel.GPT41_MINI: 1,
|
|
||||||
LlmModel.GPT4O_MINI: 1,
|
|
||||||
LlmModel.GPT4O: 3,
|
|
||||||
LlmModel.GPT4_TURBO: 10,
|
|
||||||
LlmModel.GPT3_5_TURBO: 1,
|
|
||||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
|
||||||
LlmModel.CLAUDE_4_OPUS: 21,
|
|
||||||
LlmModel.CLAUDE_4_SONNET: 5,
|
|
||||||
LlmModel.CLAUDE_4_6_OPUS: 14,
|
|
||||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
|
||||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
|
||||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
|
||||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
|
||||||
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
|
||||||
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
|
||||||
LlmModel.AIML_API_LLAMA3_3_70B: 1,
|
|
||||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
|
|
||||||
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
|
|
||||||
LlmModel.LLAMA3_3_70B: 1,
|
|
||||||
LlmModel.LLAMA3_1_8B: 1,
|
|
||||||
LlmModel.OLLAMA_LLAMA3_3: 1,
|
|
||||||
LlmModel.OLLAMA_LLAMA3_2: 1,
|
|
||||||
LlmModel.OLLAMA_LLAMA3_8B: 1,
|
|
||||||
LlmModel.OLLAMA_LLAMA3_405B: 1,
|
|
||||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
|
||||||
LlmModel.OPENAI_GPT_OSS_120B: 1,
|
|
||||||
LlmModel.OPENAI_GPT_OSS_20B: 1,
|
|
||||||
LlmModel.GEMINI_2_5_PRO: 4,
|
|
||||||
LlmModel.GEMINI_3_PRO_PREVIEW: 5,
|
|
||||||
LlmModel.GEMINI_2_5_FLASH: 1,
|
|
||||||
LlmModel.GEMINI_2_0_FLASH: 1,
|
|
||||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
|
|
||||||
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
|
|
||||||
LlmModel.MISTRAL_NEMO: 1,
|
|
||||||
LlmModel.COHERE_COMMAND_R_08_2024: 1,
|
|
||||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
|
|
||||||
LlmModel.DEEPSEEK_CHAT: 2,
|
|
||||||
LlmModel.DEEPSEEK_R1_0528: 1,
|
|
||||||
LlmModel.PERPLEXITY_SONAR: 1,
|
|
||||||
LlmModel.PERPLEXITY_SONAR_PRO: 5,
|
|
||||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
|
|
||||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
|
|
||||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
|
|
||||||
LlmModel.AMAZON_NOVA_LITE_V1: 1,
|
|
||||||
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
|
|
||||||
LlmModel.AMAZON_NOVA_PRO_V1: 1,
|
|
||||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
|
|
||||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
|
|
||||||
LlmModel.META_LLAMA_4_SCOUT: 1,
|
|
||||||
LlmModel.META_LLAMA_4_MAVERICK: 1,
|
|
||||||
LlmModel.LLAMA_API_LLAMA_4_SCOUT: 1,
|
|
||||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: 1,
|
|
||||||
LlmModel.LLAMA_API_LLAMA3_3_8B: 1,
|
|
||||||
LlmModel.LLAMA_API_LLAMA3_3_70B: 1,
|
|
||||||
LlmModel.GROK_4: 9,
|
|
||||||
LlmModel.GROK_4_FAST: 1,
|
|
||||||
LlmModel.GROK_4_1_FAST: 1,
|
|
||||||
LlmModel.GROK_CODE_FAST_1: 1,
|
|
||||||
LlmModel.KIMI_K2: 1,
|
|
||||||
LlmModel.QWEN3_235B_A22B_THINKING: 1,
|
|
||||||
LlmModel.QWEN3_CODER: 9,
|
|
||||||
# v0 by Vercel models
|
|
||||||
LlmModel.V0_1_5_MD: 1,
|
|
||||||
LlmModel.V0_1_5_LG: 2,
|
|
||||||
LlmModel.V0_1_0_MD: 1,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for model in LlmModel:
|
# =============== Configure the cost for each LLM Model call =============== #
|
||||||
if model not in MODEL_COST:
|
# All LLM costs now come from the database via llm_registry
|
||||||
raise ValueError(f"Missing MODEL_COST for model: {model}")
|
|
||||||
|
LLM_COST: list[BlockCost] = []
|
||||||
|
|
||||||
|
|
||||||
LLM_COST = (
|
async def _build_llm_costs_from_registry() -> list[BlockCost]:
|
||||||
# Anthropic Models
|
"""
|
||||||
[
|
Build BlockCost list from all models in the LLM registry.
|
||||||
BlockCost(
|
|
||||||
cost_type=BlockCostType.RUN,
|
This function checks for active model migrations with customCreditCost overrides.
|
||||||
cost_filter={
|
When a model has been migrated with a custom price, that price is used instead
|
||||||
"model": model,
|
of the target model's default cost.
|
||||||
|
"""
|
||||||
|
# Query active migrations with custom pricing overrides.
|
||||||
|
# Note: LlmModelMigration is system-level data (no userId field) and this function
|
||||||
|
# is only called during app startup and admin operations, so no user ID filter needed.
|
||||||
|
migration_overrides: dict[str, int] = {}
|
||||||
|
try:
|
||||||
|
active_migrations = await prisma.models.LlmModelMigration.prisma().find_many(
|
||||||
|
where={
|
||||||
|
"isReverted": False,
|
||||||
|
"customCreditCost": {"not": None},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Key by targetModelSlug since that's the model nodes are now using
|
||||||
|
# after migration. The custom cost applies to the target model.
|
||||||
|
migration_overrides = {
|
||||||
|
migration.targetModelSlug: migration.customCreditCost
|
||||||
|
for migration in active_migrations
|
||||||
|
if migration.customCreditCost is not None
|
||||||
|
}
|
||||||
|
if migration_overrides:
|
||||||
|
logger.info(
|
||||||
|
"Found %d active model migrations with custom pricing overrides",
|
||||||
|
len(migration_overrides),
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to query model migration overrides: %s. Proceeding with default costs.",
|
||||||
|
exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
costs: list[BlockCost] = []
|
||||||
|
for model in llm_registry.iter_dynamic_models():
|
||||||
|
for cost in model.costs:
|
||||||
|
credentials = PROVIDER_CREDENTIALS.get(cost.credential_provider)
|
||||||
|
if not credentials:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping cost entry for %s due to unknown credentials provider %s",
|
||||||
|
model.slug,
|
||||||
|
cost.credential_provider,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if this model has a custom cost override from migration
|
||||||
|
cost_amount = migration_overrides.get(model.slug, cost.credit_cost)
|
||||||
|
|
||||||
|
if model.slug in migration_overrides:
|
||||||
|
logger.debug(
|
||||||
|
"Applying custom cost override for model %s: %d credits (default: %d)",
|
||||||
|
model.slug,
|
||||||
|
cost_amount,
|
||||||
|
cost.credit_cost,
|
||||||
|
)
|
||||||
|
|
||||||
|
cost_filter = {
|
||||||
|
"model": model.slug,
|
||||||
"credentials": {
|
"credentials": {
|
||||||
"id": anthropic_credentials.id,
|
"id": credentials.id,
|
||||||
"provider": anthropic_credentials.provider,
|
"provider": credentials.provider,
|
||||||
"type": anthropic_credentials.type,
|
"type": credentials.type,
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
cost_amount=cost,
|
costs.append(
|
||||||
)
|
BlockCost(
|
||||||
for model, cost in MODEL_COST.items()
|
cost_type=BlockCostType.RUN,
|
||||||
if MODEL_METADATA[model].provider == "anthropic"
|
cost_filter=cost_filter,
|
||||||
]
|
cost_amount=cost_amount,
|
||||||
# OpenAI Models
|
)
|
||||||
+ [
|
)
|
||||||
BlockCost(
|
return costs
|
||||||
cost_type=BlockCostType.RUN,
|
|
||||||
cost_filter={
|
|
||||||
"model": model,
|
async def refresh_llm_costs() -> None:
|
||||||
"credentials": {
|
"""
|
||||||
"id": openai_credentials.id,
|
Refresh LLM costs from the registry. All costs now come from the database.
|
||||||
"provider": openai_credentials.provider,
|
|
||||||
"type": openai_credentials.type,
|
This function also checks for active model migrations with custom pricing overrides
|
||||||
},
|
and applies them to ensure accurate billing.
|
||||||
},
|
"""
|
||||||
cost_amount=cost,
|
LLM_COST.clear()
|
||||||
)
|
LLM_COST.extend(await _build_llm_costs_from_registry())
|
||||||
for model, cost in MODEL_COST.items()
|
|
||||||
if MODEL_METADATA[model].provider == "openai"
|
|
||||||
]
|
# Initial load will happen after registry is refreshed at startup
|
||||||
# Groq Models
|
# Don't call refresh_llm_costs() here - it will be called after registry refresh
|
||||||
+ [
|
|
||||||
BlockCost(
|
|
||||||
cost_type=BlockCostType.RUN,
|
|
||||||
cost_filter={
|
|
||||||
"model": model,
|
|
||||||
"credentials": {"id": groq_credentials.id},
|
|
||||||
},
|
|
||||||
cost_amount=cost,
|
|
||||||
)
|
|
||||||
for model, cost in MODEL_COST.items()
|
|
||||||
if MODEL_METADATA[model].provider == "groq"
|
|
||||||
]
|
|
||||||
# Open Router Models
|
|
||||||
+ [
|
|
||||||
BlockCost(
|
|
||||||
cost_type=BlockCostType.RUN,
|
|
||||||
cost_filter={
|
|
||||||
"model": model,
|
|
||||||
"credentials": {
|
|
||||||
"id": open_router_credentials.id,
|
|
||||||
"provider": open_router_credentials.provider,
|
|
||||||
"type": open_router_credentials.type,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
cost_amount=cost,
|
|
||||||
)
|
|
||||||
for model, cost in MODEL_COST.items()
|
|
||||||
if MODEL_METADATA[model].provider == "open_router"
|
|
||||||
]
|
|
||||||
# Llama API Models
|
|
||||||
+ [
|
|
||||||
BlockCost(
|
|
||||||
cost_type=BlockCostType.RUN,
|
|
||||||
cost_filter={
|
|
||||||
"model": model,
|
|
||||||
"credentials": {
|
|
||||||
"id": llama_api_credentials.id,
|
|
||||||
"provider": llama_api_credentials.provider,
|
|
||||||
"type": llama_api_credentials.type,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
cost_amount=cost,
|
|
||||||
)
|
|
||||||
for model, cost in MODEL_COST.items()
|
|
||||||
if MODEL_METADATA[model].provider == "llama_api"
|
|
||||||
]
|
|
||||||
# v0 by Vercel Models
|
|
||||||
+ [
|
|
||||||
BlockCost(
|
|
||||||
cost_type=BlockCostType.RUN,
|
|
||||||
cost_filter={
|
|
||||||
"model": model,
|
|
||||||
"credentials": {
|
|
||||||
"id": v0_credentials.id,
|
|
||||||
"provider": v0_credentials.provider,
|
|
||||||
"type": v0_credentials.type,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
cost_amount=cost,
|
|
||||||
)
|
|
||||||
for model, cost in MODEL_COST.items()
|
|
||||||
if MODEL_METADATA[model].provider == "v0"
|
|
||||||
]
|
|
||||||
# AI/ML Api Models
|
|
||||||
+ [
|
|
||||||
BlockCost(
|
|
||||||
cost_type=BlockCostType.RUN,
|
|
||||||
cost_filter={
|
|
||||||
"model": model,
|
|
||||||
"credentials": {
|
|
||||||
"id": aiml_api_credentials.id,
|
|
||||||
"provider": aiml_api_credentials.provider,
|
|
||||||
"type": aiml_api_credentials.type,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
cost_amount=cost,
|
|
||||||
)
|
|
||||||
for model, cost in MODEL_COST.items()
|
|
||||||
if MODEL_METADATA[model].provider == "aiml_api"
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# =============== This is the exhaustive list of cost for each Block =============== #
|
# =============== This is the exhaustive list of cost for each Block =============== #
|
||||||
|
|
||||||
|
|||||||
@@ -1,118 +0,0 @@
|
|||||||
from backend.data import db
|
|
||||||
|
|
||||||
|
|
||||||
def chat_db():
|
|
||||||
if db.is_connected():
|
|
||||||
from backend.copilot import db as _chat_db
|
|
||||||
|
|
||||||
chat_db = _chat_db
|
|
||||||
else:
|
|
||||||
from backend.util.clients import get_database_manager_async_client
|
|
||||||
|
|
||||||
chat_db = get_database_manager_async_client()
|
|
||||||
|
|
||||||
return chat_db
|
|
||||||
|
|
||||||
|
|
||||||
def graph_db():
|
|
||||||
if db.is_connected():
|
|
||||||
from backend.data import graph as _graph_db
|
|
||||||
|
|
||||||
graph_db = _graph_db
|
|
||||||
else:
|
|
||||||
from backend.util.clients import get_database_manager_async_client
|
|
||||||
|
|
||||||
graph_db = get_database_manager_async_client()
|
|
||||||
|
|
||||||
return graph_db
|
|
||||||
|
|
||||||
|
|
||||||
def library_db():
|
|
||||||
if db.is_connected():
|
|
||||||
from backend.api.features.library import db as _library_db
|
|
||||||
|
|
||||||
library_db = _library_db
|
|
||||||
else:
|
|
||||||
from backend.util.clients import get_database_manager_async_client
|
|
||||||
|
|
||||||
library_db = get_database_manager_async_client()
|
|
||||||
|
|
||||||
return library_db
|
|
||||||
|
|
||||||
|
|
||||||
def store_db():
|
|
||||||
if db.is_connected():
|
|
||||||
from backend.api.features.store import db as _store_db
|
|
||||||
|
|
||||||
store_db = _store_db
|
|
||||||
else:
|
|
||||||
from backend.util.clients import get_database_manager_async_client
|
|
||||||
|
|
||||||
store_db = get_database_manager_async_client()
|
|
||||||
|
|
||||||
return store_db
|
|
||||||
|
|
||||||
|
|
||||||
def search():
|
|
||||||
if db.is_connected():
|
|
||||||
from backend.api.features.store import hybrid_search as _search
|
|
||||||
|
|
||||||
search = _search
|
|
||||||
else:
|
|
||||||
from backend.util.clients import get_database_manager_async_client
|
|
||||||
|
|
||||||
search = get_database_manager_async_client()
|
|
||||||
|
|
||||||
return search
|
|
||||||
|
|
||||||
|
|
||||||
def execution_db():
|
|
||||||
if db.is_connected():
|
|
||||||
from backend.data import execution as _execution_db
|
|
||||||
|
|
||||||
execution_db = _execution_db
|
|
||||||
else:
|
|
||||||
from backend.util.clients import get_database_manager_async_client
|
|
||||||
|
|
||||||
execution_db = get_database_manager_async_client()
|
|
||||||
|
|
||||||
return execution_db
|
|
||||||
|
|
||||||
|
|
||||||
def user_db():
|
|
||||||
if db.is_connected():
|
|
||||||
from backend.data import user as _user_db
|
|
||||||
|
|
||||||
user_db = _user_db
|
|
||||||
else:
|
|
||||||
from backend.util.clients import get_database_manager_async_client
|
|
||||||
|
|
||||||
user_db = get_database_manager_async_client()
|
|
||||||
|
|
||||||
return user_db
|
|
||||||
|
|
||||||
|
|
||||||
def understanding_db():
|
|
||||||
if db.is_connected():
|
|
||||||
from backend.data import understanding as _understanding_db
|
|
||||||
|
|
||||||
understanding_db = _understanding_db
|
|
||||||
else:
|
|
||||||
from backend.util.clients import get_database_manager_async_client
|
|
||||||
|
|
||||||
understanding_db = get_database_manager_async_client()
|
|
||||||
|
|
||||||
return understanding_db
|
|
||||||
|
|
||||||
|
|
||||||
def workspace_db():
|
|
||||||
if db.is_connected():
|
|
||||||
from backend.data import workspace as _workspace_db
|
|
||||||
|
|
||||||
workspace_db = _workspace_db
|
|
||||||
else:
|
|
||||||
from backend.util.clients import get_database_manager_async_client
|
|
||||||
|
|
||||||
workspace_db = get_database_manager_async_client()
|
|
||||||
|
|
||||||
return workspace_db
|
|
||||||
@@ -1147,14 +1147,14 @@ async def get_graph(
|
|||||||
return GraphModel.from_db(graph, for_export)
|
return GraphModel.from_db(graph, for_export)
|
||||||
|
|
||||||
|
|
||||||
async def get_store_listed_graphs(graph_ids: list[str]) -> dict[str, GraphModel]:
|
async def get_store_listed_graphs(*graph_ids: str) -> dict[str, GraphModel]:
|
||||||
"""Batch-fetch multiple store-listed graphs by their IDs.
|
"""Batch-fetch multiple store-listed graphs by their IDs.
|
||||||
|
|
||||||
Only returns graphs that have approved store listings (publicly available).
|
Only returns graphs that have approved store listings (publicly available).
|
||||||
Does not require permission checks since store-listed graphs are public.
|
Does not require permission checks since store-listed graphs are public.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_ids: List of graph IDs to fetch
|
*graph_ids: Variable number of graph IDs to fetch
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping graph_id to GraphModel for graphs with approved store listings
|
Dict mapping graph_id to GraphModel for graphs with approved store listings
|
||||||
@@ -1663,8 +1663,10 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
|||||||
if field.annotation == LlmModel:
|
if field.annotation == LlmModel:
|
||||||
llm_model_fields[block.id] = field_name
|
llm_model_fields[block.id] = field_name
|
||||||
|
|
||||||
# Convert enum values to a list of strings for the SQL query
|
# Get all model slugs from the registry (dynamic, not hardcoded enum)
|
||||||
enum_values = [v.value for v in LlmModel]
|
from backend.data import llm_registry
|
||||||
|
|
||||||
|
enum_values = list(llm_registry.get_all_model_slugs_for_validation())
|
||||||
escaped_enum_values = repr(tuple(enum_values)) # hack but works
|
escaped_enum_values = repr(tuple(enum_values)) # hack but works
|
||||||
|
|
||||||
# Update each block
|
# Update each block
|
||||||
|
|||||||
@@ -0,0 +1,72 @@
|
|||||||
|
"""
|
||||||
|
LLM Registry module for managing LLM models, providers, and costs dynamically.
|
||||||
|
|
||||||
|
This module provides a database-driven registry system for LLM models,
|
||||||
|
replacing hardcoded model configurations with a flexible admin-managed system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from backend.data.llm_registry.model import ModelMetadata
|
||||||
|
|
||||||
|
# Re-export for backwards compatibility
|
||||||
|
from backend.data.llm_registry.notifications import (
|
||||||
|
REGISTRY_REFRESH_CHANNEL,
|
||||||
|
publish_registry_refresh_notification,
|
||||||
|
subscribe_to_registry_refresh,
|
||||||
|
)
|
||||||
|
from backend.data.llm_registry.registry import (
|
||||||
|
RegistryModel,
|
||||||
|
RegistryModelCost,
|
||||||
|
RegistryModelCreator,
|
||||||
|
get_all_model_slugs_for_validation,
|
||||||
|
get_default_model_slug,
|
||||||
|
get_dynamic_model_slugs,
|
||||||
|
get_fallback_model_for_disabled,
|
||||||
|
get_llm_discriminator_mapping,
|
||||||
|
get_llm_model_cost,
|
||||||
|
get_llm_model_metadata,
|
||||||
|
get_llm_model_schema_options,
|
||||||
|
get_model_info,
|
||||||
|
is_model_enabled,
|
||||||
|
iter_dynamic_models,
|
||||||
|
refresh_llm_registry,
|
||||||
|
register_static_costs,
|
||||||
|
register_static_metadata,
|
||||||
|
)
|
||||||
|
from backend.data.llm_registry.schema_utils import (
|
||||||
|
is_llm_model_field,
|
||||||
|
refresh_llm_discriminator_mapping,
|
||||||
|
refresh_llm_model_options,
|
||||||
|
update_schema_with_llm_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Types
|
||||||
|
"ModelMetadata",
|
||||||
|
"RegistryModel",
|
||||||
|
"RegistryModelCost",
|
||||||
|
"RegistryModelCreator",
|
||||||
|
# Registry functions
|
||||||
|
"get_all_model_slugs_for_validation",
|
||||||
|
"get_default_model_slug",
|
||||||
|
"get_dynamic_model_slugs",
|
||||||
|
"get_fallback_model_for_disabled",
|
||||||
|
"get_llm_discriminator_mapping",
|
||||||
|
"get_llm_model_cost",
|
||||||
|
"get_llm_model_metadata",
|
||||||
|
"get_llm_model_schema_options",
|
||||||
|
"get_model_info",
|
||||||
|
"is_model_enabled",
|
||||||
|
"iter_dynamic_models",
|
||||||
|
"refresh_llm_registry",
|
||||||
|
"register_static_costs",
|
||||||
|
"register_static_metadata",
|
||||||
|
# Notifications
|
||||||
|
"REGISTRY_REFRESH_CHANNEL",
|
||||||
|
"publish_registry_refresh_notification",
|
||||||
|
"subscribe_to_registry_refresh",
|
||||||
|
# Schema utilities
|
||||||
|
"is_llm_model_field",
|
||||||
|
"refresh_llm_discriminator_mapping",
|
||||||
|
"refresh_llm_model_options",
|
||||||
|
"update_schema_with_llm_registry",
|
||||||
|
]
|
||||||
25
autogpt_platform/backend/backend/data/llm_registry/model.py
Normal file
25
autogpt_platform/backend/backend/data/llm_registry/model.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""Type definitions for LLM model metadata."""
|
||||||
|
|
||||||
|
from typing import Literal, NamedTuple
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMetadata(NamedTuple):
|
||||||
|
"""Metadata for an LLM model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
provider: The provider identifier (e.g., "openai", "anthropic")
|
||||||
|
context_window: Maximum context window size in tokens
|
||||||
|
max_output_tokens: Maximum output tokens (None if unlimited)
|
||||||
|
display_name: Human-readable name for the model
|
||||||
|
provider_name: Human-readable provider name (e.g., "OpenAI", "Anthropic")
|
||||||
|
creator_name: Name of the organization that created the model
|
||||||
|
price_tier: Relative cost tier (1=cheapest, 2=medium, 3=expensive)
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: str
|
||||||
|
context_window: int
|
||||||
|
max_output_tokens: int | None
|
||||||
|
display_name: str
|
||||||
|
provider_name: str
|
||||||
|
creator_name: str
|
||||||
|
price_tier: Literal[1, 2, 3]
|
||||||
@@ -0,0 +1,89 @@
|
|||||||
|
"""
|
||||||
|
Redis pub/sub notifications for LLM registry updates.
|
||||||
|
|
||||||
|
When models are added/updated/removed via the admin UI, this module
|
||||||
|
publishes notifications to Redis that all executor services subscribe to,
|
||||||
|
ensuring they refresh their registry cache in real-time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.data.redis_client import connect_async
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Redis channel name for LLM registry refresh notifications
|
||||||
|
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
|
||||||
|
|
||||||
|
|
||||||
|
async def publish_registry_refresh_notification() -> None:
|
||||||
|
"""
|
||||||
|
Publish a notification to Redis that the LLM registry has been updated.
|
||||||
|
All executor services subscribed to this channel will refresh their registry.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis = await connect_async()
|
||||||
|
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
|
||||||
|
logger.info("Published LLM registry refresh notification to Redis")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to publish LLM registry refresh notification: %s",
|
||||||
|
exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def subscribe_to_registry_refresh(
|
||||||
|
on_refresh: Any, # Async callable that takes no args
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Subscribe to Redis notifications for LLM registry updates.
|
||||||
|
This runs in a loop and processes messages as they arrive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
on_refresh: Async callable to execute when a refresh notification is received
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis = await connect_async()
|
||||||
|
pubsub = redis.pubsub()
|
||||||
|
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||||
|
logger.info(
|
||||||
|
"Subscribed to LLM registry refresh notifications on channel: %s",
|
||||||
|
REGISTRY_REFRESH_CHANNEL,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process messages in a loop
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
message = await pubsub.get_message(
|
||||||
|
ignore_subscribe_messages=True, timeout=1.0
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
message
|
||||||
|
and message["type"] == "message"
|
||||||
|
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||||
|
):
|
||||||
|
logger.info("Received LLM registry refresh notification")
|
||||||
|
try:
|
||||||
|
await on_refresh()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"Error refreshing LLM registry from notification: %s",
|
||||||
|
exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Error processing registry refresh message: %s", exc, exc_info=True
|
||||||
|
)
|
||||||
|
# Continue listening even if one message fails
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"Failed to subscribe to LLM registry refresh notifications: %s",
|
||||||
|
exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
388
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
388
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
@@ -0,0 +1,388 @@
|
|||||||
|
"""Core LLM registry implementation for managing models dynamically."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Iterable
|
||||||
|
|
||||||
|
import prisma.models
|
||||||
|
|
||||||
|
from backend.data.llm_registry.model import ModelMetadata
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _json_to_dict(value: Any) -> dict[str, Any]:
|
||||||
|
"""Convert Prisma Json type to dict, with fallback to empty dict."""
|
||||||
|
if value is None:
|
||||||
|
return {}
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return value
|
||||||
|
# Prisma Json type should always be a dict at runtime
|
||||||
|
return dict(value) if value else {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RegistryModelCost:
|
||||||
|
"""Cost configuration for an LLM model."""
|
||||||
|
|
||||||
|
credit_cost: int
|
||||||
|
credential_provider: str
|
||||||
|
credential_id: str | None
|
||||||
|
credential_type: str | None
|
||||||
|
currency: str | None
|
||||||
|
metadata: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RegistryModelCreator:
|
||||||
|
"""Creator information for an LLM model."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
display_name: str
|
||||||
|
description: str | None
|
||||||
|
website_url: str | None
|
||||||
|
logo_url: str | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RegistryModel:
|
||||||
|
"""Represents a model in the LLM registry."""
|
||||||
|
|
||||||
|
slug: str
|
||||||
|
display_name: str
|
||||||
|
description: str | None
|
||||||
|
metadata: ModelMetadata
|
||||||
|
capabilities: dict[str, Any]
|
||||||
|
extra_metadata: dict[str, Any]
|
||||||
|
provider_display_name: str
|
||||||
|
is_enabled: bool
|
||||||
|
is_recommended: bool = False
|
||||||
|
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
|
||||||
|
creator: RegistryModelCreator | None = None
|
||||||
|
|
||||||
|
|
||||||
|
_static_metadata: dict[str, ModelMetadata] = {}
|
||||||
|
_static_costs: dict[str, int] = {}
|
||||||
|
_dynamic_models: dict[str, RegistryModel] = {}
|
||||||
|
_schema_options: list[dict[str, str]] = []
|
||||||
|
_discriminator_mapping: dict[str, str] = {}
|
||||||
|
_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def register_static_metadata(metadata: dict[Any, ModelMetadata]) -> None:
|
||||||
|
"""Register static metadata for legacy models (deprecated)."""
|
||||||
|
_static_metadata.update({str(key): value for key, value in metadata.items()})
|
||||||
|
_refresh_cached_schema()
|
||||||
|
|
||||||
|
|
||||||
|
def register_static_costs(costs: dict[Any, int]) -> None:
|
||||||
|
"""Register static costs for legacy models (deprecated)."""
|
||||||
|
_static_costs.update({str(key): value for key, value in costs.items()})
|
||||||
|
|
||||||
|
|
||||||
|
def _build_schema_options() -> list[dict[str, str]]:
|
||||||
|
"""Build schema options for model selection dropdown. Only includes enabled models."""
|
||||||
|
options: list[dict[str, str]] = []
|
||||||
|
# Only include enabled models in the dropdown options
|
||||||
|
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
|
||||||
|
if model.is_enabled:
|
||||||
|
options.append(
|
||||||
|
{
|
||||||
|
"label": model.display_name,
|
||||||
|
"value": model.slug,
|
||||||
|
"group": model.metadata.provider,
|
||||||
|
"description": model.description or "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
for slug, metadata in _static_metadata.items():
|
||||||
|
if slug in _dynamic_models:
|
||||||
|
continue
|
||||||
|
options.append(
|
||||||
|
{
|
||||||
|
"label": slug,
|
||||||
|
"value": slug,
|
||||||
|
"group": metadata.provider,
|
||||||
|
"description": "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_llm_registry() -> None:
|
||||||
|
"""Refresh the LLM registry from the database. Loads all models (enabled and disabled)."""
|
||||||
|
async with _lock:
|
||||||
|
try:
|
||||||
|
records = await prisma.models.LlmModel.prisma().find_many(
|
||||||
|
include={
|
||||||
|
"Provider": True,
|
||||||
|
"Costs": True,
|
||||||
|
"Creator": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.debug("Found %d LLM model records in database", len(records))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"Failed to refresh LLM registry from DB: %s", exc, exc_info=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
dynamic: dict[str, RegistryModel] = {}
|
||||||
|
for record in records:
|
||||||
|
provider_name = (
|
||||||
|
record.Provider.name if record.Provider else record.providerId
|
||||||
|
)
|
||||||
|
provider_display_name = (
|
||||||
|
record.Provider.displayName if record.Provider else record.providerId
|
||||||
|
)
|
||||||
|
# Creator name: prefer Creator.name, fallback to provider display name
|
||||||
|
creator_name = (
|
||||||
|
record.Creator.name if record.Creator else provider_display_name
|
||||||
|
)
|
||||||
|
# Price tier: default to 1 (cheapest) if not set
|
||||||
|
price_tier = getattr(record, "priceTier", 1) or 1
|
||||||
|
# Clamp to valid range 1-3
|
||||||
|
price_tier = max(1, min(3, price_tier))
|
||||||
|
|
||||||
|
metadata = ModelMetadata(
|
||||||
|
provider=provider_name,
|
||||||
|
context_window=record.contextWindow,
|
||||||
|
max_output_tokens=record.maxOutputTokens,
|
||||||
|
display_name=record.displayName,
|
||||||
|
provider_name=provider_display_name,
|
||||||
|
creator_name=creator_name,
|
||||||
|
price_tier=price_tier, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
costs = tuple(
|
||||||
|
RegistryModelCost(
|
||||||
|
credit_cost=cost.creditCost,
|
||||||
|
credential_provider=cost.credentialProvider,
|
||||||
|
credential_id=cost.credentialId,
|
||||||
|
credential_type=cost.credentialType,
|
||||||
|
currency=cost.currency,
|
||||||
|
metadata=_json_to_dict(cost.metadata),
|
||||||
|
)
|
||||||
|
for cost in (record.Costs or [])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map creator if present
|
||||||
|
creator = None
|
||||||
|
if record.Creator:
|
||||||
|
creator = RegistryModelCreator(
|
||||||
|
id=record.Creator.id,
|
||||||
|
name=record.Creator.name,
|
||||||
|
display_name=record.Creator.displayName,
|
||||||
|
description=record.Creator.description,
|
||||||
|
website_url=record.Creator.websiteUrl,
|
||||||
|
logo_url=record.Creator.logoUrl,
|
||||||
|
)
|
||||||
|
|
||||||
|
dynamic[record.slug] = RegistryModel(
|
||||||
|
slug=record.slug,
|
||||||
|
display_name=record.displayName,
|
||||||
|
description=record.description,
|
||||||
|
metadata=metadata,
|
||||||
|
capabilities=_json_to_dict(record.capabilities),
|
||||||
|
extra_metadata=_json_to_dict(record.metadata),
|
||||||
|
provider_display_name=(
|
||||||
|
record.Provider.displayName
|
||||||
|
if record.Provider
|
||||||
|
else record.providerId
|
||||||
|
),
|
||||||
|
is_enabled=record.isEnabled,
|
||||||
|
is_recommended=record.isRecommended,
|
||||||
|
costs=costs,
|
||||||
|
creator=creator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Atomic swap - build new structures then replace references
|
||||||
|
# This ensures readers never see partially updated state
|
||||||
|
global _dynamic_models
|
||||||
|
_dynamic_models = dynamic
|
||||||
|
_refresh_cached_schema()
|
||||||
|
logger.info(
|
||||||
|
"LLM registry refreshed with %s dynamic models (enabled: %s, disabled: %s)",
|
||||||
|
len(dynamic),
|
||||||
|
sum(1 for m in dynamic.values() if m.is_enabled),
|
||||||
|
sum(1 for m in dynamic.values() if not m.is_enabled),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _refresh_cached_schema() -> None:
|
||||||
|
"""Refresh cached schema options and discriminator mapping."""
|
||||||
|
global _schema_options, _discriminator_mapping
|
||||||
|
|
||||||
|
# Build new structures
|
||||||
|
new_options = _build_schema_options()
|
||||||
|
new_mapping = {
|
||||||
|
slug: entry.metadata.provider for slug, entry in _dynamic_models.items()
|
||||||
|
}
|
||||||
|
for slug, metadata in _static_metadata.items():
|
||||||
|
new_mapping.setdefault(slug, metadata.provider)
|
||||||
|
|
||||||
|
# Atomic swap - replace references to ensure readers see consistent state
|
||||||
|
_schema_options = new_options
|
||||||
|
_discriminator_mapping = new_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_model_metadata(slug: str) -> ModelMetadata | None:
|
||||||
|
"""Get model metadata by slug. Checks dynamic models first, then static metadata."""
|
||||||
|
if slug in _dynamic_models:
|
||||||
|
return _dynamic_models[slug].metadata
|
||||||
|
return _static_metadata.get(slug)
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_model_cost(slug: str) -> tuple[RegistryModelCost, ...]:
|
||||||
|
"""Get model cost configuration by slug."""
|
||||||
|
if slug in _dynamic_models:
|
||||||
|
return _dynamic_models[slug].costs
|
||||||
|
cost_value = _static_costs.get(slug)
|
||||||
|
if cost_value is None:
|
||||||
|
return tuple()
|
||||||
|
return (
|
||||||
|
RegistryModelCost(
|
||||||
|
credit_cost=cost_value,
|
||||||
|
credential_provider="static",
|
||||||
|
credential_id=None,
|
||||||
|
credential_type=None,
|
||||||
|
currency=None,
|
||||||
|
metadata={},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_model_schema_options() -> list[dict[str, str]]:
|
||||||
|
"""
|
||||||
|
Get schema options for LLM model selection dropdown.
|
||||||
|
|
||||||
|
Returns a copy of cached schema options that are refreshed when the registry is
|
||||||
|
updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
|
||||||
|
"""
|
||||||
|
# Return a copy to prevent external mutation
|
||||||
|
return list(_schema_options)
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_discriminator_mapping() -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Get discriminator mapping for LLM models.
|
||||||
|
|
||||||
|
Returns a copy of cached discriminator mapping that is refreshed when the registry
|
||||||
|
is updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
|
||||||
|
"""
|
||||||
|
# Return a copy to prevent external mutation
|
||||||
|
return dict(_discriminator_mapping)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dynamic_model_slugs() -> set[str]:
|
||||||
|
"""Get all dynamic model slugs from the registry."""
|
||||||
|
return set(_dynamic_models.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_model_slugs_for_validation() -> set[str]:
|
||||||
|
"""
|
||||||
|
Get ALL model slugs (both enabled and disabled) for validation purposes.
|
||||||
|
|
||||||
|
This is used for JSON schema enum validation - we need to accept any known
|
||||||
|
model value (even disabled ones) so that existing graphs don't fail validation.
|
||||||
|
The actual fallback/enforcement happens at runtime in llm_call().
|
||||||
|
"""
|
||||||
|
all_slugs = set(_dynamic_models.keys())
|
||||||
|
all_slugs.update(_static_metadata.keys())
|
||||||
|
return all_slugs
|
||||||
|
|
||||||
|
|
||||||
|
def iter_dynamic_models() -> Iterable[RegistryModel]:
|
||||||
|
"""Iterate over all dynamic models in the registry."""
|
||||||
|
return tuple(_dynamic_models.values())
|
||||||
|
|
||||||
|
|
||||||
|
def get_fallback_model_for_disabled(disabled_model_slug: str) -> RegistryModel | None:
|
||||||
|
"""
|
||||||
|
Find a fallback model when the requested model is disabled.
|
||||||
|
|
||||||
|
Looks for an enabled model from the same provider. Prefers models with
|
||||||
|
similar names or capabilities if possible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
disabled_model_slug: The slug of the disabled model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An enabled RegistryModel from the same provider, or None if no fallback found
|
||||||
|
"""
|
||||||
|
disabled_model = _dynamic_models.get(disabled_model_slug)
|
||||||
|
if not disabled_model:
|
||||||
|
return None
|
||||||
|
|
||||||
|
provider = disabled_model.metadata.provider
|
||||||
|
|
||||||
|
# Find all enabled models from the same provider
|
||||||
|
candidates = [
|
||||||
|
model
|
||||||
|
for model in _dynamic_models.values()
|
||||||
|
if model.is_enabled and model.metadata.provider == provider
|
||||||
|
]
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Sort by: prefer models with similar context window, then by name
|
||||||
|
candidates.sort(
|
||||||
|
key=lambda m: (
|
||||||
|
abs(m.metadata.context_window - disabled_model.metadata.context_window),
|
||||||
|
m.display_name.lower(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return candidates[0]
|
||||||
|
|
||||||
|
|
||||||
|
def is_model_enabled(model_slug: str) -> bool:
|
||||||
|
"""Check if a model is enabled in the registry."""
|
||||||
|
model = _dynamic_models.get(model_slug)
|
||||||
|
if not model:
|
||||||
|
# Model not in registry - assume it's a static/legacy model and allow it
|
||||||
|
return True
|
||||||
|
return model.is_enabled
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_info(model_slug: str) -> RegistryModel | None:
|
||||||
|
"""Get model info from the registry."""
|
||||||
|
return _dynamic_models.get(model_slug)
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_model_slug() -> str | None:
|
||||||
|
"""
|
||||||
|
Get the default model slug to use for block defaults.
|
||||||
|
|
||||||
|
Returns the recommended model if set (configured via admin UI),
|
||||||
|
otherwise returns the first enabled model alphabetically.
|
||||||
|
Returns None if no models are available or enabled.
|
||||||
|
"""
|
||||||
|
# Return the recommended model if one is set and enabled
|
||||||
|
for model in _dynamic_models.values():
|
||||||
|
if model.is_recommended and model.is_enabled:
|
||||||
|
return model.slug
|
||||||
|
|
||||||
|
# No recommended model set - find first enabled model alphabetically
|
||||||
|
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
|
||||||
|
if model.is_enabled:
|
||||||
|
logger.warning(
|
||||||
|
"No recommended model set, using '%s' as default",
|
||||||
|
model.slug,
|
||||||
|
)
|
||||||
|
return model.slug
|
||||||
|
|
||||||
|
# No enabled models available
|
||||||
|
if _dynamic_models:
|
||||||
|
logger.error(
|
||||||
|
"No enabled models found in registry (%d models registered but all disabled)",
|
||||||
|
len(_dynamic_models),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error("No models registered in LLM registry")
|
||||||
|
|
||||||
|
return None
|
||||||
@@ -0,0 +1,130 @@
|
|||||||
|
"""
|
||||||
|
Helper utilities for LLM registry integration with block schemas.
|
||||||
|
|
||||||
|
This module handles the dynamic injection of discriminator mappings
|
||||||
|
and model options from the LLM registry into block schemas.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.data.llm_registry.registry import (
|
||||||
|
get_all_model_slugs_for_validation,
|
||||||
|
get_default_model_slug,
|
||||||
|
get_llm_discriminator_mapping,
|
||||||
|
get_llm_model_schema_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_llm_model_field(field_name: str, field_info: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a field is an LLM model selection field.
|
||||||
|
|
||||||
|
Returns True if the field has 'options' in json_schema_extra
|
||||||
|
(set by llm_model_schema_extra() in blocks/llm.py).
|
||||||
|
"""
|
||||||
|
if not hasattr(field_info, "json_schema_extra"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
extra = field_info.json_schema_extra
|
||||||
|
if isinstance(extra, dict):
|
||||||
|
return "options" in extra
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_llm_model_options(field_schema: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Refresh LLM model options from the registry.
|
||||||
|
|
||||||
|
Updates 'options' (for frontend dropdown) to show only enabled models,
|
||||||
|
but keeps the 'enum' (for validation) inclusive of ALL known models.
|
||||||
|
|
||||||
|
This is important because:
|
||||||
|
- Options: What users see in the dropdown (enabled models only)
|
||||||
|
- Enum: What values pass validation (all known models, including disabled)
|
||||||
|
|
||||||
|
Existing graphs may have disabled models selected - they should pass validation
|
||||||
|
and the fallback logic in llm_call() will handle using an alternative model.
|
||||||
|
"""
|
||||||
|
fresh_options = get_llm_model_schema_options()
|
||||||
|
if not fresh_options:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Update options array (UI dropdown) - only enabled models
|
||||||
|
if "options" in field_schema:
|
||||||
|
field_schema["options"] = fresh_options
|
||||||
|
|
||||||
|
all_known_slugs = get_all_model_slugs_for_validation()
|
||||||
|
if all_known_slugs and "enum" in field_schema:
|
||||||
|
existing_enum = set(field_schema.get("enum", []))
|
||||||
|
combined_enum = existing_enum | all_known_slugs
|
||||||
|
field_schema["enum"] = sorted(combined_enum)
|
||||||
|
|
||||||
|
# Set the default value from the registry (gpt-4o if available, else first enabled)
|
||||||
|
# This ensures new blocks have a sensible default pre-selected
|
||||||
|
default_slug = get_default_model_slug()
|
||||||
|
if default_slug:
|
||||||
|
field_schema["default"] = default_slug
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_llm_discriminator_mapping(field_schema: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Refresh discriminator_mapping for fields that use model-based discrimination.
|
||||||
|
|
||||||
|
The discriminator is already set when AICredentialsField() creates the field.
|
||||||
|
We only need to refresh the mapping when models are added/removed.
|
||||||
|
"""
|
||||||
|
if field_schema.get("discriminator") != "model":
|
||||||
|
return
|
||||||
|
|
||||||
|
# Always refresh the mapping to get latest models
|
||||||
|
fresh_mapping = get_llm_discriminator_mapping()
|
||||||
|
if fresh_mapping is not None:
|
||||||
|
field_schema["discriminator_mapping"] = fresh_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def update_schema_with_llm_registry(
|
||||||
|
schema: dict[str, Any], model_class: type | None = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update a JSON schema with current LLM registry data.
|
||||||
|
|
||||||
|
Refreshes:
|
||||||
|
1. Model options for LLM model selection fields (dropdown choices)
|
||||||
|
2. Discriminator mappings for credentials fields (model → provider)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: The JSON schema to update (mutated in-place)
|
||||||
|
model_class: The Pydantic model class (optional, for field introspection)
|
||||||
|
"""
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
|
||||||
|
for field_name, field_schema in properties.items():
|
||||||
|
if not isinstance(field_schema, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Refresh model options for LLM model fields
|
||||||
|
if model_class and hasattr(model_class, "model_fields"):
|
||||||
|
field_info = model_class.model_fields.get(field_name)
|
||||||
|
if field_info and is_llm_model_field(field_name, field_info):
|
||||||
|
try:
|
||||||
|
refresh_llm_model_options(field_schema)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to refresh LLM options for field %s: %s",
|
||||||
|
field_name,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Refresh discriminator mapping for fields that use model discrimination
|
||||||
|
try:
|
||||||
|
refresh_llm_discriminator_mapping(field_schema)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to refresh discriminator mapping for field %s: %s",
|
||||||
|
field_name,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
@@ -40,6 +40,7 @@ from pydantic_core import (
|
|||||||
)
|
)
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from backend.data.llm_registry import update_schema_with_llm_registry
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.json import loads as json_loads
|
from backend.util.json import loads as json_loads
|
||||||
from backend.util.request import parse_url
|
from backend.util.request import parse_url
|
||||||
@@ -570,7 +571,9 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
else:
|
else:
|
||||||
schema["credentials_provider"] = allowed_providers
|
schema["credentials_provider"] = allowed_providers
|
||||||
schema["credentials_types"] = model_class.allowed_cred_types()
|
schema["credentials_types"] = model_class.allowed_cred_types()
|
||||||
# Do not return anything, just mutate schema in place
|
|
||||||
|
# Ensure LLM discriminators are populated (delegates to shared helper)
|
||||||
|
update_schema_with_llm_registry(schema, model_class)
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra=_add_json_schema_extra, # type: ignore
|
json_schema_extra=_add_json_schema_extra, # type: ignore
|
||||||
@@ -732,16 +735,20 @@ def CredentialsField(
|
|||||||
This is enforced by the `BlockSchema` base class.
|
This is enforced by the `BlockSchema` base class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
field_schema_extra = {
|
# Build field_schema_extra - always include discriminator and mapping if discriminator is set
|
||||||
k: v
|
field_schema_extra: dict[str, Any] = {}
|
||||||
for k, v in {
|
|
||||||
"credentials_scopes": list(required_scopes) or None,
|
# Always include discriminator if provided
|
||||||
"discriminator": discriminator,
|
if discriminator is not None:
|
||||||
"discriminator_mapping": discriminator_mapping,
|
field_schema_extra["discriminator"] = discriminator
|
||||||
"discriminator_values": discriminator_values,
|
# Always include discriminator_mapping when discriminator is set (even if empty initially)
|
||||||
}.items()
|
field_schema_extra["discriminator_mapping"] = discriminator_mapping or {}
|
||||||
if v is not None
|
|
||||||
}
|
# Include other optional fields (only if not None)
|
||||||
|
if required_scopes:
|
||||||
|
field_schema_extra["credentials_scopes"] = list(required_scopes)
|
||||||
|
if discriminator_values:
|
||||||
|
field_schema_extra["discriminator_values"] = discriminator_values
|
||||||
|
|
||||||
# Merge any json_schema_extra passed in kwargs
|
# Merge any json_schema_extra passed in kwargs
|
||||||
if "json_schema_extra" in kwargs:
|
if "json_schema_extra" in kwargs:
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import logging
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pydantic
|
|
||||||
from prisma.models import UserWorkspace, UserWorkspaceFile
|
from prisma.models import UserWorkspace, UserWorkspaceFile
|
||||||
from prisma.types import UserWorkspaceFileWhereInput
|
from prisma.types import UserWorkspaceFileWhereInput
|
||||||
|
|
||||||
@@ -17,61 +16,7 @@ from backend.util.json import SafeJson
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Workspace(pydantic.BaseModel):
|
async def get_or_create_workspace(user_id: str) -> UserWorkspace:
|
||||||
"""Pydantic model for UserWorkspace, safe for RPC transport."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
user_id: str
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_db(workspace: "UserWorkspace") -> "Workspace":
|
|
||||||
return Workspace(
|
|
||||||
id=workspace.id,
|
|
||||||
user_id=workspace.userId,
|
|
||||||
created_at=workspace.createdAt,
|
|
||||||
updated_at=workspace.updatedAt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceFile(pydantic.BaseModel):
|
|
||||||
"""Pydantic model for UserWorkspaceFile, safe for RPC transport."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
workspace_id: str
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
storage_path: str
|
|
||||||
mime_type: str
|
|
||||||
size_bytes: int
|
|
||||||
checksum: Optional[str] = None
|
|
||||||
is_deleted: bool = False
|
|
||||||
deleted_at: Optional[datetime] = None
|
|
||||||
metadata: dict = pydantic.Field(default_factory=dict)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_db(file: "UserWorkspaceFile") -> "WorkspaceFile":
|
|
||||||
return WorkspaceFile(
|
|
||||||
id=file.id,
|
|
||||||
workspace_id=file.workspaceId,
|
|
||||||
created_at=file.createdAt,
|
|
||||||
updated_at=file.updatedAt,
|
|
||||||
name=file.name,
|
|
||||||
path=file.path,
|
|
||||||
storage_path=file.storagePath,
|
|
||||||
mime_type=file.mimeType,
|
|
||||||
size_bytes=file.sizeBytes,
|
|
||||||
checksum=file.checksum,
|
|
||||||
is_deleted=file.isDeleted,
|
|
||||||
deleted_at=file.deletedAt,
|
|
||||||
metadata=file.metadata if isinstance(file.metadata, dict) else {},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_or_create_workspace(user_id: str) -> Workspace:
|
|
||||||
"""
|
"""
|
||||||
Get user's workspace, creating one if it doesn't exist.
|
Get user's workspace, creating one if it doesn't exist.
|
||||||
|
|
||||||
@@ -82,7 +27,7 @@ async def get_or_create_workspace(user_id: str) -> Workspace:
|
|||||||
user_id: The user's ID
|
user_id: The user's ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Workspace instance
|
UserWorkspace instance
|
||||||
"""
|
"""
|
||||||
workspace = await UserWorkspace.prisma().upsert(
|
workspace = await UserWorkspace.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
@@ -92,10 +37,10 @@ async def get_or_create_workspace(user_id: str) -> Workspace:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return Workspace.from_db(workspace)
|
return workspace
|
||||||
|
|
||||||
|
|
||||||
async def get_workspace(user_id: str) -> Optional[Workspace]:
|
async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
|
||||||
"""
|
"""
|
||||||
Get user's workspace if it exists.
|
Get user's workspace if it exists.
|
||||||
|
|
||||||
@@ -103,10 +48,9 @@ async def get_workspace(user_id: str) -> Optional[Workspace]:
|
|||||||
user_id: The user's ID
|
user_id: The user's ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Workspace instance or None
|
UserWorkspace instance or None
|
||||||
"""
|
"""
|
||||||
workspace = await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
return await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
||||||
return Workspace.from_db(workspace) if workspace else None
|
|
||||||
|
|
||||||
|
|
||||||
async def create_workspace_file(
|
async def create_workspace_file(
|
||||||
@@ -119,7 +63,7 @@ async def create_workspace_file(
|
|||||||
size_bytes: int,
|
size_bytes: int,
|
||||||
checksum: Optional[str] = None,
|
checksum: Optional[str] = None,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
) -> WorkspaceFile:
|
) -> UserWorkspaceFile:
|
||||||
"""
|
"""
|
||||||
Create a new workspace file record.
|
Create a new workspace file record.
|
||||||
|
|
||||||
@@ -135,7 +79,7 @@ async def create_workspace_file(
|
|||||||
metadata: Optional additional metadata
|
metadata: Optional additional metadata
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Created WorkspaceFile instance
|
Created UserWorkspaceFile instance
|
||||||
"""
|
"""
|
||||||
# Normalize path to start with /
|
# Normalize path to start with /
|
||||||
if not path.startswith("/"):
|
if not path.startswith("/"):
|
||||||
@@ -159,37 +103,34 @@ async def create_workspace_file(
|
|||||||
f"Created workspace file {file.id} at path {path} "
|
f"Created workspace file {file.id} at path {path} "
|
||||||
f"in workspace {workspace_id}"
|
f"in workspace {workspace_id}"
|
||||||
)
|
)
|
||||||
return WorkspaceFile.from_db(file)
|
return file
|
||||||
|
|
||||||
|
|
||||||
async def get_workspace_file(
|
async def get_workspace_file(
|
||||||
file_id: str,
|
file_id: str,
|
||||||
workspace_id: str,
|
workspace_id: Optional[str] = None,
|
||||||
) -> Optional[WorkspaceFile]:
|
) -> Optional[UserWorkspaceFile]:
|
||||||
"""
|
"""
|
||||||
Get a workspace file by ID.
|
Get a workspace file by ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_id: The file ID
|
file_id: The file ID
|
||||||
workspace_id: Workspace ID for scoping (required)
|
workspace_id: Optional workspace ID for validation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
WorkspaceFile instance or None
|
UserWorkspaceFile instance or None
|
||||||
"""
|
"""
|
||||||
where_clause: UserWorkspaceFileWhereInput = {
|
where_clause: dict = {"id": file_id, "isDeleted": False}
|
||||||
"id": file_id,
|
if workspace_id:
|
||||||
"isDeleted": False,
|
where_clause["workspaceId"] = workspace_id
|
||||||
"workspaceId": workspace_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
file = await UserWorkspaceFile.prisma().find_first(where=where_clause)
|
return await UserWorkspaceFile.prisma().find_first(where=where_clause)
|
||||||
return WorkspaceFile.from_db(file) if file else None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_workspace_file_by_path(
|
async def get_workspace_file_by_path(
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
path: str,
|
path: str,
|
||||||
) -> Optional[WorkspaceFile]:
|
) -> Optional[UserWorkspaceFile]:
|
||||||
"""
|
"""
|
||||||
Get a workspace file by its virtual path.
|
Get a workspace file by its virtual path.
|
||||||
|
|
||||||
@@ -198,20 +139,19 @@ async def get_workspace_file_by_path(
|
|||||||
path: Virtual path
|
path: Virtual path
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
WorkspaceFile instance or None
|
UserWorkspaceFile instance or None
|
||||||
"""
|
"""
|
||||||
# Normalize path
|
# Normalize path
|
||||||
if not path.startswith("/"):
|
if not path.startswith("/"):
|
||||||
path = f"/{path}"
|
path = f"/{path}"
|
||||||
|
|
||||||
file = await UserWorkspaceFile.prisma().find_first(
|
return await UserWorkspaceFile.prisma().find_first(
|
||||||
where={
|
where={
|
||||||
"workspaceId": workspace_id,
|
"workspaceId": workspace_id,
|
||||||
"path": path,
|
"path": path,
|
||||||
"isDeleted": False,
|
"isDeleted": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return WorkspaceFile.from_db(file) if file else None
|
|
||||||
|
|
||||||
|
|
||||||
async def list_workspace_files(
|
async def list_workspace_files(
|
||||||
@@ -220,7 +160,7 @@ async def list_workspace_files(
|
|||||||
include_deleted: bool = False,
|
include_deleted: bool = False,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> list[WorkspaceFile]:
|
) -> list[UserWorkspaceFile]:
|
||||||
"""
|
"""
|
||||||
List files in a workspace.
|
List files in a workspace.
|
||||||
|
|
||||||
@@ -232,7 +172,7 @@ async def list_workspace_files(
|
|||||||
offset: Number of files to skip
|
offset: Number of files to skip
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of WorkspaceFile instances
|
List of UserWorkspaceFile instances
|
||||||
"""
|
"""
|
||||||
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
|
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
|
||||||
|
|
||||||
@@ -245,13 +185,12 @@ async def list_workspace_files(
|
|||||||
path_prefix = f"/{path_prefix}"
|
path_prefix = f"/{path_prefix}"
|
||||||
where_clause["path"] = {"startswith": path_prefix}
|
where_clause["path"] = {"startswith": path_prefix}
|
||||||
|
|
||||||
files = await UserWorkspaceFile.prisma().find_many(
|
return await UserWorkspaceFile.prisma().find_many(
|
||||||
where=where_clause,
|
where=where_clause,
|
||||||
order={"createdAt": "desc"},
|
order={"createdAt": "desc"},
|
||||||
take=limit,
|
take=limit,
|
||||||
skip=offset,
|
skip=offset,
|
||||||
)
|
)
|
||||||
return [WorkspaceFile.from_db(f) for f in files]
|
|
||||||
|
|
||||||
|
|
||||||
async def count_workspace_files(
|
async def count_workspace_files(
|
||||||
@@ -270,7 +209,7 @@ async def count_workspace_files(
|
|||||||
Returns:
|
Returns:
|
||||||
Number of files
|
Number of files
|
||||||
"""
|
"""
|
||||||
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
|
where_clause: dict = {"workspaceId": workspace_id}
|
||||||
if not include_deleted:
|
if not include_deleted:
|
||||||
where_clause["isDeleted"] = False
|
where_clause["isDeleted"] = False
|
||||||
|
|
||||||
@@ -285,8 +224,8 @@ async def count_workspace_files(
|
|||||||
|
|
||||||
async def soft_delete_workspace_file(
|
async def soft_delete_workspace_file(
|
||||||
file_id: str,
|
file_id: str,
|
||||||
workspace_id: str,
|
workspace_id: Optional[str] = None,
|
||||||
) -> Optional[WorkspaceFile]:
|
) -> Optional[UserWorkspaceFile]:
|
||||||
"""
|
"""
|
||||||
Soft-delete a workspace file.
|
Soft-delete a workspace file.
|
||||||
|
|
||||||
@@ -295,10 +234,10 @@ async def soft_delete_workspace_file(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_id: The file ID
|
file_id: The file ID
|
||||||
workspace_id: Workspace ID for scoping (required)
|
workspace_id: Optional workspace ID for validation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated WorkspaceFile instance or None if not found
|
Updated UserWorkspaceFile instance or None if not found
|
||||||
"""
|
"""
|
||||||
# First verify the file exists and belongs to workspace
|
# First verify the file exists and belongs to workspace
|
||||||
file = await get_workspace_file(file_id, workspace_id)
|
file = await get_workspace_file(file_id, workspace_id)
|
||||||
@@ -320,7 +259,7 @@ async def soft_delete_workspace_file(
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Soft-deleted workspace file {file_id}")
|
logger.info(f"Soft-deleted workspace file {file_id}")
|
||||||
return WorkspaceFile.from_db(updated) if updated else None
|
return updated
|
||||||
|
|
||||||
|
|
||||||
async def get_workspace_total_size(workspace_id: str) -> int:
|
async def get_workspace_total_size(workspace_id: str) -> int:
|
||||||
@@ -334,4 +273,4 @@ async def get_workspace_total_size(workspace_id: str) -> int:
|
|||||||
Total size in bytes
|
Total size in bytes
|
||||||
"""
|
"""
|
||||||
files = await list_workspace_files(workspace_id)
|
files = await list_workspace_files(workspace_id)
|
||||||
return sum(file.size_bytes for file in files)
|
return sum(file.sizeBytes for file in files)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from backend.app import run_processes
|
from backend.app import run_processes
|
||||||
from backend.data.db_manager import DatabaseManager
|
from backend.executor import DatabaseManager
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
|
from .database import DatabaseManager, DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||||
from .manager import ExecutionManager
|
from .manager import ExecutionManager
|
||||||
from .scheduler import Scheduler
|
from .scheduler import Scheduler
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"DatabaseManager",
|
||||||
|
"DatabaseManagerClient",
|
||||||
|
"DatabaseManagerAsyncClient",
|
||||||
"ExecutionManager",
|
"ExecutionManager",
|
||||||
"Scheduler",
|
"Scheduler",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from backend.util.settings import Settings
|
|||||||
from backend.util.truncate import truncate
|
from backend.util.truncate import truncate
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
from backend.executor import DatabaseManagerAsyncClient
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import logging
|
|||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
from backend.executor import DatabaseManagerAsyncClient
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Redis-based distributed locking for cluster coordination."""
|
"""Redis-based distributed locking for cluster coordination."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@@ -20,7 +19,6 @@ class ClusterLock:
|
|||||||
self.owner_id = owner_id
|
self.owner_id = owner_id
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self._last_refresh = 0.0
|
self._last_refresh = 0.0
|
||||||
self._refresh_lock = threading.Lock()
|
|
||||||
|
|
||||||
def try_acquire(self) -> str | None:
|
def try_acquire(self) -> str | None:
|
||||||
"""Try to acquire the lock.
|
"""Try to acquire the lock.
|
||||||
@@ -33,8 +31,7 @@ class ClusterLock:
|
|||||||
try:
|
try:
|
||||||
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
|
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
|
||||||
if success:
|
if success:
|
||||||
with self._refresh_lock:
|
self._last_refresh = time.time()
|
||||||
self._last_refresh = time.time()
|
|
||||||
return self.owner_id # Successfully acquired
|
return self.owner_id # Successfully acquired
|
||||||
|
|
||||||
# Failed to acquire, get current owner
|
# Failed to acquire, get current owner
|
||||||
@@ -60,27 +57,23 @@ class ClusterLock:
|
|||||||
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
||||||
During rate limiting, still verifies lock existence but skips TTL extension.
|
During rate limiting, still verifies lock existence but skips TTL extension.
|
||||||
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
||||||
|
|
||||||
Thread-safe: uses _refresh_lock to protect _last_refresh access.
|
|
||||||
"""
|
"""
|
||||||
# Calculate refresh interval: max(timeout // 10, 1)
|
# Calculate refresh interval: max(timeout // 10, 1)
|
||||||
refresh_interval = max(self.timeout // 10, 1)
|
refresh_interval = max(self.timeout // 10, 1)
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
# Check if we're within the rate limit period (thread-safe read)
|
# Check if we're within the rate limit period
|
||||||
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
||||||
with self._refresh_lock:
|
|
||||||
last_refresh = self._last_refresh
|
|
||||||
is_rate_limited = (
|
is_rate_limited = (
|
||||||
last_refresh > 0 and (current_time - last_refresh) < refresh_interval
|
self._last_refresh > 0
|
||||||
|
and (current_time - self._last_refresh) < refresh_interval
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Always verify lock existence, even during rate limiting
|
# Always verify lock existence, even during rate limiting
|
||||||
current_value = self.redis.get(self.key)
|
current_value = self.redis.get(self.key)
|
||||||
if not current_value:
|
if not current_value:
|
||||||
with self._refresh_lock:
|
self._last_refresh = 0
|
||||||
self._last_refresh = 0
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
stored_owner = (
|
stored_owner = (
|
||||||
@@ -89,8 +82,7 @@ class ClusterLock:
|
|||||||
else str(current_value)
|
else str(current_value)
|
||||||
)
|
)
|
||||||
if stored_owner != self.owner_id:
|
if stored_owner != self.owner_id:
|
||||||
with self._refresh_lock:
|
self._last_refresh = 0
|
||||||
self._last_refresh = 0
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# If rate limited, return True but don't update TTL or timestamp
|
# If rate limited, return True but don't update TTL or timestamp
|
||||||
@@ -99,30 +91,25 @@ class ClusterLock:
|
|||||||
|
|
||||||
# Perform actual refresh
|
# Perform actual refresh
|
||||||
if self.redis.expire(self.key, self.timeout):
|
if self.redis.expire(self.key, self.timeout):
|
||||||
with self._refresh_lock:
|
self._last_refresh = current_time
|
||||||
self._last_refresh = current_time
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
with self._refresh_lock:
|
self._last_refresh = 0
|
||||||
self._last_refresh = 0
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
|
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
|
||||||
with self._refresh_lock:
|
self._last_refresh = 0
|
||||||
self._last_refresh = 0
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def release(self):
|
def release(self):
|
||||||
"""Release the lock."""
|
"""Release the lock."""
|
||||||
with self._refresh_lock:
|
if self._last_refresh == 0:
|
||||||
if self._last_refresh == 0:
|
return
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.redis.delete(self.key)
|
self.redis.delete(self.key)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
with self._refresh_lock:
|
self._last_refresh = 0.0
|
||||||
self._last_refresh = 0.0
|
|
||||||
|
|||||||
@@ -4,26 +4,14 @@ from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cas
|
|||||||
|
|
||||||
from backend.api.features.library.db import (
|
from backend.api.features.library.db import (
|
||||||
add_store_agent_to_library,
|
add_store_agent_to_library,
|
||||||
create_graph_in_library,
|
|
||||||
create_library_agent,
|
|
||||||
get_library_agent,
|
|
||||||
get_library_agent_by_graph_id,
|
|
||||||
list_library_agents,
|
list_library_agents,
|
||||||
update_graph_in_library,
|
|
||||||
)
|
|
||||||
from backend.api.features.store.db import (
|
|
||||||
get_agent,
|
|
||||||
get_available_graph,
|
|
||||||
get_store_agent_details,
|
|
||||||
get_store_agents,
|
|
||||||
)
|
)
|
||||||
|
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||||
from backend.api.features.store.embeddings import (
|
from backend.api.features.store.embeddings import (
|
||||||
backfill_missing_embeddings,
|
backfill_missing_embeddings,
|
||||||
cleanup_orphaned_embeddings,
|
cleanup_orphaned_embeddings,
|
||||||
get_embedding_stats,
|
get_embedding_stats,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
|
||||||
from backend.copilot import db as chat_db
|
|
||||||
from backend.data import db
|
from backend.data import db
|
||||||
from backend.data.analytics import (
|
from backend.data.analytics import (
|
||||||
get_accuracy_trends_and_alerts,
|
get_accuracy_trends_and_alerts,
|
||||||
@@ -60,7 +48,6 @@ from backend.data.graph import (
|
|||||||
get_graph_metadata,
|
get_graph_metadata,
|
||||||
get_graph_settings,
|
get_graph_settings,
|
||||||
get_node,
|
get_node,
|
||||||
get_store_listed_graphs,
|
|
||||||
validate_graph_execution_permissions,
|
validate_graph_execution_permissions,
|
||||||
)
|
)
|
||||||
from backend.data.human_review import (
|
from backend.data.human_review import (
|
||||||
@@ -80,10 +67,6 @@ from backend.data.notifications import (
|
|||||||
remove_notifications_from_batch,
|
remove_notifications_from_batch,
|
||||||
)
|
)
|
||||||
from backend.data.onboarding import increment_onboarding_runs
|
from backend.data.onboarding import increment_onboarding_runs
|
||||||
from backend.data.understanding import (
|
|
||||||
get_business_understanding,
|
|
||||||
upsert_business_understanding,
|
|
||||||
)
|
|
||||||
from backend.data.user import (
|
from backend.data.user import (
|
||||||
get_active_user_ids_in_timerange,
|
get_active_user_ids_in_timerange,
|
||||||
get_user_by_id,
|
get_user_by_id,
|
||||||
@@ -93,15 +76,6 @@ from backend.data.user import (
|
|||||||
get_user_notification_preference,
|
get_user_notification_preference,
|
||||||
update_user_integrations,
|
update_user_integrations,
|
||||||
)
|
)
|
||||||
from backend.data.workspace import (
|
|
||||||
count_workspace_files,
|
|
||||||
create_workspace_file,
|
|
||||||
get_or_create_workspace,
|
|
||||||
get_workspace_file,
|
|
||||||
get_workspace_file_by_path,
|
|
||||||
list_workspace_files,
|
|
||||||
soft_delete_workspace_file,
|
|
||||||
)
|
|
||||||
from backend.util.service import (
|
from backend.util.service import (
|
||||||
AppService,
|
AppService,
|
||||||
AppServiceClient,
|
AppServiceClient,
|
||||||
@@ -133,13 +107,6 @@ async def _get_credits(user_id: str) -> int:
|
|||||||
|
|
||||||
|
|
||||||
class DatabaseManager(AppService):
|
class DatabaseManager(AppService):
|
||||||
"""Database connection pooling service.
|
|
||||||
|
|
||||||
This service connects to the Prisma engine and exposes database
|
|
||||||
operations via RPC endpoints. It acts as a centralized connection pool
|
|
||||||
for all services that need database access.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(self, app: "FastAPI"):
|
async def lifespan(self, app: "FastAPI"):
|
||||||
async with super().lifespan(app):
|
async with super().lifespan(app):
|
||||||
@@ -175,15 +142,11 @@ class DatabaseManager(AppService):
|
|||||||
def _(
|
def _(
|
||||||
f: Callable[P, R], name: str | None = None
|
f: Callable[P, R], name: str | None = None
|
||||||
) -> Callable[Concatenate[object, P], R]:
|
) -> Callable[Concatenate[object, P], R]:
|
||||||
"""
|
|
||||||
Exposes a function as an RPC endpoint, and adds a virtual `self` param
|
|
||||||
to the function's type so it can be bound as a method.
|
|
||||||
"""
|
|
||||||
if name is not None:
|
if name is not None:
|
||||||
f.__name__ = name
|
f.__name__ = name
|
||||||
return cast(Callable[Concatenate[object, P], R], expose(f))
|
return cast(Callable[Concatenate[object, P], R], expose(f))
|
||||||
|
|
||||||
# ============ Graph Executions ============ #
|
# Executions
|
||||||
get_child_graph_executions = _(get_child_graph_executions)
|
get_child_graph_executions = _(get_child_graph_executions)
|
||||||
get_graph_executions = _(get_graph_executions)
|
get_graph_executions = _(get_graph_executions)
|
||||||
get_graph_executions_count = _(get_graph_executions_count)
|
get_graph_executions_count = _(get_graph_executions_count)
|
||||||
@@ -207,37 +170,36 @@ class DatabaseManager(AppService):
|
|||||||
get_frequently_executed_graphs = _(get_frequently_executed_graphs)
|
get_frequently_executed_graphs = _(get_frequently_executed_graphs)
|
||||||
get_marketplace_graphs_for_monitoring = _(get_marketplace_graphs_for_monitoring)
|
get_marketplace_graphs_for_monitoring = _(get_marketplace_graphs_for_monitoring)
|
||||||
|
|
||||||
# ============ Graphs ============ #
|
# Graphs
|
||||||
get_node = _(get_node)
|
get_node = _(get_node)
|
||||||
get_graph = _(get_graph)
|
get_graph = _(get_graph)
|
||||||
get_connected_output_nodes = _(get_connected_output_nodes)
|
get_connected_output_nodes = _(get_connected_output_nodes)
|
||||||
get_graph_metadata = _(get_graph_metadata)
|
get_graph_metadata = _(get_graph_metadata)
|
||||||
get_graph_settings = _(get_graph_settings)
|
get_graph_settings = _(get_graph_settings)
|
||||||
get_store_listed_graphs = _(get_store_listed_graphs)
|
|
||||||
|
|
||||||
# ============ Credits ============ #
|
# Credits
|
||||||
spend_credits = _(_spend_credits, name="spend_credits")
|
spend_credits = _(_spend_credits, name="spend_credits")
|
||||||
get_credits = _(_get_credits, name="get_credits")
|
get_credits = _(_get_credits, name="get_credits")
|
||||||
|
|
||||||
# ============ User + Integrations ============ #
|
# User + User Metadata + User Integrations
|
||||||
get_user_by_id = _(get_user_by_id)
|
|
||||||
get_user_integrations = _(get_user_integrations)
|
get_user_integrations = _(get_user_integrations)
|
||||||
update_user_integrations = _(update_user_integrations)
|
update_user_integrations = _(update_user_integrations)
|
||||||
|
|
||||||
# ============ User Comms ============ #
|
# User Comms - async
|
||||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||||
|
get_user_by_id = _(get_user_by_id)
|
||||||
get_user_email_by_id = _(get_user_email_by_id)
|
get_user_email_by_id = _(get_user_email_by_id)
|
||||||
get_user_email_verification = _(get_user_email_verification)
|
get_user_email_verification = _(get_user_email_verification)
|
||||||
get_user_notification_preference = _(get_user_notification_preference)
|
get_user_notification_preference = _(get_user_notification_preference)
|
||||||
|
|
||||||
# ============ Human In The Loop ============ #
|
# Human In The Loop
|
||||||
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
||||||
check_approval = _(check_approval)
|
check_approval = _(check_approval)
|
||||||
get_or_create_human_review = _(get_or_create_human_review)
|
get_or_create_human_review = _(get_or_create_human_review)
|
||||||
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
|
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
|
||||||
update_review_processed_status = _(update_review_processed_status)
|
update_review_processed_status = _(update_review_processed_status)
|
||||||
|
|
||||||
# ============ Notifications ============ #
|
# Notifications - async
|
||||||
clear_all_user_notification_batches = _(clear_all_user_notification_batches)
|
clear_all_user_notification_batches = _(clear_all_user_notification_batches)
|
||||||
create_or_add_to_user_notification_batch = _(
|
create_or_add_to_user_notification_batch = _(
|
||||||
create_or_add_to_user_notification_batch
|
create_or_add_to_user_notification_batch
|
||||||
@@ -250,62 +212,29 @@ class DatabaseManager(AppService):
|
|||||||
get_user_notification_oldest_message_in_batch
|
get_user_notification_oldest_message_in_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
# ============ Library ============ #
|
# Library
|
||||||
list_library_agents = _(list_library_agents)
|
list_library_agents = _(list_library_agents)
|
||||||
add_store_agent_to_library = _(add_store_agent_to_library)
|
add_store_agent_to_library = _(add_store_agent_to_library)
|
||||||
create_graph_in_library = _(create_graph_in_library)
|
|
||||||
create_library_agent = _(create_library_agent)
|
|
||||||
get_library_agent = _(get_library_agent)
|
|
||||||
get_library_agent_by_graph_id = _(get_library_agent_by_graph_id)
|
|
||||||
update_graph_in_library = _(update_graph_in_library)
|
|
||||||
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
||||||
|
|
||||||
# ============ Onboarding ============ #
|
# Onboarding
|
||||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||||
|
|
||||||
# ============ OAuth ============ #
|
# OAuth
|
||||||
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
|
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
|
||||||
|
|
||||||
# ============ Store ============ #
|
# Store
|
||||||
get_store_agents = _(get_store_agents)
|
get_store_agents = _(get_store_agents)
|
||||||
get_store_agent_details = _(get_store_agent_details)
|
get_store_agent_details = _(get_store_agent_details)
|
||||||
get_agent = _(get_agent)
|
|
||||||
get_available_graph = _(get_available_graph)
|
|
||||||
|
|
||||||
# ============ Search ============ #
|
# Store Embeddings
|
||||||
get_embedding_stats = _(get_embedding_stats)
|
get_embedding_stats = _(get_embedding_stats)
|
||||||
backfill_missing_embeddings = _(backfill_missing_embeddings)
|
backfill_missing_embeddings = _(backfill_missing_embeddings)
|
||||||
cleanup_orphaned_embeddings = _(cleanup_orphaned_embeddings)
|
cleanup_orphaned_embeddings = _(cleanup_orphaned_embeddings)
|
||||||
unified_hybrid_search = _(unified_hybrid_search)
|
|
||||||
|
|
||||||
# ============ Summary Data ============ #
|
# Summary data - async
|
||||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||||
|
|
||||||
# ============ Workspace ============ #
|
|
||||||
count_workspace_files = _(count_workspace_files)
|
|
||||||
create_workspace_file = _(create_workspace_file)
|
|
||||||
get_or_create_workspace = _(get_or_create_workspace)
|
|
||||||
get_workspace_file = _(get_workspace_file)
|
|
||||||
get_workspace_file_by_path = _(get_workspace_file_by_path)
|
|
||||||
list_workspace_files = _(list_workspace_files)
|
|
||||||
soft_delete_workspace_file = _(soft_delete_workspace_file)
|
|
||||||
|
|
||||||
# ============ Understanding ============ #
|
|
||||||
get_business_understanding = _(get_business_understanding)
|
|
||||||
upsert_business_understanding = _(upsert_business_understanding)
|
|
||||||
|
|
||||||
# ============ CoPilot Chat Sessions ============ #
|
|
||||||
get_chat_session = _(chat_db.get_chat_session)
|
|
||||||
create_chat_session = _(chat_db.create_chat_session)
|
|
||||||
update_chat_session = _(chat_db.update_chat_session)
|
|
||||||
add_chat_message = _(chat_db.add_chat_message)
|
|
||||||
add_chat_messages_batch = _(chat_db.add_chat_messages_batch)
|
|
||||||
get_user_chat_sessions = _(chat_db.get_user_chat_sessions)
|
|
||||||
get_user_session_count = _(chat_db.get_user_session_count)
|
|
||||||
delete_chat_session = _(chat_db.delete_chat_session)
|
|
||||||
get_chat_session_message_count = _(chat_db.get_chat_session_message_count)
|
|
||||||
update_tool_message_content = _(chat_db.update_tool_message_content)
|
|
||||||
|
|
||||||
|
|
||||||
class DatabaseManagerClient(AppServiceClient):
|
class DatabaseManagerClient(AppServiceClient):
|
||||||
d = DatabaseManager
|
d = DatabaseManager
|
||||||
@@ -367,50 +296,43 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
def get_service_type(cls):
|
def get_service_type(cls):
|
||||||
return DatabaseManager
|
return DatabaseManager
|
||||||
|
|
||||||
# ============ Graph Executions ============ #
|
|
||||||
create_graph_execution = d.create_graph_execution
|
create_graph_execution = d.create_graph_execution
|
||||||
get_child_graph_executions = d.get_child_graph_executions
|
get_child_graph_executions = d.get_child_graph_executions
|
||||||
get_connected_output_nodes = d.get_connected_output_nodes
|
get_connected_output_nodes = d.get_connected_output_nodes
|
||||||
get_latest_node_execution = d.get_latest_node_execution
|
get_latest_node_execution = d.get_latest_node_execution
|
||||||
get_graph_execution = d.get_graph_execution
|
|
||||||
get_graph_execution_meta = d.get_graph_execution_meta
|
|
||||||
get_graph_executions = d.get_graph_executions
|
|
||||||
get_node_execution = d.get_node_execution
|
|
||||||
get_node_executions = d.get_node_executions
|
|
||||||
update_graph_execution_stats = d.update_graph_execution_stats
|
|
||||||
update_node_execution_status = d.update_node_execution_status
|
|
||||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
|
||||||
upsert_execution_input = d.upsert_execution_input
|
|
||||||
upsert_execution_output = d.upsert_execution_output
|
|
||||||
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
|
|
||||||
get_execution_kv_data = d.get_execution_kv_data
|
|
||||||
set_execution_kv_data = d.set_execution_kv_data
|
|
||||||
|
|
||||||
# ============ Graphs ============ #
|
|
||||||
get_graph = d.get_graph
|
get_graph = d.get_graph
|
||||||
get_graph_metadata = d.get_graph_metadata
|
get_graph_metadata = d.get_graph_metadata
|
||||||
get_graph_settings = d.get_graph_settings
|
get_graph_settings = d.get_graph_settings
|
||||||
|
get_graph_execution = d.get_graph_execution
|
||||||
|
get_graph_execution_meta = d.get_graph_execution_meta
|
||||||
get_node = d.get_node
|
get_node = d.get_node
|
||||||
get_store_listed_graphs = d.get_store_listed_graphs
|
get_node_execution = d.get_node_execution
|
||||||
|
get_node_executions = d.get_node_executions
|
||||||
# ============ User + Integrations ============ #
|
|
||||||
get_user_by_id = d.get_user_by_id
|
get_user_by_id = d.get_user_by_id
|
||||||
get_user_integrations = d.get_user_integrations
|
get_user_integrations = d.get_user_integrations
|
||||||
|
upsert_execution_input = d.upsert_execution_input
|
||||||
|
upsert_execution_output = d.upsert_execution_output
|
||||||
|
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
|
||||||
|
update_graph_execution_stats = d.update_graph_execution_stats
|
||||||
|
update_node_execution_status = d.update_node_execution_status
|
||||||
|
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||||
update_user_integrations = d.update_user_integrations
|
update_user_integrations = d.update_user_integrations
|
||||||
|
get_execution_kv_data = d.get_execution_kv_data
|
||||||
|
set_execution_kv_data = d.set_execution_kv_data
|
||||||
|
|
||||||
# ============ Human In The Loop ============ #
|
# Human In The Loop
|
||||||
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
||||||
check_approval = d.check_approval
|
check_approval = d.check_approval
|
||||||
get_or_create_human_review = d.get_or_create_human_review
|
get_or_create_human_review = d.get_or_create_human_review
|
||||||
update_review_processed_status = d.update_review_processed_status
|
update_review_processed_status = d.update_review_processed_status
|
||||||
|
|
||||||
# ============ User Comms ============ #
|
# User Comms
|
||||||
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
|
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
|
||||||
get_user_email_by_id = d.get_user_email_by_id
|
get_user_email_by_id = d.get_user_email_by_id
|
||||||
get_user_email_verification = d.get_user_email_verification
|
get_user_email_verification = d.get_user_email_verification
|
||||||
get_user_notification_preference = d.get_user_notification_preference
|
get_user_notification_preference = d.get_user_notification_preference
|
||||||
|
|
||||||
# ============ Notifications ============ #
|
# Notifications
|
||||||
clear_all_user_notification_batches = d.clear_all_user_notification_batches
|
clear_all_user_notification_batches = d.clear_all_user_notification_batches
|
||||||
create_or_add_to_user_notification_batch = (
|
create_or_add_to_user_notification_batch = (
|
||||||
d.create_or_add_to_user_notification_batch
|
d.create_or_add_to_user_notification_batch
|
||||||
@@ -423,55 +345,20 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
d.get_user_notification_oldest_message_in_batch
|
d.get_user_notification_oldest_message_in_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
# ============ Library ============ #
|
# Library
|
||||||
list_library_agents = d.list_library_agents
|
list_library_agents = d.list_library_agents
|
||||||
add_store_agent_to_library = d.add_store_agent_to_library
|
add_store_agent_to_library = d.add_store_agent_to_library
|
||||||
create_graph_in_library = d.create_graph_in_library
|
|
||||||
create_library_agent = d.create_library_agent
|
|
||||||
get_library_agent = d.get_library_agent
|
|
||||||
get_library_agent_by_graph_id = d.get_library_agent_by_graph_id
|
|
||||||
update_graph_in_library = d.update_graph_in_library
|
|
||||||
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
||||||
|
|
||||||
# ============ Onboarding ============ #
|
# Onboarding
|
||||||
increment_onboarding_runs = d.increment_onboarding_runs
|
increment_onboarding_runs = d.increment_onboarding_runs
|
||||||
|
|
||||||
# ============ OAuth ============ #
|
# OAuth
|
||||||
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
|
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
|
||||||
|
|
||||||
# ============ Store ============ #
|
# Store
|
||||||
get_store_agents = d.get_store_agents
|
get_store_agents = d.get_store_agents
|
||||||
get_store_agent_details = d.get_store_agent_details
|
get_store_agent_details = d.get_store_agent_details
|
||||||
get_agent = d.get_agent
|
|
||||||
get_available_graph = d.get_available_graph
|
|
||||||
|
|
||||||
# ============ Search ============ #
|
# Summary data
|
||||||
unified_hybrid_search = d.unified_hybrid_search
|
|
||||||
|
|
||||||
# ============ Summary Data ============ #
|
|
||||||
get_user_execution_summary_data = d.get_user_execution_summary_data
|
get_user_execution_summary_data = d.get_user_execution_summary_data
|
||||||
|
|
||||||
# ============ Workspace ============ #
|
|
||||||
count_workspace_files = d.count_workspace_files
|
|
||||||
create_workspace_file = d.create_workspace_file
|
|
||||||
get_or_create_workspace = d.get_or_create_workspace
|
|
||||||
get_workspace_file = d.get_workspace_file
|
|
||||||
get_workspace_file_by_path = d.get_workspace_file_by_path
|
|
||||||
list_workspace_files = d.list_workspace_files
|
|
||||||
soft_delete_workspace_file = d.soft_delete_workspace_file
|
|
||||||
|
|
||||||
# ============ Understanding ============ #
|
|
||||||
get_business_understanding = d.get_business_understanding
|
|
||||||
upsert_business_understanding = d.upsert_business_understanding
|
|
||||||
|
|
||||||
# ============ CoPilot Chat Sessions ============ #
|
|
||||||
get_chat_session = d.get_chat_session
|
|
||||||
create_chat_session = d.create_chat_session
|
|
||||||
update_chat_session = d.update_chat_session
|
|
||||||
add_chat_message = d.add_chat_message
|
|
||||||
add_chat_messages_batch = d.add_chat_messages_batch
|
|
||||||
get_user_chat_sessions = d.get_user_chat_sessions
|
|
||||||
get_user_session_count = d.get_user_session_count
|
|
||||||
delete_chat_session = d.delete_chat_session
|
|
||||||
get_chat_session_message_count = d.get_chat_session_message_count
|
|
||||||
update_tool_message_content = d.update_tool_message_content
|
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
Helper functions for LLM registry initialization in executor context.
|
||||||
|
|
||||||
|
These functions handle refreshing the LLM registry when the executor starts
|
||||||
|
and subscribing to real-time updates via Redis pub/sub.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from backend.blocks._base import BlockSchema
|
||||||
|
from backend.data import db, llm_registry
|
||||||
|
from backend.data.block import initialize_blocks
|
||||||
|
from backend.data.block_cost_config import refresh_llm_costs
|
||||||
|
from backend.data.llm_registry import subscribe_to_registry_refresh
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def initialize_registry_for_executor() -> None:
|
||||||
|
"""
|
||||||
|
Initialize blocks and refresh LLM registry in the executor context.
|
||||||
|
|
||||||
|
This must run in the executor's event loop to have access to the database.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Connect to database if not already connected
|
||||||
|
if not db.is_connected():
|
||||||
|
await db.connect()
|
||||||
|
logger.info("[GraphExecutor] Connected to database for registry refresh")
|
||||||
|
|
||||||
|
# Initialize blocks (internally refreshes LLM registry and costs)
|
||||||
|
await initialize_blocks()
|
||||||
|
logger.info("[GraphExecutor] Blocks initialized")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"[GraphExecutor] Failed to refresh LLM registry on startup: %s",
|
||||||
|
exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_registry_on_notification() -> None:
|
||||||
|
"""Refresh LLM registry when notified via Redis pub/sub."""
|
||||||
|
try:
|
||||||
|
# Ensure DB is connected
|
||||||
|
if not db.is_connected():
|
||||||
|
await db.connect()
|
||||||
|
|
||||||
|
# Refresh registry and costs
|
||||||
|
await llm_registry.refresh_llm_registry()
|
||||||
|
await refresh_llm_costs()
|
||||||
|
|
||||||
|
# Clear block schema caches so they regenerate with new model options
|
||||||
|
BlockSchema.clear_all_schema_caches()
|
||||||
|
|
||||||
|
logger.info("[GraphExecutor] LLM registry refreshed from notification")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"[GraphExecutor] Failed to refresh LLM registry from notification: %s",
|
||||||
|
exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def subscribe_to_registry_updates() -> None:
|
||||||
|
"""Subscribe to Redis pub/sub for LLM registry refresh notifications."""
|
||||||
|
await subscribe_to_registry_refresh(refresh_registry_on_notification)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user