mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-17 02:58:01 -05:00
Compare commits
80 Commits
dependabot
...
fix/fronte
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b98b2df40 | ||
|
|
e55f05c7a8 | ||
|
|
4a9b13acb6 | ||
|
|
5ff669e999 | ||
|
|
ec03a13e26 | ||
|
|
b08851f5d7 | ||
|
|
8b1720e61d | ||
|
|
aa5a039c5e | ||
|
|
8b83bb8647 | ||
|
|
e80e4d9cbb | ||
|
|
375d33cca9 | ||
|
|
3b1b2fe30c | ||
|
|
af63b3678e | ||
|
|
631f1bd50a | ||
|
|
5ac941fe2f | ||
|
|
b01ea3fcbd | ||
|
|
3b09a94e3f | ||
|
|
61efee4139 | ||
|
|
e539280e98 | ||
|
|
db8b43bb3d | ||
|
|
923d8baedc | ||
|
|
a55b2e02dc | ||
|
|
6b6648b290 | ||
|
|
c0a9c0410b | ||
|
|
17a77b02c7 | ||
|
|
701fce83ca | ||
|
|
78d89d0faf | ||
|
|
f482eb668b | ||
|
|
4a52b7eca0 | ||
|
|
97847f59f7 | ||
|
|
22ca8955c5 | ||
|
|
43cbe2e011 | ||
|
|
a318832414 | ||
|
|
843c487500 | ||
|
|
47a3a5ef41 | ||
|
|
ec00aa951a | ||
|
|
36fb1ea004 | ||
|
|
a81ac150da | ||
|
|
49ee087496 | ||
|
|
fc25e008b3 | ||
|
|
b0855e8cf2 | ||
|
|
5e2146dd76 | ||
|
|
103a62c9da | ||
|
|
fc8434fb30 | ||
|
|
3ae08cd48e | ||
|
|
4db13837b9 | ||
|
|
df87867625 | ||
|
|
e503126170 | ||
|
|
7ee28197a3 | ||
|
|
818de26d24 | ||
|
|
cb08def96c | ||
|
|
ac2daee5f8 | ||
|
|
266e0d79d4 | ||
|
|
01f443190e | ||
|
|
bdba0033de | ||
|
|
b87c64ce38 | ||
|
|
003affca43 | ||
|
|
290d0d9a9b | ||
|
|
fba61c72ed | ||
|
|
79d45a15d0 | ||
|
|
66f0d97ca2 | ||
|
|
5894a8fcdf | ||
|
|
dff8efa35d | ||
|
|
e26822998f | ||
|
|
88731b1f76 | ||
|
|
c3e407ef09 | ||
|
|
08a60dcb9b | ||
|
|
de78d062a9 | ||
|
|
217e3718d7 | ||
|
|
3dbc03e488 | ||
|
|
b76b5a37c5 | ||
|
|
eed07b173a | ||
|
|
c5e8b0b08f | ||
|
|
cd3e35df9e | ||
|
|
4a7bc006a8 | ||
|
|
4c474417bc | ||
|
|
99e2261254 | ||
|
|
cab498fa8c | ||
|
|
22078671df | ||
|
|
0082a72657 |
37
.branchlet.json
Normal file
37
.branchlet.json
Normal file
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"worktreeCopyPatterns": [
|
||||
".env*",
|
||||
".vscode/**",
|
||||
".auth/**",
|
||||
".claude/**",
|
||||
"autogpt_platform/.env*",
|
||||
"autogpt_platform/backend/.env*",
|
||||
"autogpt_platform/frontend/.env*",
|
||||
"autogpt_platform/frontend/.auth/**",
|
||||
"autogpt_platform/db/docker/.env*"
|
||||
],
|
||||
"worktreeCopyIgnores": [
|
||||
"**/node_modules/**",
|
||||
"**/dist/**",
|
||||
"**/.git/**",
|
||||
"**/Thumbs.db",
|
||||
"**/.DS_Store",
|
||||
"**/.next/**",
|
||||
"**/__pycache__/**",
|
||||
"**/.ruff_cache/**",
|
||||
"**/.pytest_cache/**",
|
||||
"**/*.pyc",
|
||||
"**/playwright-report/**",
|
||||
"**/logs/**",
|
||||
"**/site/**"
|
||||
],
|
||||
"worktreePathTemplate": "$BASE_PATH.worktree",
|
||||
"postCreateCmd": [
|
||||
"cd autogpt_platform/autogpt_libs && poetry install",
|
||||
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
||||
"cd autogpt_platform/frontend && pnpm install",
|
||||
"cd docs && pip install -r requirements.txt"
|
||||
],
|
||||
"terminalCommand": "code .",
|
||||
"deleteBranchWithWorktree": false
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
# Ignore everything by default, selectively add things to context
|
||||
*
|
||||
|
||||
# Documentation (for embeddings/search)
|
||||
!docs/
|
||||
|
||||
# Platform - Libs
|
||||
!autogpt_platform/autogpt_libs/autogpt_libs/
|
||||
!autogpt_platform/autogpt_libs/pyproject.toml
|
||||
@@ -16,6 +19,7 @@
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
!autogpt_platform/backend/gen_prisma_types_stub.py
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
|
||||
4
.github/workflows/claude-dependabot.yml
vendored
4
.github/workflows/claude-dependabot.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
@@ -74,7 +74,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
|
||||
4
.github/workflows/claude.yml
vendored
4
.github/workflows/claude.yml
vendored
@@ -52,7 +52,7 @@ jobs:
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
@@ -90,7 +90,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
|
||||
14
.github/workflows/copilot-setup-steps.yml
vendored
14
.github/workflows/copilot-setup-steps.yml
vendored
@@ -34,7 +34,7 @@ jobs:
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
@@ -72,7 +72,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
@@ -108,6 +108,16 @@ jobs:
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
# Remove large unused tools to free disk space for Docker builds
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ jobs:
|
||||
ref: ${{ github.ref_name || 'master' }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
|
||||
6
.github/workflows/platform-backend-ci.yml
vendored
6
.github/workflows/platform-backend-ci.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
@@ -134,7 +134,7 @@ jobs:
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
- id: supabase
|
||||
name: Start Supabase
|
||||
@@ -176,7 +176,7 @@ jobs:
|
||||
}
|
||||
|
||||
- name: Run Database Migrations
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
run: poetry run prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
25
.github/workflows/platform-frontend-ci.yml
vendored
25
.github/workflows/platform-frontend-ci.yml
vendored
@@ -11,6 +11,7 @@ on:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event_name == 'merge_group' && format('merge-queue-{0}', github.ref) || format('{0}-{1}', github.ref, github.event.pull_request.number || github.sha) }}
|
||||
@@ -151,6 +152,14 @@ jobs:
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
@@ -226,13 +235,25 @@ jobs:
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload Playwright artifacts
|
||||
if: failure()
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Upload Playwright test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
|
||||
@@ -11,7 +11,7 @@ jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
# operations-per-run: 5000
|
||||
stale-issue-message: >
|
||||
|
||||
2
.github/workflows/repo-pr-label.yml
vendored
2
.github/workflows/repo-pr-label.yml
vendored
@@ -61,6 +61,6 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v5
|
||||
- uses: actions/labeler@v6
|
||||
with:
|
||||
sync-labels: true
|
||||
|
||||
2
.github/workflows/repo-workflow-checker.yml
vendored
2
.github/workflows/repo-workflow-checker.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
|
||||
@@ -6,12 +6,14 @@ start-core:
|
||||
|
||||
# Stop core services
|
||||
stop-core:
|
||||
docker compose stop deps
|
||||
docker compose stop
|
||||
|
||||
reset-db:
|
||||
docker compose stop db
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
@@ -33,6 +35,7 @@ init-env:
|
||||
migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
@@ -58,4 +61,4 @@ help:
|
||||
@echo " run-backend - Run the backend FastAPI server"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@echo " test-data - Run the test data creator"
|
||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||
|
||||
@@ -57,6 +57,9 @@ class APIKeySmith:
|
||||
|
||||
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
||||
"""Migrate a legacy hash to secure hash format."""
|
||||
if not raw_key.startswith(self.PREFIX):
|
||||
raise ValueError("Key without 'agpt_' prefix would fail validation")
|
||||
|
||||
salt = self._generate_salt()
|
||||
hash = self._hash_key_with_salt(raw_key, salt)
|
||||
return hash, salt.hex()
|
||||
|
||||
@@ -1,29 +1,25 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from .jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def add_auth_responses_to_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Set up custom OpenAPI schema generation that adds 401 responses
|
||||
Patch a FastAPI instance's `openapi()` method to add 401 responses
|
||||
to all authenticated endpoints.
|
||||
|
||||
This is needed when using HTTPBearer with auto_error=False to get proper
|
||||
401 responses instead of 403, but FastAPI only automatically adds security
|
||||
responses when auto_error=True.
|
||||
"""
|
||||
# Wrap current method to allow stacking OpenAPI schema modifiers like this
|
||||
wrapped_openapi = app.openapi
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
openapi_schema = wrapped_openapi()
|
||||
|
||||
# Add 401 response to all endpoints that have security requirements
|
||||
for path, methods in openapi_schema["paths"].items():
|
||||
|
||||
@@ -58,6 +58,13 @@ V0_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Langfuse Prompt Management
|
||||
# Used for managing the CoPilot system prompt externally
|
||||
# Get credentials from https://cloud.langfuse.com or your self-hosted instance
|
||||
LANGFUSE_PUBLIC_KEY=
|
||||
LANGFUSE_SECRET_KEY=
|
||||
LANGFUSE_HOST=https://cloud.langfuse.com
|
||||
|
||||
# OAuth Credentials
|
||||
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
|
||||
# e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
|
||||
1
autogpt_platform/backend/.gitignore
vendored
1
autogpt_platform/backend/.gitignore
vendored
@@ -18,3 +18,4 @@ load-tests/results/
|
||||
load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
migrations/*/rollback*.sql
|
||||
|
||||
@@ -48,7 +48,8 @@ RUN poetry install --no-ansi --no-root
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
RUN poetry run prisma generate
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
|
||||
@@ -99,6 +100,7 @@ COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migration
|
||||
FROM server_dependencies AS server
|
||||
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
COPY docs /app/docs
|
||||
RUN poetry install --no-ansi --only-root
|
||||
|
||||
ENV PORT=8000
|
||||
|
||||
@@ -108,7 +108,7 @@ import fastapi.testclient
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.server.v2.myroute import router
|
||||
from backend.api.features.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
@@ -149,7 +149,7 @@ These provide the easiest way to set up authentication mocking in test modules:
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from backend.server.v2.myroute import router
|
||||
from backend.api.features.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
@@ -3,12 +3,12 @@ from typing import Dict, Set
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from backend.api.model import NotificationPayload, WSMessage, WSMethod
|
||||
from backend.data.execution import (
|
||||
ExecutionEventType,
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.server.model import NotificationPayload, WSMessage, WSMethod
|
||||
|
||||
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
||||
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
|
||||
@@ -4,13 +4,13 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
from fastapi import WebSocket
|
||||
|
||||
from backend.api.conn_manager import ConnectionManager
|
||||
from backend.api.model import NotificationPayload, WSMessage, WSMethod
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import NotificationPayload, WSMessage, WSMethod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
25
autogpt_platform/backend/backend/api/external/fastapi_app.py
vendored
Normal file
25
autogpt_platform/backend/backend/api/external/fastapi_app.py
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
|
||||
from .v1.routes import v1_router
|
||||
|
||||
external_api = FastAPI(
|
||||
title="AutoGPT External API",
|
||||
description="External API for AutoGPT integrations",
|
||||
docs_url="/docs",
|
||||
version="1.0",
|
||||
)
|
||||
|
||||
external_api.add_middleware(SecurityHeadersMiddleware)
|
||||
external_api.include_router(v1_router, prefix="/v1")
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
external_api,
|
||||
service_name="external-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=True,
|
||||
)
|
||||
107
autogpt_platform/backend/backend/api/external/middleware.py
vendored
Normal file
107
autogpt_platform/backend/backend/api/external/middleware.py
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
from fastapi import HTTPException, Security, status
|
||||
from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer
|
||||
from prisma.enums import APIKeyPermission
|
||||
|
||||
from backend.data.auth.api_key import APIKeyInfo, validate_api_key
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.auth.oauth import (
|
||||
InvalidClientError,
|
||||
InvalidTokenError,
|
||||
OAuthAccessTokenInfo,
|
||||
validate_access_token,
|
||||
)
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
bearer_auth = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
|
||||
"""Middleware for API key authentication only"""
|
||||
if api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing API key"
|
||||
)
|
||||
|
||||
api_key_obj = await validate_api_key(api_key)
|
||||
|
||||
if not api_key_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||
)
|
||||
|
||||
return api_key_obj
|
||||
|
||||
|
||||
async def require_access_token(
|
||||
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
|
||||
) -> OAuthAccessTokenInfo:
|
||||
"""Middleware for OAuth access token authentication only"""
|
||||
if bearer is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing Authorization header",
|
||||
)
|
||||
|
||||
try:
|
||||
token_info, _ = await validate_access_token(bearer.credentials)
|
||||
except (InvalidClientError, InvalidTokenError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||
|
||||
return token_info
|
||||
|
||||
|
||||
async def require_auth(
|
||||
api_key: str | None = Security(api_key_header),
|
||||
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
"""
|
||||
Unified authentication middleware supporting both API keys and OAuth tokens.
|
||||
|
||||
Supports two authentication methods, which are checked in order:
|
||||
1. X-API-Key header (existing API key authentication)
|
||||
2. Authorization: Bearer <token> header (OAuth access token)
|
||||
|
||||
Returns:
|
||||
APIAuthorizationInfo: base class of both APIKeyInfo and OAuthAccessTokenInfo.
|
||||
"""
|
||||
# Try API key first
|
||||
if api_key is not None:
|
||||
api_key_info = await validate_api_key(api_key)
|
||||
if api_key_info:
|
||||
return api_key_info
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||
)
|
||||
|
||||
# Try OAuth bearer token
|
||||
if bearer is not None:
|
||||
try:
|
||||
token_info, _ = await validate_access_token(bearer.credentials)
|
||||
return token_info
|
||||
except (InvalidClientError, InvalidTokenError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||
|
||||
# No credentials provided
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing authentication. Provide API key or access token.",
|
||||
)
|
||||
|
||||
|
||||
def require_permission(permission: APIKeyPermission):
|
||||
"""
|
||||
Dependency function for checking specific permissions
|
||||
(works with API keys and OAuth tokens)
|
||||
"""
|
||||
|
||||
async def check_permission(
|
||||
auth: APIAuthorizationInfo = Security(require_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
if permission not in auth.scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing required permission: {permission.value}",
|
||||
)
|
||||
return auth
|
||||
|
||||
return check_permission
|
||||
@@ -16,7 +16,9 @@ from fastapi import APIRouter, Body, HTTPException, Path, Security, status
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.integrations.models import get_all_provider_names
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
@@ -28,8 +30,6 @@ from backend.data.model import (
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.server.integrations.models import get_all_provider_names
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -255,7 +255,7 @@ def _get_oauth_handler_for_external(
|
||||
|
||||
@integrations_router.get("/providers", response_model=list[ProviderInfo])
|
||||
async def list_providers(
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[ProviderInfo]:
|
||||
@@ -319,7 +319,7 @@ async def list_providers(
|
||||
async def initiate_oauth(
|
||||
provider: Annotated[str, Path(title="The OAuth provider")],
|
||||
request: OAuthInitiateRequest,
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||
),
|
||||
) -> OAuthInitiateResponse:
|
||||
@@ -337,7 +337,10 @@ async def initiate_oauth(
|
||||
if not validate_callback_url(request.callback_url):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Callback URL origin is not allowed. Allowed origins: {settings.config.external_oauth_callback_origins}",
|
||||
detail=(
|
||||
f"Callback URL origin is not allowed. "
|
||||
f"Allowed origins: {settings.config.external_oauth_callback_origins}",
|
||||
),
|
||||
)
|
||||
|
||||
# Validate provider
|
||||
@@ -359,13 +362,15 @@ async def initiate_oauth(
|
||||
)
|
||||
|
||||
# Store state token with external flow metadata
|
||||
# Note: initiated_by_api_key_id is only available for API key auth, not OAuth
|
||||
api_key_id = getattr(auth, "id", None) if auth.type == "api_key" else None
|
||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
provider=provider if isinstance(provider_name, str) else provider_name.value,
|
||||
scopes=request.scopes,
|
||||
callback_url=request.callback_url,
|
||||
state_metadata=request.state_metadata,
|
||||
initiated_by_api_key_id=api_key.id,
|
||||
initiated_by_api_key_id=api_key_id,
|
||||
)
|
||||
|
||||
# Build login URL
|
||||
@@ -393,7 +398,7 @@ async def initiate_oauth(
|
||||
async def complete_oauth(
|
||||
provider: Annotated[str, Path(title="The OAuth provider")],
|
||||
request: OAuthCompleteRequest,
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||
),
|
||||
) -> OAuthCompleteResponse:
|
||||
@@ -406,7 +411,7 @@ async def complete_oauth(
|
||||
"""
|
||||
# Verify state token
|
||||
valid_state = await creds_manager.store.verify_state_token(
|
||||
api_key.user_id, request.state_token, provider
|
||||
auth.user_id, request.state_token, provider
|
||||
)
|
||||
|
||||
if not valid_state:
|
||||
@@ -453,7 +458,7 @@ async def complete_oauth(
|
||||
)
|
||||
|
||||
# Store credentials
|
||||
await creds_manager.create(api_key.user_id, credentials)
|
||||
await creds_manager.create(auth.user_id, credentials)
|
||||
|
||||
logger.info(f"Successfully completed external OAuth for provider {provider}")
|
||||
|
||||
@@ -470,7 +475,7 @@ async def complete_oauth(
|
||||
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
|
||||
async def list_credentials(
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
@@ -479,7 +484,7 @@ async def list_credentials(
|
||||
|
||||
Returns metadata about each credential without exposing sensitive tokens.
|
||||
"""
|
||||
credentials = await creds_manager.store.get_all_creds(api_key.user_id)
|
||||
credentials = await creds_manager.store.get_all_creds(auth.user_id)
|
||||
return [
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
@@ -499,7 +504,7 @@ async def list_credentials(
|
||||
)
|
||||
async def list_credentials_by_provider(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
@@ -507,7 +512,7 @@ async def list_credentials_by_provider(
|
||||
List credentials for a specific provider.
|
||||
"""
|
||||
credentials = await creds_manager.store.get_creds_by_provider(
|
||||
api_key.user_id, provider
|
||||
auth.user_id, provider
|
||||
)
|
||||
return [
|
||||
CredentialSummary(
|
||||
@@ -536,7 +541,7 @@ async def create_credential(
|
||||
CreateUserPasswordCredentialRequest,
|
||||
CreateHostScopedCredentialRequest,
|
||||
] = Body(..., discriminator="type"),
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||
),
|
||||
) -> CreateCredentialResponse:
|
||||
@@ -591,7 +596,7 @@ async def create_credential(
|
||||
|
||||
# Store credentials
|
||||
try:
|
||||
await creds_manager.create(api_key.user_id, credentials)
|
||||
await creds_manager.create(auth.user_id, credentials)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store credentials: {e}")
|
||||
raise HTTPException(
|
||||
@@ -623,7 +628,7 @@ class DeleteCredentialResponse(BaseModel):
|
||||
async def delete_credential(
|
||||
provider: Annotated[str, Path(title="The provider")],
|
||||
cred_id: Annotated[str, Path(title="The credential ID to delete")],
|
||||
api_key: APIKeyInfo = Security(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.DELETE_INTEGRATIONS)
|
||||
),
|
||||
) -> DeleteCredentialResponse:
|
||||
@@ -634,7 +639,7 @@ async def delete_credential(
|
||||
use the main API's delete endpoint which handles webhook cleanup and
|
||||
token revocation.
|
||||
"""
|
||||
creds = await creds_manager.store.get_creds_by_id(api_key.user_id, cred_id)
|
||||
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
@@ -645,6 +650,6 @@ async def delete_credential(
|
||||
detail="Credentials do not match the specified provider",
|
||||
)
|
||||
|
||||
await creds_manager.delete(api_key.user_id, cred_id)
|
||||
await creds_manager.delete(auth.user_id, cred_id)
|
||||
|
||||
return DeleteCredentialResponse(deleted=True, credentials_id=cred_id)
|
||||
@@ -5,46 +5,60 @@ from typing import Annotated, Any, Literal, Optional, Sequence
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.data.block
|
||||
import backend.server.v2.store.cache as store_cache
|
||||
import backend.server.v2.store.model as store_model
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.data import user as user_db
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .integrations import integrations_router
|
||||
from .tools import tools_router
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
v1_router = APIRouter()
|
||||
|
||||
|
||||
class NodeOutput(TypedDict):
|
||||
key: str
|
||||
value: Any
|
||||
v1_router.include_router(integrations_router)
|
||||
v1_router.include_router(tools_router)
|
||||
|
||||
|
||||
class ExecutionNode(TypedDict):
|
||||
node_id: str
|
||||
input: Any
|
||||
output: dict[str, Any]
|
||||
class UserInfoResponse(BaseModel):
|
||||
id: str
|
||||
name: Optional[str]
|
||||
email: str
|
||||
timezone: str = Field(
|
||||
description="The user's last known timezone (e.g. 'Europe/Amsterdam'), "
|
||||
"or 'not-set' if not set"
|
||||
)
|
||||
|
||||
|
||||
class ExecutionNodeOutput(TypedDict):
|
||||
node_id: str
|
||||
outputs: list[NodeOutput]
|
||||
@v1_router.get(
|
||||
path="/me",
|
||||
tags=["user", "meta"],
|
||||
)
|
||||
async def get_user_info(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.IDENTITY)
|
||||
),
|
||||
) -> UserInfoResponse:
|
||||
user = await user_db.get_user_by_id(auth.user_id)
|
||||
|
||||
|
||||
class GraphExecutionResult(TypedDict):
|
||||
execution_id: str
|
||||
status: str
|
||||
nodes: list[ExecutionNode]
|
||||
output: Optional[list[dict[str, str]]]
|
||||
return UserInfoResponse(
|
||||
id=user.id,
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
timezone=user.timezone,
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -65,7 +79,9 @@ async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
async def execute_graph_block(
|
||||
block_id: str,
|
||||
data: BlockInput,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
||||
),
|
||||
) -> CompletedBlockOutput:
|
||||
obj = backend.data.block.get_block(block_id)
|
||||
if not obj:
|
||||
@@ -85,12 +101,14 @@ async def execute_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.EXECUTE_GRAPH)
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
graph_exec = await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
inputs=node_input,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
@@ -100,6 +118,19 @@ async def execute_graph(
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
|
||||
|
||||
class ExecutionNode(TypedDict):
|
||||
node_id: str
|
||||
input: Any
|
||||
output: dict[str, Any]
|
||||
|
||||
|
||||
class GraphExecutionResult(TypedDict):
|
||||
execution_id: str
|
||||
status: str
|
||||
nodes: list[ExecutionNode]
|
||||
output: Optional[list[dict[str, str]]]
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/results",
|
||||
tags=["graphs"],
|
||||
@@ -107,10 +138,12 @@ async def execute_graph(
|
||||
async def get_graph_execution_results(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.READ_GRAPH)),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_GRAPH)
|
||||
),
|
||||
) -> GraphExecutionResult:
|
||||
graph_exec = await execution_db.get_graph_execution(
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
execution_id=graph_exec_id,
|
||||
include_node_executions=True,
|
||||
)
|
||||
@@ -122,7 +155,7 @@ async def get_graph_execution_results(
|
||||
if not await graph_db.get_graph(
|
||||
graph_id=graph_exec.graph_id,
|
||||
version=graph_exec.graph_version,
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
):
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
@@ -14,19 +14,19 @@ from fastapi import APIRouter, Security
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools import find_agent_tool, run_agent_tool
|
||||
from backend.server.v2.chat.tools.models import ToolResponseBase
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
|
||||
from backend.api.features.chat.tools.models import ToolResponseBase
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
tools_router = APIRouter(prefix="/tools", tags=["tools"])
|
||||
|
||||
# Note: We use Security() as a function parameter dependency (api_key: APIKeyInfo = Security(...))
|
||||
# Note: We use Security() as a function parameter dependency (auth: APIAuthorizationInfo = Security(...))
|
||||
# rather than in the decorator's dependencies= list. This avoids duplicate permission checks
|
||||
# while still enforcing auth AND giving us access to the api_key for extracting user_id.
|
||||
# while still enforcing auth AND giving us access to auth for extracting user_id.
|
||||
|
||||
|
||||
# Request models
|
||||
@@ -70,7 +70,7 @@ class RunAgentRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
def _create_ephemeral_session(user_id: str | None) -> ChatSession:
|
||||
def _create_ephemeral_session(user_id: str) -> ChatSession:
|
||||
"""Create an ephemeral session for stateless API requests."""
|
||||
return ChatSession.new(user_id)
|
||||
|
||||
@@ -80,7 +80,9 @@ def _create_ephemeral_session(user_id: str | None) -> ChatSession:
|
||||
)
|
||||
async def find_agent(
|
||||
request: FindAgentRequest,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.USE_TOOLS)
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Search for agents in the marketplace based on capabilities and user needs.
|
||||
@@ -91,9 +93,9 @@ async def find_agent(
|
||||
Returns:
|
||||
List of matching agents or no results response
|
||||
"""
|
||||
session = _create_ephemeral_session(api_key.user_id)
|
||||
session = _create_ephemeral_session(auth.user_id)
|
||||
result = await find_agent_tool._execute(
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
session=session,
|
||||
query=request.query,
|
||||
)
|
||||
@@ -105,7 +107,9 @@ async def find_agent(
|
||||
)
|
||||
async def run_agent(
|
||||
request: RunAgentRequest,
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.USE_TOOLS)
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run or schedule an agent from the marketplace.
|
||||
@@ -129,9 +133,9 @@ async def run_agent(
|
||||
- execution_started: If agent was run or scheduled successfully
|
||||
- error: If something went wrong
|
||||
"""
|
||||
session = _create_ephemeral_session(api_key.user_id)
|
||||
session = _create_ephemeral_session(auth.user_id)
|
||||
result = await run_agent_tool._execute(
|
||||
user_id=api_key.user_id,
|
||||
user_id=auth.user_id,
|
||||
session=session,
|
||||
username_agent_slug=request.username_agent_slug,
|
||||
inputs=request.inputs,
|
||||
@@ -6,9 +6,10 @@ from fastapi import APIRouter, Body, Security
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import admin_get_user_history, get_user_credit_model
|
||||
from backend.server.v2.admin.model import AddUserCreditsResponse, UserHistoryResponse
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
from .model import AddUserCreditsResponse, UserHistoryResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -9,14 +9,15 @@ import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.v2.admin.credit_admin_routes as credit_admin_routes
|
||||
import backend.server.v2.admin.model as admin_model
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .credit_admin_routes import router as credit_admin_router
|
||||
from .model import UserHistoryResponse
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(credit_admin_routes.router)
|
||||
app.include_router(credit_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
@@ -30,7 +31,7 @@ def setup_app_admin_auth(mock_jwt_admin):
|
||||
|
||||
|
||||
def test_add_user_credits_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
admin_user_id: str,
|
||||
target_user_id: str,
|
||||
@@ -42,7 +43,7 @@ def test_add_user_credits_success(
|
||||
return_value=(1500, "transaction-123-uuid")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
|
||||
"backend.api.features.admin.credit_admin_routes.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
@@ -84,7 +85,7 @@ def test_add_user_credits_success(
|
||||
|
||||
|
||||
def test_add_user_credits_negative_amount(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test credit deduction by admin (negative amount)"""
|
||||
@@ -94,7 +95,7 @@ def test_add_user_credits_negative_amount(
|
||||
return_value=(200, "transaction-456-uuid")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
|
||||
"backend.api.features.admin.credit_admin_routes.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
@@ -119,12 +120,12 @@ def test_add_user_credits_negative_amount(
|
||||
|
||||
|
||||
def test_get_user_history_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful retrieval of user credit history"""
|
||||
# Mock the admin_get_user_history function
|
||||
mock_history_response = admin_model.UserHistoryResponse(
|
||||
mock_history_response = UserHistoryResponse(
|
||||
history=[
|
||||
UserTransaction(
|
||||
user_id="user-1",
|
||||
@@ -150,7 +151,7 @@ def test_get_user_history_success(
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.admin_get_user_history",
|
||||
"backend.api.features.admin.credit_admin_routes.admin_get_user_history",
|
||||
return_value=mock_history_response,
|
||||
)
|
||||
|
||||
@@ -170,12 +171,12 @@ def test_get_user_history_success(
|
||||
|
||||
|
||||
def test_get_user_history_with_filters(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test user credit history with search and filter parameters"""
|
||||
# Mock the admin_get_user_history function
|
||||
mock_history_response = admin_model.UserHistoryResponse(
|
||||
mock_history_response = UserHistoryResponse(
|
||||
history=[
|
||||
UserTransaction(
|
||||
user_id="user-3",
|
||||
@@ -194,7 +195,7 @@ def test_get_user_history_with_filters(
|
||||
)
|
||||
|
||||
mock_get_history = mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.admin_get_user_history",
|
||||
"backend.api.features.admin.credit_admin_routes.admin_get_user_history",
|
||||
return_value=mock_history_response,
|
||||
)
|
||||
|
||||
@@ -230,12 +231,12 @@ def test_get_user_history_with_filters(
|
||||
|
||||
|
||||
def test_get_user_history_empty_results(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test user credit history with no results"""
|
||||
# Mock empty history response
|
||||
mock_history_response = admin_model.UserHistoryResponse(
|
||||
mock_history_response = UserHistoryResponse(
|
||||
history=[],
|
||||
pagination=Pagination(
|
||||
total_items=0,
|
||||
@@ -246,7 +247,7 @@ def test_get_user_history_empty_results(
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.admin_get_user_history",
|
||||
"backend.api.features.admin.credit_admin_routes.admin_get_user_history",
|
||||
return_value=mock_history_response,
|
||||
)
|
||||
|
||||
@@ -7,9 +7,9 @@ import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
|
||||
import backend.server.v2.store.cache as store_cache
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.model
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.util.json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -24,7 +24,7 @@ router = fastapi.APIRouter(
|
||||
@router.get(
|
||||
"/listings",
|
||||
summary="Get Admin Listings History",
|
||||
response_model=backend.server.v2.store.model.StoreListingsWithVersionsResponse,
|
||||
response_model=store_model.StoreListingsWithVersionsResponse,
|
||||
)
|
||||
async def get_admin_listings_with_versions(
|
||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||
@@ -48,7 +48,7 @@ async def get_admin_listings_with_versions(
|
||||
StoreListingsWithVersionsResponse with listings and their versions
|
||||
"""
|
||||
try:
|
||||
listings = await backend.server.v2.store.db.get_admin_listings_with_versions(
|
||||
listings = await store_db.get_admin_listings_with_versions(
|
||||
status=status,
|
||||
search_query=search,
|
||||
page=page,
|
||||
@@ -68,11 +68,11 @@ async def get_admin_listings_with_versions(
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/review",
|
||||
summary="Review Store Submission",
|
||||
response_model=backend.server.v2.store.model.StoreSubmission,
|
||||
response_model=store_model.StoreSubmission,
|
||||
)
|
||||
async def review_submission(
|
||||
store_listing_version_id: str,
|
||||
request: backend.server.v2.store.model.ReviewSubmissionRequest,
|
||||
request: store_model.ReviewSubmissionRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
@@ -87,12 +87,10 @@ async def review_submission(
|
||||
StoreSubmission with updated review information
|
||||
"""
|
||||
try:
|
||||
already_approved = (
|
||||
await backend.server.v2.store.db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
already_approved = await store_db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
submission = await backend.server.v2.store.db.review_store_submission(
|
||||
submission = await store_db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
@@ -136,7 +134,7 @@ async def admin_download_agent_file(
|
||||
Raises:
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
graph_data = await backend.server.v2.store.db.get_agent_as_admin(
|
||||
graph_data = await store_db.get_agent_as_admin(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
@@ -6,10 +6,11 @@ from typing import Annotated
|
||||
import fastapi
|
||||
import pydantic
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from autogpt_libs.auth.dependencies import requires_user
|
||||
|
||||
import backend.data.analytics
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
router = fastapi.APIRouter(dependencies=[fastapi.Security(requires_user)])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
340
autogpt_platform/backend/backend/api/features/analytics_test.py
Normal file
340
autogpt_platform/backend/backend/api/features/analytics_test.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""Tests for analytics API endpoints."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from .analytics import router as analytics_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(analytics_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module."""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_metric endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_metric_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw metric logging."""
|
||||
mock_result = Mock(id="metric-123-uuid")
|
||||
mock_log_metric = mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "page_load_time",
|
||||
"metric_value": 2.5,
|
||||
"data_string": "/dashboard",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "metric-123-uuid"
|
||||
|
||||
mock_log_metric.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
metric_name="page_load_time",
|
||||
metric_value=2.5,
|
||||
data_string="/dashboard",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"metric_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_metric_success",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metric_value,metric_name,data_string,test_id",
|
||||
[
|
||||
(100, "api_calls_count", "external_api", "integer_value"),
|
||||
(0, "error_count", "no_errors", "zero_value"),
|
||||
(-5.2, "temperature_delta", "cooling", "negative_value"),
|
||||
(1.23456789, "precision_test", "float_precision", "float_precision"),
|
||||
(999999999, "large_number", "max_value", "large_number"),
|
||||
(0.0000001, "tiny_number", "min_value", "tiny_number"),
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_various_values(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
metric_value: float,
|
||||
metric_name: str,
|
||||
data_string: str,
|
||||
test_id: str,
|
||||
) -> None:
|
||||
"""Test raw metric logging with various metric values."""
|
||||
mock_result = Mock(id=f"metric-{test_id}-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": metric_name,
|
||||
"metric_value": metric_value,
|
||||
"data_string": data_string,
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Failed for {test_id}: {response.text}"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"metric_id": response.json(), "test_case": test_id},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
f"analytics_metric_{test_id}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"metric_name": "test"}, "Field required"),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": "not_a_number", "data_string": "x"},
|
||||
"Input should be a valid number",
|
||||
),
|
||||
(
|
||||
{"metric_name": "", "metric_value": 1.0, "data_string": "test"},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": 1.0, "data_string": ""},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_metric_value_and_data_string",
|
||||
"invalid_metric_value_type",
|
||||
"empty_metric_name",
|
||||
"empty_data_string",
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid metric requests."""
|
||||
response = client.post("/log_raw_metric", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_metric_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database connection failed"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "test_metric",
|
||||
"metric_value": 1.0,
|
||||
"data_string": "test",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Database connection failed" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_analytics endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_analytics_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw analytics logging."""
|
||||
mock_result = Mock(id="analytics-789-uuid")
|
||||
mock_log_analytics = mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "user_action",
|
||||
"data": {
|
||||
"action": "button_click",
|
||||
"button_id": "submit_form",
|
||||
"timestamp": "2023-01-01T00:00:00Z",
|
||||
"metadata": {"form_type": "registration", "fields_filled": 5},
|
||||
},
|
||||
"data_index": "button_click_submit_form",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "analytics-789-uuid"
|
||||
|
||||
mock_log_analytics.assert_called_once_with(
|
||||
test_user_id,
|
||||
"user_action",
|
||||
request_data["data"],
|
||||
"button_click_submit_form",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"analytics_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_analytics_success",
|
||||
)
|
||||
|
||||
|
||||
def test_log_raw_analytics_complex_data(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test raw analytics logging with complex nested data structures."""
|
||||
mock_result = Mock(id="analytics-complex-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "agent_execution",
|
||||
"data": {
|
||||
"agent_id": "agent_123",
|
||||
"execution_id": "exec_456",
|
||||
"status": "completed",
|
||||
"duration_ms": 3500,
|
||||
"nodes_executed": 15,
|
||||
"blocks_used": [
|
||||
{"block_id": "llm_block", "count": 3},
|
||||
{"block_id": "http_block", "count": 5},
|
||||
{"block_id": "code_block", "count": 2},
|
||||
],
|
||||
"errors": [],
|
||||
"metadata": {
|
||||
"trigger": "manual",
|
||||
"user_tier": "premium",
|
||||
"environment": "production",
|
||||
},
|
||||
},
|
||||
"data_index": "agent_123_exec_456",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"analytics_id": response.json(), "logged_data": request_data["data"]},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
"analytics_log_analytics_complex_data",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"type": "test"}, "Field required"),
|
||||
(
|
||||
{"type": "test", "data": "not_a_dict", "data_index": "test"},
|
||||
"Input should be a valid dictionary",
|
||||
),
|
||||
({"type": "test", "data": {"key": "value"}}, "Field required"),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_data_and_data_index",
|
||||
"invalid_data_type",
|
||||
"missing_data_index",
|
||||
],
|
||||
)
|
||||
def test_log_raw_analytics_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid analytics requests."""
|
||||
response = client.post("/log_raw_analytics", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_analytics_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Analytics DB unreachable"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "test_event",
|
||||
"data": {"key": "value"},
|
||||
"data_index": "test_index",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Analytics DB unreachable" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
@@ -6,17 +6,20 @@ from typing import Sequence
|
||||
|
||||
import prisma
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.data.block
|
||||
import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.db as store_db
|
||||
import backend.server.v2.store.model as store_model
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.v2.builder.model import (
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .model import (
|
||||
BlockCategoryResponse,
|
||||
BlockResponse,
|
||||
BlockType,
|
||||
@@ -26,8 +29,6 @@ from backend.server.v2.builder.model import (
|
||||
ProviderResponse,
|
||||
SearchEntry,
|
||||
)
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
||||
@@ -2,8 +2,8 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.model as store_model
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.model as store_model
|
||||
from backend.data.block import BlockInfo
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.models import Pagination
|
||||
@@ -4,11 +4,12 @@ from typing import Annotated, Sequence
|
||||
import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
|
||||
import backend.server.v2.builder.db as builder_db
|
||||
import backend.server.v2.builder.model as builder_model
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import db as builder_db
|
||||
from . import model as builder_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Configuration management for chat system."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
@@ -12,7 +11,11 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
# OpenAI API Configuration
|
||||
model: str = Field(
|
||||
default="qwen/qwen3-235b-a22b-2507", description="Default model to use"
|
||||
default="anthropic/claude-opus-4.5", description="Default model to use"
|
||||
)
|
||||
title_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
description="Model to use for generating session titles (should be fast/cheap)",
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
@@ -23,12 +26,6 @@ class ChatConfig(BaseSettings):
|
||||
# Session TTL Configuration - 12 hours
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# System Prompt Configuration
|
||||
system_prompt_path: str = Field(
|
||||
default="prompts/chat_system.md",
|
||||
description="Path to system prompt file relative to chat module",
|
||||
)
|
||||
|
||||
# Streaming Configuration
|
||||
max_context_messages: int = Field(
|
||||
default=50, ge=1, le=200, description="Maximum context messages"
|
||||
@@ -41,6 +38,13 @@ class ChatConfig(BaseSettings):
|
||||
default=3, description="Maximum number of agent schedules"
|
||||
)
|
||||
|
||||
# Langfuse Prompt Management Configuration
|
||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||
langfuse_prompt_name: str = Field(
|
||||
default="CoPilot Prompt",
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def get_api_key(cls, v):
|
||||
@@ -72,43 +76,11 @@ class ChatConfig(BaseSettings):
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
def get_system_prompt(self, **template_vars) -> str:
|
||||
"""Load and render the system prompt from file.
|
||||
|
||||
Args:
|
||||
**template_vars: Variables to substitute in the template
|
||||
|
||||
Returns:
|
||||
Rendered system prompt string
|
||||
|
||||
"""
|
||||
# Get the path relative to this module
|
||||
module_dir = Path(__file__).parent
|
||||
prompt_path = module_dir / self.system_prompt_path
|
||||
|
||||
# Check for .j2 extension first (Jinja2 template)
|
||||
j2_path = Path(str(prompt_path) + ".j2")
|
||||
if j2_path.exists():
|
||||
try:
|
||||
from jinja2 import Template
|
||||
|
||||
template = Template(j2_path.read_text())
|
||||
return template.render(**template_vars)
|
||||
except ImportError:
|
||||
# Jinja2 not installed, fall back to reading as plain text
|
||||
return j2_path.read_text()
|
||||
|
||||
# Check for markdown file
|
||||
if prompt_path.exists():
|
||||
content = prompt_path.read_text()
|
||||
|
||||
# Simple variable substitution if Jinja2 is not available
|
||||
for key, value in template_vars.items():
|
||||
placeholder = f"{{{key}}}"
|
||||
content = content.replace(placeholder, str(value))
|
||||
|
||||
return content
|
||||
raise FileNotFoundError(f"System prompt file not found: {prompt_path}")
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
"onboarding": "prompts/onboarding_system.md",
|
||||
}
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
249
autogpt_platform/backend/backend/api/features/chat/db.py
Normal file
249
autogpt_platform/backend/backend/api/features/chat/db.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""Database operations for chat sessions."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import (
|
||||
ChatMessageCreateInput,
|
||||
ChatSessionCreateInput,
|
||||
ChatSessionUpdateInput,
|
||||
ChatSessionWhereInput,
|
||||
)
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
||||
"""Get a chat session by ID from the database."""
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
where={"id": session_id},
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
# Sort messages by sequence in Python - Prisma Python client doesn't support
|
||||
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
) -> PrismaChatSession:
|
||||
"""Create a new chat session in the database."""
|
||||
data = ChatSessionCreateInput(
|
||||
id=session_id,
|
||||
userId=user_id,
|
||||
credentials=SafeJson({}),
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
)
|
||||
return await PrismaChatSession.prisma().create(
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
|
||||
|
||||
async def update_chat_session(
|
||||
session_id: str,
|
||||
credentials: dict[str, Any] | None = None,
|
||||
successful_agent_runs: dict[str, Any] | None = None,
|
||||
successful_agent_schedules: dict[str, Any] | None = None,
|
||||
total_prompt_tokens: int | None = None,
|
||||
total_completion_tokens: int | None = None,
|
||||
title: str | None = None,
|
||||
) -> PrismaChatSession | None:
|
||||
"""Update a chat session's metadata."""
|
||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||
|
||||
if credentials is not None:
|
||||
data["credentials"] = SafeJson(credentials)
|
||||
if successful_agent_runs is not None:
|
||||
data["successfulAgentRuns"] = SafeJson(successful_agent_runs)
|
||||
if successful_agent_schedules is not None:
|
||||
data["successfulAgentSchedules"] = SafeJson(successful_agent_schedules)
|
||||
if total_prompt_tokens is not None:
|
||||
data["totalPromptTokens"] = total_prompt_tokens
|
||||
if total_completion_tokens is not None:
|
||||
data["totalCompletionTokens"] = total_completion_tokens
|
||||
if title is not None:
|
||||
data["title"] = title
|
||||
|
||||
session = await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
# Sort in Python - Prisma Python doesn't support order_by in include clauses
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def add_chat_message(
|
||||
session_id: str,
|
||||
role: str,
|
||||
sequence: int,
|
||||
content: str | None = None,
|
||||
name: str | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
refusal: str | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
function_call: dict[str, Any] | None = None,
|
||||
) -> PrismaChatMessage:
|
||||
"""Add a message to a chat session."""
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
||||
# because Prisma's TypedDict validation rejects optional fields set to None.
|
||||
# We only include fields that have values, then cast at the end.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": role,
|
||||
"sequence": sequence,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if content is not None:
|
||||
data["content"] = content
|
||||
if name is not None:
|
||||
data["name"] = name
|
||||
if tool_call_id is not None:
|
||||
data["toolCallId"] = tool_call_id
|
||||
if refusal is not None:
|
||||
data["refusal"] = refusal
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if tool_calls is not None:
|
||||
data["toolCalls"] = SafeJson(tool_calls)
|
||||
if function_call is not None:
|
||||
data["functionCall"] = SafeJson(function_call)
|
||||
|
||||
# Run message create and session timestamp update in parallel for lower latency
|
||||
_, message = await asyncio.gather(
|
||||
PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
),
|
||||
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
async def add_chat_messages_batch(
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
start_sequence: int,
|
||||
) -> list[PrismaChatMessage]:
|
||||
"""Add multiple messages to a chat session in a batch.
|
||||
|
||||
Uses a transaction for atomicity - if any message creation fails,
|
||||
the entire batch is rolled back.
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
created_messages = []
|
||||
|
||||
async with transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
# set to None. We only include fields that have values, then cast.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma(tx).create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
created_messages.append(created)
|
||||
|
||||
# Update session's updatedAt timestamp within the same transaction.
|
||||
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
||||
# separately via update_chat_session() after streaming completes.
|
||||
await PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return created_messages
|
||||
|
||||
|
||||
async def get_user_chat_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[PrismaChatSession]:
|
||||
"""Get chat sessions for a user, ordered by most recent."""
|
||||
return await PrismaChatSession.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
)
|
||||
|
||||
|
||||
async def get_user_session_count(user_id: str) -> int:
|
||||
"""Get the total number of chat sessions for a user."""
|
||||
return await PrismaChatSession.prisma().count(where={"userId": user_id})
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
|
||||
"""Delete a chat session and all its messages.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to delete.
|
||||
user_id: If provided, validates that the session belongs to this user
|
||||
before deletion. This prevents unauthorized deletion of other
|
||||
users' sessions.
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Build typed where clause with optional user_id validation
|
||||
where_clause: ChatSessionWhereInput = {"id": session_id}
|
||||
if user_id is not None:
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
result = await PrismaChatSession.prisma().delete_many(where=where_clause)
|
||||
if result == 0:
|
||||
logger.warning(
|
||||
f"No session deleted for {session_id} "
|
||||
f"(user_id validation: {user_id is not None})"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete chat session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_chat_session_message_count(session_id: str) -> int:
|
||||
"""Get the number of messages in a chat session."""
|
||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||
return count
|
||||
597
autogpt_platform/backend/backend/api/features/chat/model.py
Normal file
597
autogpt_platform/backend/backend/api/features/chat/model.py
Normal file
@@ -0,0 +1,597 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionDeveloperMessageParam,
|
||||
ChatCompletionFunctionMessageParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_assistant_message_param import FunctionCall
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
ChatCompletionMessageToolCallParam,
|
||||
Function,
|
||||
)
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util import json
|
||||
from backend.util.exceptions import DatabaseError, RedisError
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
# Redis cache key prefix for chat sessions
|
||||
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
||||
|
||||
|
||||
def _get_session_cache_key(session_id: str) -> str:
|
||||
"""Get the Redis cache key for a chat session."""
|
||||
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
||||
|
||||
|
||||
# Session-level locks to prevent race conditions during concurrent upserts.
|
||||
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
|
||||
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
|
||||
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
|
||||
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str | None = None
|
||||
name: str | None = None
|
||||
tool_call_id: str | None = None
|
||||
refusal: str | None = None
|
||||
tool_calls: list[dict] | None = None
|
||||
function_call: dict | None = None
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatSession(BaseModel):
|
||||
session_id: str
|
||||
user_id: str
|
||||
title: str | None = None
|
||||
messages: list[ChatMessage]
|
||||
usage: list[Usage]
|
||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||
started_at: datetime
|
||||
updated_at: datetime
|
||||
successful_agent_runs: dict[str, int] = {}
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
|
||||
@staticmethod
|
||||
def new(user_id: str) -> "ChatSession":
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title=None,
|
||||
messages=[],
|
||||
usage=[],
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
prisma_session: PrismaChatSession,
|
||||
prisma_messages: list[PrismaChatMessage] | None = None,
|
||||
) -> "ChatSession":
|
||||
"""Convert Prisma models to Pydantic ChatSession."""
|
||||
messages = []
|
||||
if prisma_messages:
|
||||
for msg in prisma_messages:
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
role=msg.role,
|
||||
content=msg.content,
|
||||
name=msg.name,
|
||||
tool_call_id=msg.toolCallId,
|
||||
refusal=msg.refusal,
|
||||
tool_calls=_parse_json_field(msg.toolCalls),
|
||||
function_call=_parse_json_field(msg.functionCall),
|
||||
)
|
||||
)
|
||||
|
||||
# Parse JSON fields from Prisma
|
||||
credentials = _parse_json_field(prisma_session.credentials, default={})
|
||||
successful_agent_runs = _parse_json_field(
|
||||
prisma_session.successfulAgentRuns, default={}
|
||||
)
|
||||
successful_agent_schedules = _parse_json_field(
|
||||
prisma_session.successfulAgentSchedules, default={}
|
||||
)
|
||||
|
||||
# Calculate usage from token counts
|
||||
usage = []
|
||||
if prisma_session.totalPromptTokens or prisma_session.totalCompletionTokens:
|
||||
usage.append(
|
||||
Usage(
|
||||
prompt_tokens=prisma_session.totalPromptTokens or 0,
|
||||
completion_tokens=prisma_session.totalCompletionTokens or 0,
|
||||
total_tokens=(prisma_session.totalPromptTokens or 0)
|
||||
+ (prisma_session.totalCompletionTokens or 0),
|
||||
)
|
||||
)
|
||||
|
||||
return ChatSession(
|
||||
session_id=prisma_session.id,
|
||||
user_id=prisma_session.userId,
|
||||
title=prisma_session.title,
|
||||
messages=messages,
|
||||
usage=usage,
|
||||
credentials=credentials,
|
||||
started_at=prisma_session.createdAt,
|
||||
updated_at=prisma_session.updatedAt,
|
||||
successful_agent_runs=successful_agent_runs,
|
||||
successful_agent_schedules=successful_agent_schedules,
|
||||
)
|
||||
|
||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||
messages = []
|
||||
for message in self.messages:
|
||||
if message.role == "developer":
|
||||
m = ChatCompletionDeveloperMessageParam(
|
||||
role="developer",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "system":
|
||||
m = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "user":
|
||||
m = ChatCompletionUserMessageParam(
|
||||
role="user",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "assistant":
|
||||
m = ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.function_call:
|
||||
m["function_call"] = FunctionCall(
|
||||
arguments=message.function_call["arguments"],
|
||||
name=message.function_call["name"],
|
||||
)
|
||||
if message.refusal:
|
||||
m["refusal"] = message.refusal
|
||||
if message.tool_calls:
|
||||
t: list[ChatCompletionMessageToolCallParam] = []
|
||||
for tool_call in message.tool_calls:
|
||||
# Tool calls are stored with nested structure: {id, type, function: {name, arguments}}
|
||||
function_data = tool_call.get("function", {})
|
||||
|
||||
# Skip tool calls that are missing required fields
|
||||
if "id" not in tool_call or "name" not in function_data:
|
||||
logger.warning(
|
||||
f"Skipping invalid tool call: missing required fields. "
|
||||
f"Got: {tool_call.keys()}, function keys: {function_data.keys()}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Arguments are stored as a JSON string
|
||||
arguments_str = function_data.get("arguments", "{}")
|
||||
|
||||
t.append(
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=tool_call["id"],
|
||||
type="function",
|
||||
function=Function(
|
||||
arguments=arguments_str,
|
||||
name=function_data["name"],
|
||||
),
|
||||
)
|
||||
)
|
||||
m["tool_calls"] = t
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
messages.append(m)
|
||||
elif message.role == "tool":
|
||||
messages.append(
|
||||
ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
content=message.content or "",
|
||||
tool_call_id=message.tool_call_id or "",
|
||||
)
|
||||
)
|
||||
elif message.role == "function":
|
||||
messages.append(
|
||||
ChatCompletionFunctionMessageParam(
|
||||
role="function",
|
||||
content=message.content,
|
||||
name=message.name or "",
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
|
||||
async def _cache_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis."""
|
||||
redis_key = _get_session_cache_key(session.session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
# Check if session exists in DB
|
||||
existing = await chat_db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await chat_db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await chat_db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> ChatSession | None:
|
||||
"""Get a chat session by ID.
|
||||
|
||||
Checks Redis cache first, falls back to database if not found.
|
||||
Caches database results back to Redis.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to fetch.
|
||||
user_id: If provided, validates that the session belongs to this user.
|
||||
If None, ownership is not validated (admin/system access).
|
||||
"""
|
||||
# Try cache first
|
||||
try:
|
||||
session = await _get_session_from_cache(session_id)
|
||||
if session:
|
||||
# Verify user ownership if user_id was provided for validation
|
||||
if user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||
)
|
||||
return None
|
||||
return session
|
||||
except RedisError:
|
||||
logger.warning(f"Cache error for session {session_id}, trying database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||
|
||||
# Fall back to database
|
||||
logger.info(f"Session {session_id} not in cache, checking database")
|
||||
session = await _get_session_from_db(session_id)
|
||||
|
||||
if session is None:
|
||||
logger.warning(f"Session {session_id} not found in cache or database")
|
||||
return None
|
||||
|
||||
# Verify user ownership if user_id was provided for validation
|
||||
if user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Cache the session from DB
|
||||
try:
|
||||
await _cache_session(session)
|
||||
logger.info(f"Cached session {session_id} from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
) -> ChatSession:
|
||||
"""Update a chat session in both cache and database.
|
||||
|
||||
Uses session-level locking to prevent race conditions when concurrent
|
||||
operations (e.g., background title update and main stream handler)
|
||||
attempt to upsert the same session simultaneously.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If the database write fails. The cache is still updated
|
||||
as a best-effort optimization, but the error is propagated to ensure
|
||||
callers are aware of the persistence failure.
|
||||
RedisError: If the cache write fails (after successful DB write).
|
||||
"""
|
||||
# Acquire session-specific lock to prevent concurrent upserts
|
||||
lock = await _get_session_lock(session.session_id)
|
||||
|
||||
async with lock:
|
||||
# Get existing message count from DB for incremental saves
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
|
||||
db_error: Exception | None = None
|
||||
|
||||
# Save to database (primary storage)
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save session {session.session_id} to database: {e}"
|
||||
)
|
||||
db_error = e
|
||||
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
if db_error is None:
|
||||
raise RedisError(
|
||||
f"Failed to persist chat session {session.session_id} to Redis: {e}"
|
||||
) from e
|
||||
# If both failed, log cache error but raise DB error (more critical)
|
||||
logger.warning(
|
||||
f"Cache write also failed for session {session.session_id}: {e}"
|
||||
)
|
||||
|
||||
# Propagate DB error after attempting cache (prevents data loss)
|
||||
if db_error is not None:
|
||||
raise DatabaseError(
|
||||
f"Failed to persist chat session {session.session_id} to database"
|
||||
) from db_error
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str) -> ChatSession:
|
||||
"""Create a new chat session and persist it.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If the database write fails. We fail fast to ensure
|
||||
callers never receive a non-persisted session that only exists
|
||||
in cache (which would be lost when the cache expires).
|
||||
"""
|
||||
session = ChatSession.new(user_id)
|
||||
|
||||
# Create in database first - fail fast if this fails
|
||||
try:
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create session {session.session_id} in database: {e}")
|
||||
raise DatabaseError(
|
||||
f"Failed to create chat session {session.session_id} in database"
|
||||
) from e
|
||||
|
||||
# Cache the session (best-effort optimization, DB is source of truth)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def get_user_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[ChatSession], int]:
|
||||
"""Get chat sessions for a user from the database with total count.
|
||||
|
||||
Returns:
|
||||
A tuple of (sessions, total_count) where total_count is the overall
|
||||
number of sessions for the user (not just the current page).
|
||||
"""
|
||||
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await chat_db.get_user_session_count(user_id)
|
||||
|
||||
sessions = []
|
||||
for prisma_session in prisma_sessions:
|
||||
# Convert without messages for listing (lighter weight)
|
||||
sessions.append(ChatSession.from_db(prisma_session, None))
|
||||
|
||||
return sessions, total_count
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
|
||||
"""Delete a chat session from both cache and database.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to delete.
|
||||
user_id: If provided, validates that the session belongs to this user
|
||||
before deletion. This prevents unauthorized deletion.
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False otherwise.
|
||||
"""
|
||||
# Delete from database first (with optional user_id validation)
|
||||
# This confirms ownership before invalidating cache
|
||||
deleted = await chat_db.delete_chat_session(session_id, user_id)
|
||||
|
||||
if not deleted:
|
||||
return False
|
||||
|
||||
# Only invalidate cache and clean up lock after DB confirms deletion
|
||||
try:
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
|
||||
|
||||
# Clean up session lock (belt-and-suspenders with WeakValueDictionary)
|
||||
async with _session_locks_mutex:
|
||||
_session_locks.pop(session_id, None)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def update_session_title(session_id: str, title: str) -> bool:
|
||||
"""Update only the title of a chat session.
|
||||
|
||||
This is a lightweight operation that doesn't touch messages, avoiding
|
||||
race conditions with concurrent message updates. Use this for background
|
||||
title generation instead of upsert_chat_session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
title: The new title to set.
|
||||
|
||||
Returns:
|
||||
True if updated successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
result = await chat_db.update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
|
||||
# Invalidate cache so next fetch gets updated title
|
||||
try:
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update title for session {session_id}: {e}")
|
||||
return False
|
||||
119
autogpt_platform/backend/backend/api/features/chat/model_test.py
Normal file
119
autogpt_platform/backend/backend/api/features/chat/model_test.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import pytest
|
||||
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
|
||||
messages = [
|
||||
ChatMessage(content="Hello, how are you?", role="user"),
|
||||
ChatMessage(
|
||||
content="I'm fine, thank you!",
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "t123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "New York"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
content="I'm using the tool to get the weather",
|
||||
role="tool",
|
||||
tool_call_id="t123",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_serialization_deserialization():
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s.messages = messages
|
||||
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
|
||||
serialized = s.model_dump_json()
|
||||
s2 = ChatSession.model_validate_json(serialized)
|
||||
assert s2.model_dump() == s.model_dump()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
|
||||
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages
|
||||
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(
|
||||
session_id=s.session_id,
|
||||
user_id=s.user_id,
|
||||
)
|
||||
|
||||
assert s2 == s
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage_user_id_mismatch(
|
||||
setup_test_user, test_user_id
|
||||
):
|
||||
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(s.session_id, "different_user_id")
|
||||
|
||||
assert s2 is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
||||
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
# Create session with messages including assistant message
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages # Contains user, assistant, and tool messages
|
||||
assert s.session_id is not None, "Session id is not set"
|
||||
# Upsert to save to both cache and DB
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
# Clear the Redis cache to force DB load
|
||||
redis_key = f"chat:session:{s.session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
|
||||
# Load from DB (cache was cleared)
|
||||
s2 = await get_chat_session(
|
||||
session_id=s.session_id,
|
||||
user_id=s.user_id,
|
||||
)
|
||||
|
||||
assert s2 is not None, "Session not found after loading from DB"
|
||||
assert len(s2.messages) == len(
|
||||
s.messages
|
||||
), f"Message count mismatch: expected {len(s.messages)}, got {len(s2.messages)}"
|
||||
|
||||
# Verify all roles are present
|
||||
roles = [m.role for m in s2.messages]
|
||||
assert "user" in roles, f"User message missing. Roles found: {roles}"
|
||||
assert "assistant" in roles, f"Assistant message missing. Roles found: {roles}"
|
||||
assert "tool" in roles, f"Tool message missing. Roles found: {roles}"
|
||||
|
||||
# Verify message content
|
||||
for orig, loaded in zip(s.messages, s2.messages):
|
||||
assert orig.role == loaded.role, f"Role mismatch: {orig.role} != {loaded.role}"
|
||||
assert (
|
||||
orig.content == loaded.content
|
||||
), f"Content mismatch for {orig.role}: {orig.content} != {loaded.content}"
|
||||
if orig.tool_calls:
|
||||
assert (
|
||||
loaded.tool_calls is not None
|
||||
), f"Tool calls missing for {orig.role} message"
|
||||
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
||||
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Response models for Vercel AI SDK UI Stream Protocol.
|
||||
|
||||
This module implements the AI SDK UI Stream Protocol (v1) for streaming chat responses.
|
||||
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of streaming responses following AI SDK protocol."""
|
||||
|
||||
# Message lifecycle
|
||||
START = "start"
|
||||
FINISH = "finish"
|
||||
|
||||
# Text streaming
|
||||
TEXT_START = "text-start"
|
||||
TEXT_DELTA = "text-delta"
|
||||
TEXT_END = "text-end"
|
||||
|
||||
# Tool interaction
|
||||
TOOL_INPUT_START = "tool-input-start"
|
||||
TOOL_INPUT_AVAILABLE = "tool-input-available"
|
||||
TOOL_OUTPUT_AVAILABLE = "tool-output-available"
|
||||
|
||||
# Other
|
||||
ERROR = "error"
|
||||
USAGE = "usage"
|
||||
|
||||
|
||||
class StreamBaseResponse(BaseModel):
|
||||
"""Base response model for all streaming responses."""
|
||||
|
||||
type: ResponseType
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
# ========== Message Lifecycle ==========
|
||||
|
||||
|
||||
class StreamStart(StreamBaseResponse):
|
||||
"""Start of a new message."""
|
||||
|
||||
type: ResponseType = ResponseType.START
|
||||
messageId: str = Field(..., description="Unique message ID")
|
||||
|
||||
|
||||
class StreamFinish(StreamBaseResponse):
|
||||
"""End of message/stream."""
|
||||
|
||||
type: ResponseType = ResponseType.FINISH
|
||||
|
||||
|
||||
# ========== Text Streaming ==========
|
||||
|
||||
|
||||
class StreamTextStart(StreamBaseResponse):
|
||||
"""Start of a text block."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_START
|
||||
id: str = Field(..., description="Text block ID")
|
||||
|
||||
|
||||
class StreamTextDelta(StreamBaseResponse):
|
||||
"""Streaming text content delta."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_DELTA
|
||||
id: str = Field(..., description="Text block ID")
|
||||
delta: str = Field(..., description="Text content delta")
|
||||
|
||||
|
||||
class StreamTextEnd(StreamBaseResponse):
|
||||
"""End of a text block."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_END
|
||||
id: str = Field(..., description="Text block ID")
|
||||
|
||||
|
||||
# ========== Tool Interaction ==========
|
||||
|
||||
|
||||
class StreamToolInputStart(StreamBaseResponse):
|
||||
"""Tool call started notification."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_INPUT_START
|
||||
toolCallId: str = Field(..., description="Unique tool call ID")
|
||||
toolName: str = Field(..., description="Name of the tool being called")
|
||||
|
||||
|
||||
class StreamToolInputAvailable(StreamBaseResponse):
|
||||
"""Tool input is ready for execution."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_INPUT_AVAILABLE
|
||||
toolCallId: str = Field(..., description="Unique tool call ID")
|
||||
toolName: str = Field(..., description="Name of the tool being called")
|
||||
input: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Tool input arguments"
|
||||
)
|
||||
|
||||
|
||||
class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
"""Tool execution result."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
||||
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
||||
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
||||
# Additional fields for internal use (not part of AI SDK spec but useful)
|
||||
toolName: str | None = Field(
|
||||
default=None, description="Name of the tool that was executed"
|
||||
)
|
||||
success: bool = Field(
|
||||
default=True, description="Whether the tool execution succeeded"
|
||||
)
|
||||
|
||||
|
||||
# ========== Other ==========
|
||||
|
||||
|
||||
class StreamUsage(StreamBaseResponse):
|
||||
"""Token usage statistics."""
|
||||
|
||||
type: ResponseType = ResponseType.USAGE
|
||||
promptTokens: int = Field(..., description="Number of prompt tokens")
|
||||
completionTokens: int = Field(..., description="Number of completion tokens")
|
||||
totalTokens: int = Field(..., description="Total number of tokens")
|
||||
|
||||
|
||||
class StreamError(StreamBaseResponse):
|
||||
"""Error response."""
|
||||
|
||||
type: ResponseType = ResponseType.ERROR
|
||||
errorText: str = Field(..., description="Error message text")
|
||||
code: str | None = Field(default=None, description="Error code")
|
||||
details: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional error details"
|
||||
)
|
||||
362
autogpt_platform/backend/backend/api/features/chat/routes.py
Normal file
362
autogpt_platform/backend/backend/api/features/chat/routes.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Query, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _validate_and_get_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> ChatSession:
|
||||
"""Validate session exists and belongs to user."""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
return session
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
tags=["chat"],
|
||||
)
|
||||
|
||||
# ========== Request/Response Models ==========
|
||||
|
||||
|
||||
class StreamChatRequest(BaseModel):
|
||||
"""Request model for streaming chat with optional context."""
|
||||
|
||||
message: str
|
||||
is_user_message: bool = True
|
||||
context: dict[str, str] | None = None # {url: str, content: str}
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
"""Response model containing information on a newly created chat session."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
user_id: str | None
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
"""Response model providing complete details for a chat session, including messages."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
"""Response model for a session summary (without messages)."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
title: str | None = None
|
||||
|
||||
|
||||
class ListSessionsResponse(BaseModel):
|
||||
"""Response model for listing chat sessions."""
|
||||
|
||||
sessions: list[SessionSummaryResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def list_sessions(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: int = Query(default=50, ge=1, le=100),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
) -> ListSessionsResponse:
|
||||
"""
|
||||
List chat sessions for the authenticated user.
|
||||
|
||||
Returns a paginated list of chat sessions belonging to the current user,
|
||||
ordered by most recently updated.
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user's ID.
|
||||
limit: Maximum number of sessions to return (1-100).
|
||||
offset: Number of sessions to skip for pagination.
|
||||
|
||||
Returns:
|
||||
ListSessionsResponse: List of session summaries and total count.
|
||||
"""
|
||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||
|
||||
return ListSessionsResponse(
|
||||
sessions=[
|
||||
SessionSummaryResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
title=session.title,
|
||||
)
|
||||
for session in sessions
|
||||
],
|
||||
total=total_count,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions",
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str, Depends(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
|
||||
Initiates a new chat session for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user ID parsed from the JWT (required).
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created session.
|
||||
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating session with user_id: "
|
||||
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
|
||||
)
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.
|
||||
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
logger.info(
|
||||
f"Returning session {session_id}: "
|
||||
f"message_count={len(messages)}, "
|
||||
f"roles={[m.get('role') for m in messages]}"
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_chat_post(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (POST with context support).
|
||||
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_chat_get(
|
||||
session_id: str,
|
||||
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
is_user_message: bool = Query(default=True),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (GET - legacy endpoint).
|
||||
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
message: The user's new message to process.
|
||||
user_id: Optional authenticated user ID.
|
||||
is_user_message: Whether the message is a user message.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/assign-user",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
)
|
||||
async def session_assign_user(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> dict:
|
||||
"""
|
||||
Assign an authenticated user to a chat session.
|
||||
|
||||
Used (typically post-login) to claim an existing anonymous session as the current authenticated user.
|
||||
|
||||
Args:
|
||||
session_id: The identifier for the (previously anonymous) session.
|
||||
user_id: The authenticated user's ID to associate with the session.
|
||||
|
||||
Returns:
|
||||
dict: Status of the assignment.
|
||||
|
||||
"""
|
||||
await chat_service.assign_user_to_session(session_id, user_id)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Health Check ==========
|
||||
|
||||
|
||||
@router.get("/health", status_code=200)
|
||||
async def health_check() -> dict:
|
||||
"""
|
||||
Health check endpoint for the chat service.
|
||||
|
||||
Performs a full cycle test of session creation and retrieval. Should always return healthy
|
||||
if the service and data layer are operational.
|
||||
|
||||
Returns:
|
||||
dict: A status dictionary indicating health, service name, and API version.
|
||||
|
||||
"""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Ensure health check user exists (required for FK constraint)
|
||||
health_check_user_id = "health-check-user"
|
||||
await get_or_create_user(
|
||||
{
|
||||
"sub": health_check_user_id,
|
||||
"email": "health-check@system.local",
|
||||
"user_metadata": {"name": "Health Check User"},
|
||||
}
|
||||
)
|
||||
|
||||
# Create and retrieve session to verify full data layer
|
||||
session = await create_chat_session(health_check_user_id)
|
||||
await get_chat_session(session.session_id, health_check_user_id)
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "chat",
|
||||
"version": "0.1.0",
|
||||
}
|
||||
904
autogpt_platform/backend/backend/api/features/chat/service.py
Normal file
904
autogpt_platform/backend/backend/api/features/chat/service.py
Normal file
@@ -0,0 +1,904 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from langfuse import Langfuse
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
APIStatusError,
|
||||
AsyncOpenAI,
|
||||
RateLimitError,
|
||||
)
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||
|
||||
from backend.data.understanding import (
|
||||
format_understanding_for_prompt,
|
||||
get_business_understanding,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from .response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from .tools import execute_tool, tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
settings = Settings()
|
||||
client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
# Langfuse client (lazy initialization)
|
||||
_langfuse_client: Langfuse | None = None
|
||||
|
||||
|
||||
class LangfuseNotConfiguredError(Exception):
|
||||
"""Raised when Langfuse is required but not configured."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _is_langfuse_configured() -> bool:
|
||||
"""Check if Langfuse credentials are configured."""
|
||||
return bool(
|
||||
settings.secrets.langfuse_public_key and settings.secrets.langfuse_secret_key
|
||||
)
|
||||
|
||||
|
||||
def _get_langfuse_client() -> Langfuse:
|
||||
"""Get or create the Langfuse client for prompt management and tracing."""
|
||||
global _langfuse_client
|
||||
if _langfuse_client is None:
|
||||
if not _is_langfuse_configured():
|
||||
raise LangfuseNotConfiguredError(
|
||||
"Langfuse is not configured. The chat feature requires Langfuse for prompt management. "
|
||||
"Please set the LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables."
|
||||
)
|
||||
_langfuse_client = Langfuse(
|
||||
public_key=settings.secrets.langfuse_public_key,
|
||||
secret_key=settings.secrets.langfuse_secret_key,
|
||||
host=settings.secrets.langfuse_host or "https://cloud.langfuse.com",
|
||||
)
|
||||
return _langfuse_client
|
||||
|
||||
|
||||
def _get_environment() -> str:
|
||||
"""Get the current environment name for Langfuse tagging."""
|
||||
return settings.config.app_env.value
|
||||
|
||||
|
||||
def _get_langfuse_prompt() -> str:
|
||||
"""Fetch the latest production prompt from Langfuse.
|
||||
|
||||
Returns:
|
||||
The compiled prompt text from Langfuse.
|
||||
|
||||
Raises:
|
||||
Exception: If Langfuse is unavailable or prompt fetch fails.
|
||||
"""
|
||||
try:
|
||||
langfuse = _get_langfuse_client()
|
||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||
prompt = langfuse.get_prompt(config.langfuse_prompt_name, cache_ttl_seconds=0)
|
||||
compiled = prompt.compile()
|
||||
logger.info(
|
||||
f"Fetched prompt '{config.langfuse_prompt_name}' from Langfuse "
|
||||
f"(version: {prompt.version})"
|
||||
)
|
||||
return compiled
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch prompt from Langfuse: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def _is_first_session(user_id: str) -> bool:
|
||||
"""Check if this is the user's first chat session.
|
||||
|
||||
Returns True if the user has 1 or fewer sessions (meaning this is their first).
|
||||
"""
|
||||
try:
|
||||
session_count = await chat_db.get_user_session_count(user_id)
|
||||
return session_count <= 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check session count for user {user_id}: {e}")
|
||||
return False # Default to non-onboarding if we can't check
|
||||
|
||||
|
||||
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||
"""Build the full system prompt including business understanding if available.
|
||||
|
||||
Args:
|
||||
user_id: The user ID for fetching business understanding
|
||||
If "default" and this is the user's first session, will use "onboarding" instead.
|
||||
|
||||
Returns:
|
||||
Tuple of (compiled prompt string, Langfuse prompt object for tracing)
|
||||
"""
|
||||
|
||||
langfuse = _get_langfuse_client()
|
||||
|
||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||
prompt = langfuse.get_prompt(config.langfuse_prompt_name, cache_ttl_seconds=0)
|
||||
|
||||
# If user is authenticated, try to fetch their business understanding
|
||||
understanding = None
|
||||
if user_id:
|
||||
try:
|
||||
understanding = await get_business_understanding(user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch business understanding: {e}")
|
||||
understanding = None
|
||||
if understanding:
|
||||
context = format_understanding_for_prompt(understanding)
|
||||
else:
|
||||
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||
|
||||
compiled = prompt.compile(users_information=context)
|
||||
return compiled, prompt
|
||||
|
||||
|
||||
async def _generate_session_title(message: str) -> str | None:
|
||||
"""Generate a concise title for a chat session based on the first message.
|
||||
|
||||
Args:
|
||||
message: The first user message in the session
|
||||
|
||||
Returns:
|
||||
A short title (3-6 words) or None if generation fails
|
||||
"""
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=config.title_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"Generate a very short title (3-6 words) for a chat conversation "
|
||||
"based on the user's first message. The title should capture the "
|
||||
"main topic or intent. Return ONLY the title, no quotes or punctuation."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": message[:500]}, # Limit input length
|
||||
],
|
||||
max_tokens=20,
|
||||
)
|
||||
title = response.choices[0].message.content
|
||||
if title:
|
||||
# Clean up the title
|
||||
title = title.strip().strip("\"'")
|
||||
# Limit length
|
||||
if len(title) > 50:
|
||||
title = title[:47] + "..."
|
||||
return title
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate session title: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def assign_user_to_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
) -> ChatSession:
|
||||
"""
|
||||
Assign a user to a chat session.
|
||||
"""
|
||||
session = await get_chat_session(session_id, None)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
session.user_id = user_id
|
||||
return await upsert_chat_session(session)
|
||||
|
||||
|
||||
async def stream_chat_completion(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
retry_count: int = 0,
|
||||
session: ChatSession | None = None,
|
||||
context: dict[str, str] | None = None, # {url: str, content: str}
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Main entry point for streaming chat completions with database handling.
|
||||
|
||||
This function handles all database operations and delegates streaming
|
||||
to the internal _stream_chat_chunks function.
|
||||
|
||||
Args:
|
||||
session_id: Chat session ID
|
||||
user_message: User's input message
|
||||
user_id: User ID for authentication (None for anonymous)
|
||||
session: Optional pre-loaded session object (for recursive calls to avoid Redis refetch)
|
||||
|
||||
Yields:
|
||||
StreamBaseResponse objects formatted as SSE
|
||||
|
||||
Raises:
|
||||
NotFoundError: If session_id is invalid
|
||||
ValueError: If max_context_messages is exceeded
|
||||
|
||||
"""
|
||||
logger.info(
|
||||
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
||||
)
|
||||
|
||||
# Check if Langfuse is configured - required for chat functionality
|
||||
if not _is_langfuse_configured():
|
||||
logger.error("Chat request failed: Langfuse is not configured")
|
||||
yield StreamError(
|
||||
errorText="Chat service is not available. Langfuse must be configured "
|
||||
"with LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables."
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# Langfuse observations will be created after session is loaded (need messages for input)
|
||||
# Initialize to None so finally block can safely check and end them
|
||||
trace = None
|
||||
generation = None
|
||||
|
||||
# Only fetch from Redis if session not provided (initial call)
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
logger.info(
|
||||
f"Fetched session from Redis: {session.session_id if session else 'None'}, "
|
||||
f"message_count={len(session.messages) if session else 0}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Using provided session object: {session.session_id}, "
|
||||
f"message_count={len(session.messages)}"
|
||||
)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
f"Session {session_id} not found. Please create a new session first."
|
||||
)
|
||||
|
||||
if message:
|
||||
# Build message content with context if provided
|
||||
message_content = message
|
||||
if context and context.get("url") and context.get("content"):
|
||||
context_text = f"Page URL: {context['url']}\n\nPage Content:\n{context['content']}\n\n---\n\nUser Message: {message}"
|
||||
message_content = context_text
|
||||
logger.info(
|
||||
f"Including page context: URL={context['url']}, content_length={len(context['content'])}"
|
||||
)
|
||||
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="user" if is_user_message else "assistant", content=message_content
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Appended message (role={'user' if is_user_message else 'assistant'}), "
|
||||
f"new message_count={len(session.messages)}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Upserting session: {session.session_id} with user id {session.user_id}, "
|
||||
f"message_count={len(session.messages)}"
|
||||
)
|
||||
session = await upsert_chat_session(session)
|
||||
assert session, "Session not found"
|
||||
|
||||
# Generate title for new sessions on first user message (non-blocking)
|
||||
# Check: is_user_message, no title yet, and this is the first user message
|
||||
if is_user_message and message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
if len(user_messages) == 1:
|
||||
# First user message - generate title in background
|
||||
import asyncio
|
||||
|
||||
# Capture only the values we need (not the session object) to avoid
|
||||
# stale data issues when the main flow modifies the session
|
||||
captured_session_id = session_id
|
||||
captured_message = message
|
||||
|
||||
async def _update_title():
|
||||
try:
|
||||
title = await _generate_session_title(captured_message)
|
||||
if title:
|
||||
# Use dedicated title update function that doesn't
|
||||
# touch messages, avoiding race conditions
|
||||
await update_session_title(captured_session_id, title)
|
||||
logger.info(
|
||||
f"Generated title for session {captured_session_id}: {title}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update session title: {e}")
|
||||
|
||||
# Fire and forget - don't block the chat response
|
||||
asyncio.create_task(_update_title())
|
||||
|
||||
# Build system prompt with business understanding
|
||||
system_prompt, langfuse_prompt = await _build_system_prompt(user_id)
|
||||
|
||||
# Build input messages including system prompt for complete Langfuse logging
|
||||
trace_input_messages = [{"role": "system", "content": system_prompt}] + [
|
||||
m.model_dump() for m in session.messages
|
||||
]
|
||||
|
||||
# Create Langfuse trace for this LLM call (each call gets its own trace, grouped by session_id)
|
||||
# Using v3 SDK: start_observation creates a root span, update_trace sets trace-level attributes
|
||||
try:
|
||||
langfuse = _get_langfuse_client()
|
||||
env = _get_environment()
|
||||
trace = langfuse.start_observation(
|
||||
name="chat_completion",
|
||||
input={"messages": trace_input_messages},
|
||||
metadata={
|
||||
"environment": env,
|
||||
"model": config.model,
|
||||
"message_count": len(session.messages),
|
||||
"prompt_name": langfuse_prompt.name if langfuse_prompt else None,
|
||||
"prompt_version": langfuse_prompt.version if langfuse_prompt else None,
|
||||
},
|
||||
)
|
||||
# Set trace-level attributes (session_id, user_id, tags)
|
||||
trace.update_trace(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tags=[env, "copilot"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create Langfuse trace: {e}")
|
||||
|
||||
# Initialize variables that will be used in finally block (must be defined before try)
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
)
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
# Wrap main logic in try/finally to ensure Langfuse observations are always ended
|
||||
try:
|
||||
has_yielded_end = False
|
||||
has_yielded_error = False
|
||||
has_done_tool_call = False
|
||||
has_received_text = False
|
||||
text_streaming_ended = False
|
||||
tool_response_messages: list[ChatMessage] = []
|
||||
should_retry = False
|
||||
|
||||
# Generate unique IDs for AI SDK protocol
|
||||
import uuid as uuid_module
|
||||
|
||||
message_id = str(uuid_module.uuid4())
|
||||
text_block_id = str(uuid_module.uuid4())
|
||||
|
||||
# Yield message start
|
||||
yield StreamStart(messageId=message_id)
|
||||
|
||||
# Create Langfuse generation for each LLM call, linked to the prompt
|
||||
# Using v3 SDK: start_observation with as_type="generation"
|
||||
generation = (
|
||||
trace.start_observation(
|
||||
as_type="generation",
|
||||
name="llm_call",
|
||||
model=config.model,
|
||||
input={"messages": trace_input_messages},
|
||||
prompt=langfuse_prompt,
|
||||
)
|
||||
if trace
|
||||
else None
|
||||
)
|
||||
|
||||
try:
|
||||
async for chunk in _stream_chat_chunks(
|
||||
session=session,
|
||||
tools=tools,
|
||||
system_prompt=system_prompt,
|
||||
text_block_id=text_block_id,
|
||||
):
|
||||
|
||||
if isinstance(chunk, StreamTextStart):
|
||||
# Emit text-start before first text delta
|
||||
if not has_received_text:
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamTextDelta):
|
||||
delta = chunk.delta or ""
|
||||
assert assistant_response.content is not None
|
||||
assistant_response.content += delta
|
||||
has_received_text = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamTextEnd):
|
||||
# Emit text-end after text completes
|
||||
if has_received_text and not text_streaming_ended:
|
||||
text_streaming_ended = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamToolInputStart):
|
||||
# Emit text-end before first tool call, but only if we've received text
|
||||
if has_received_text and not text_streaming_ended:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_streaming_ended = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamToolInputAvailable):
|
||||
# Accumulate tool calls in OpenAI format
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": chunk.toolCallId,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": chunk.toolName,
|
||||
"arguments": orjson.dumps(chunk.input).decode("utf-8"),
|
||||
},
|
||||
}
|
||||
)
|
||||
elif isinstance(chunk, StreamToolOutputAvailable):
|
||||
result_content = (
|
||||
chunk.output
|
||||
if isinstance(chunk.output, str)
|
||||
else orjson.dumps(chunk.output).decode("utf-8")
|
||||
)
|
||||
tool_response_messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=result_content,
|
||||
tool_call_id=chunk.toolCallId,
|
||||
)
|
||||
)
|
||||
has_done_tool_call = True
|
||||
# Track if any tool execution failed
|
||||
if not chunk.success:
|
||||
logger.warning(
|
||||
f"Tool {chunk.toolName} (ID: {chunk.toolCallId}) execution failed"
|
||||
)
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamFinish):
|
||||
if not has_done_tool_call:
|
||||
# Emit text-end before finish if we received text but haven't closed it
|
||||
if has_received_text and not text_streaming_ended:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_streaming_ended = True
|
||||
has_yielded_end = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamError):
|
||||
has_yielded_error = True
|
||||
elif isinstance(chunk, StreamUsage):
|
||||
session.usage.append(
|
||||
Usage(
|
||||
prompt_tokens=chunk.promptTokens,
|
||||
completion_tokens=chunk.completionTokens,
|
||||
total_tokens=chunk.totalTokens,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during stream: {e!s}", exc_info=True)
|
||||
|
||||
# Check if this is a retryable error (JSON parsing, incomplete tool calls, etc.)
|
||||
is_retryable = isinstance(e, (orjson.JSONDecodeError, KeyError, TypeError))
|
||||
|
||||
if is_retryable and retry_count < config.max_retries:
|
||||
logger.info(
|
||||
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
||||
)
|
||||
should_retry = True
|
||||
else:
|
||||
# Non-retryable error or max retries exceeded
|
||||
# Save any partial progress before reporting error
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
|
||||
# Add assistant message if it has content or tool calls
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
messages_to_save.append(assistant_response)
|
||||
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
|
||||
session.messages.extend(messages_to_save)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
if not has_yielded_error:
|
||||
error_message = str(e)
|
||||
if not is_retryable:
|
||||
error_message = f"Non-retryable error: {error_message}"
|
||||
elif retry_count >= config.max_retries:
|
||||
error_message = f"Max retries ({config.max_retries}) exceeded: {error_message}"
|
||||
|
||||
error_response = StreamError(errorText=error_message)
|
||||
yield error_response
|
||||
if not has_yielded_end:
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# Handle retry outside of exception handler to avoid nesting
|
||||
if should_retry and retry_count < config.max_retries:
|
||||
logger.info(
|
||||
f"Retrying stream_chat_completion for session {session_id}, attempt {retry_count + 1}"
|
||||
)
|
||||
async for chunk in stream_chat_completion(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
retry_count=retry_count + 1,
|
||||
session=session,
|
||||
context=context,
|
||||
):
|
||||
yield chunk
|
||||
return # Exit after retry to avoid double-saving in finally block
|
||||
|
||||
# Normal completion path - save session and handle tool call continuation
|
||||
logger.info(
|
||||
f"Normal completion path: session={session.session_id}, "
|
||||
f"current message_count={len(session.messages)}"
|
||||
)
|
||||
|
||||
# Build the messages list in the correct order
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
|
||||
# Add assistant message with tool_calls if any
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
logger.info(
|
||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||
)
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
messages_to_save.append(assistant_response)
|
||||
logger.info(
|
||||
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
|
||||
)
|
||||
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
logger.info(
|
||||
f"Saving {len(tool_response_messages)} tool response messages, "
|
||||
f"total_to_save={len(messages_to_save)}"
|
||||
)
|
||||
|
||||
session.messages.extend(messages_to_save)
|
||||
logger.info(
|
||||
f"Extended session messages, new message_count={len(session.messages)}"
|
||||
)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
# If we did a tool call, stream the chat completion again to get the next response
|
||||
if has_done_tool_call:
|
||||
logger.info(
|
||||
"Tool call executed, streaming chat completion again to get assistant response"
|
||||
)
|
||||
async for chunk in stream_chat_completion(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
session=session, # Pass session object to avoid Redis refetch
|
||||
context=context,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
finally:
|
||||
# Always end Langfuse observations to prevent resource leaks
|
||||
# Guard against None and catch errors to avoid masking original exceptions
|
||||
if generation is not None:
|
||||
try:
|
||||
latest_usage = session.usage[-1] if session.usage else None
|
||||
generation.update(
|
||||
model=config.model,
|
||||
output={
|
||||
"content": assistant_response.content,
|
||||
"tool_calls": accumulated_tool_calls or None,
|
||||
},
|
||||
usage_details=(
|
||||
{
|
||||
"input": latest_usage.prompt_tokens,
|
||||
"output": latest_usage.completion_tokens,
|
||||
"total": latest_usage.total_tokens,
|
||||
}
|
||||
if latest_usage
|
||||
else None
|
||||
),
|
||||
)
|
||||
generation.end()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to end Langfuse generation: {e}")
|
||||
|
||||
if trace is not None:
|
||||
try:
|
||||
if accumulated_tool_calls:
|
||||
trace.update_trace(output={"tool_calls": accumulated_tool_calls})
|
||||
else:
|
||||
trace.update_trace(output={"response": assistant_response.content})
|
||||
trace.end()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to end Langfuse trace: {e}")
|
||||
|
||||
|
||||
# Retry configuration for OpenAI API calls
|
||||
MAX_RETRIES = 3
|
||||
BASE_DELAY_SECONDS = 1.0
|
||||
MAX_DELAY_SECONDS = 30.0
|
||||
|
||||
|
||||
def _is_retryable_error(error: Exception) -> bool:
|
||||
"""Determine if an error is retryable."""
|
||||
if isinstance(error, RateLimitError):
|
||||
return True
|
||||
if isinstance(error, APIConnectionError):
|
||||
return True
|
||||
if isinstance(error, APIStatusError):
|
||||
# APIStatusError has a response with status_code
|
||||
# Retry on 5xx status codes (server errors)
|
||||
if error.response.status_code >= 500:
|
||||
return True
|
||||
if isinstance(error, APIError):
|
||||
# Retry on overloaded errors or 500 errors (may not have status code)
|
||||
error_message = str(error).lower()
|
||||
if "overloaded" in error_message or "internal server error" in error_message:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _stream_chat_chunks(
|
||||
session: ChatSession,
|
||||
tools: list[ChatCompletionToolParam],
|
||||
system_prompt: str | None = None,
|
||||
text_block_id: str | None = None,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""
|
||||
Pure streaming function for OpenAI chat completions with tool calling.
|
||||
|
||||
This function is database-agnostic and focuses only on streaming logic.
|
||||
Implements exponential backoff retry for transient API errors.
|
||||
|
||||
Args:
|
||||
session: Chat session with conversation history
|
||||
tools: Available tools for the model
|
||||
system_prompt: System prompt to prepend to messages
|
||||
|
||||
Yields:
|
||||
SSE formatted JSON response objects
|
||||
|
||||
"""
|
||||
model = config.model
|
||||
|
||||
logger.info("Starting pure chat stream")
|
||||
|
||||
# Build messages with system prompt prepended
|
||||
messages = session.to_openai_messages()
|
||||
if system_prompt:
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||
|
||||
system_message = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
)
|
||||
messages = [system_message] + messages
|
||||
|
||||
# Loop to handle tool calls and continue conversation
|
||||
while True:
|
||||
retry_count = 0
|
||||
last_error: Exception | None = None
|
||||
|
||||
while retry_count <= MAX_RETRIES:
|
||||
try:
|
||||
logger.info(
|
||||
f"Creating OpenAI chat completion stream..."
|
||||
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
|
||||
)
|
||||
|
||||
# Create the stream with proper types
|
||||
stream = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
|
||||
# Variables to accumulate tool calls
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
active_tool_call_idx: int | None = None
|
||||
finish_reason: str | None = None
|
||||
# Track which tool call indices have had their start event emitted
|
||||
emitted_start_for_idx: set[int] = set()
|
||||
|
||||
# Track if we've started the text block
|
||||
text_started = False
|
||||
|
||||
# Process the stream
|
||||
chunk: ChatCompletionChunk
|
||||
async for chunk in stream:
|
||||
if chunk.usage:
|
||||
yield StreamUsage(
|
||||
promptTokens=chunk.usage.prompt_tokens,
|
||||
completionTokens=chunk.usage.completion_tokens,
|
||||
totalTokens=chunk.usage.total_tokens,
|
||||
)
|
||||
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
# Capture finish reason
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
logger.info(f"Finish reason: {finish_reason}")
|
||||
|
||||
# Handle content streaming
|
||||
if delta.content:
|
||||
# Emit text-start on first text content
|
||||
if not text_started and text_block_id:
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
text_started = True
|
||||
# Stream the text delta
|
||||
text_response = StreamTextDelta(
|
||||
id=text_block_id or "",
|
||||
delta=delta.content,
|
||||
)
|
||||
yield text_response
|
||||
|
||||
# Handle tool calls
|
||||
if delta.tool_calls:
|
||||
for tc_chunk in delta.tool_calls:
|
||||
idx = tc_chunk.index
|
||||
|
||||
# Update active tool call index if needed
|
||||
if (
|
||||
active_tool_call_idx is None
|
||||
or active_tool_call_idx != idx
|
||||
):
|
||||
active_tool_call_idx = idx
|
||||
|
||||
# Ensure we have a tool call object at this index
|
||||
while len(tool_calls) <= idx:
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": "",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Accumulate the tool call data
|
||||
if tc_chunk.id:
|
||||
tool_calls[idx]["id"] = tc_chunk.id
|
||||
if tc_chunk.function:
|
||||
if tc_chunk.function.name:
|
||||
tool_calls[idx]["function"][
|
||||
"name"
|
||||
] = tc_chunk.function.name
|
||||
if tc_chunk.function.arguments:
|
||||
tool_calls[idx]["function"][
|
||||
"arguments"
|
||||
] += tc_chunk.function.arguments
|
||||
|
||||
# Emit StreamToolInputStart only after we have the tool call ID
|
||||
if (
|
||||
idx not in emitted_start_for_idx
|
||||
and tool_calls[idx]["id"]
|
||||
and tool_calls[idx]["function"]["name"]
|
||||
):
|
||||
yield StreamToolInputStart(
|
||||
toolCallId=tool_calls[idx]["id"],
|
||||
toolName=tool_calls[idx]["function"]["name"],
|
||||
)
|
||||
emitted_start_for_idx.add(idx)
|
||||
logger.info(f"Stream complete. Finish reason: {finish_reason}")
|
||||
|
||||
# Yield all accumulated tool calls after the stream is complete
|
||||
# This ensures all tool call arguments have been fully received
|
||||
for idx, tool_call in enumerate(tool_calls):
|
||||
try:
|
||||
async for tc in _yield_tool_call(tool_calls, idx, session):
|
||||
yield tc
|
||||
except (orjson.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logger.error(
|
||||
f"Failed to parse tool call {idx}: {e}",
|
||||
exc_info=True,
|
||||
extra={"tool_call": tool_call},
|
||||
)
|
||||
yield StreamError(
|
||||
errorText=f"Invalid tool call arguments for tool {tool_call.get('function', {}).get('name', 'unknown')}: {e}",
|
||||
)
|
||||
# Re-raise to trigger retry logic in the parent function
|
||||
raise
|
||||
|
||||
yield StreamFinish()
|
||||
return
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
|
||||
retry_count += 1
|
||||
# Calculate delay with exponential backoff
|
||||
delay = min(
|
||||
BASE_DELAY_SECONDS * (2 ** (retry_count - 1)),
|
||||
MAX_DELAY_SECONDS,
|
||||
)
|
||||
logger.warning(
|
||||
f"Retryable error in stream: {e!s}. "
|
||||
f"Retrying in {delay:.1f}s (attempt {retry_count}/{MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue # Retry the stream
|
||||
else:
|
||||
# Non-retryable error or max retries exceeded
|
||||
logger.error(
|
||||
f"Error in stream (not retrying): {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
error_response = StreamError(errorText=str(e))
|
||||
yield error_response
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# If we exit the retry loop without returning, it means we exhausted retries
|
||||
if last_error:
|
||||
logger.error(
|
||||
f"Max retries ({MAX_RETRIES}) exceeded. Last error: {last_error!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
yield StreamError(errorText=f"Max retries exceeded: {last_error!s}")
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
|
||||
async def _yield_tool_call(
|
||||
tool_calls: list[dict[str, Any]],
|
||||
yield_idx: int,
|
||||
session: ChatSession,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""
|
||||
Yield a tool call and its execution result.
|
||||
|
||||
Raises:
|
||||
orjson.JSONDecodeError: If tool call arguments cannot be parsed as JSON
|
||||
KeyError: If expected tool call fields are missing
|
||||
TypeError: If tool call structure is invalid
|
||||
"""
|
||||
tool_name = tool_calls[yield_idx]["function"]["name"]
|
||||
tool_call_id = tool_calls[yield_idx]["id"]
|
||||
logger.info(f"Yielding tool call: {tool_calls[yield_idx]}")
|
||||
|
||||
# Parse tool call arguments - handle empty arguments gracefully
|
||||
raw_arguments = tool_calls[yield_idx]["function"]["arguments"]
|
||||
if raw_arguments:
|
||||
arguments = orjson.loads(raw_arguments)
|
||||
else:
|
||||
arguments = {}
|
||||
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
input=arguments,
|
||||
)
|
||||
|
||||
tool_execution_response: StreamToolOutputAvailable = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=arguments,
|
||||
tool_call_id=tool_call_id,
|
||||
user_id=session.user_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(f"Yielding Tool execution response: {tool_execution_response}")
|
||||
yield tool_execution_response
|
||||
@@ -3,19 +3,20 @@ from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.chat.service as chat_service
|
||||
from backend.server.v2.chat.response_model import (
|
||||
StreamEnd,
|
||||
from . import service as chat_service
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import (
|
||||
StreamError,
|
||||
StreamTextChunk,
|
||||
StreamToolExecutionResult,
|
||||
StreamFinish,
|
||||
StreamTextDelta,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion():
|
||||
async def test_stream_chat_completion(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
@@ -23,7 +24,7 @@ async def test_stream_chat_completion():
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await chat_service.create_chat_session()
|
||||
session = await create_chat_session(test_user_id)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
@@ -34,9 +35,9 @@ async def test_stream_chat_completion():
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
if isinstance(chunk, StreamTextChunk):
|
||||
assistant_message += chunk.content
|
||||
if isinstance(chunk, StreamEnd):
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
assistant_message += chunk.delta
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
@@ -45,7 +46,7 @@ async def test_stream_chat_completion():
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion_with_tool_calls():
|
||||
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
@@ -53,8 +54,8 @@ async def test_stream_chat_completion_with_tool_calls():
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await chat_service.create_chat_session()
|
||||
session = await chat_service.upsert_chat_session(session)
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
@@ -68,14 +69,14 @@ async def test_stream_chat_completion_with_tool_calls():
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
|
||||
if isinstance(chunk, StreamEnd):
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
if isinstance(chunk, StreamToolExecutionResult):
|
||||
if isinstance(chunk, StreamToolOutputAvailable):
|
||||
had_tool_calls = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert had_tool_calls, "Tool calls did not occur"
|
||||
session = await chat_service.get_session(session.session_id)
|
||||
session = await get_chat_session(session.session_id)
|
||||
assert session, "Session not found"
|
||||
assert session.usage, "Usage is empty"
|
||||
@@ -0,0 +1,55 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .find_block import FindBlockTool
|
||||
from .find_library_agent import FindLibraryAgentTool
|
||||
from .get_doc_page import GetDocPageTool
|
||||
from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .search_docs import SearchDocsTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
|
||||
# Single source of truth for all tools
|
||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"add_understanding": AddUnderstandingTool(),
|
||||
"find_agent": FindAgentTool(),
|
||||
"find_block": FindBlockTool(),
|
||||
"find_library_agent": FindLibraryAgentTool(),
|
||||
"run_agent": RunAgentTool(),
|
||||
"run_block": RunBlockTool(),
|
||||
"agent_output": AgentOutputTool(),
|
||||
"search_docs": SearchDocsTool(),
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
}
|
||||
|
||||
# Export individual tool instances for backwards compatibility
|
||||
find_agent_tool = TOOL_REGISTRY["find_agent"]
|
||||
run_agent_tool = TOOL_REGISTRY["run_agent"]
|
||||
|
||||
# Generated from registry for OpenAI API
|
||||
tools: list[ChatCompletionToolParam] = [
|
||||
tool.as_openai_tool() for tool in TOOL_REGISTRY.values()
|
||||
]
|
||||
|
||||
|
||||
async def execute_tool(
|
||||
tool_name: str,
|
||||
parameters: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
tool_call_id: str,
|
||||
) -> "StreamToolOutputAvailable":
|
||||
"""Execute a tool by name."""
|
||||
tool = TOOL_REGISTRY.get(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
return await tool.execute(user_id, session, tool_call_id, **parameters)
|
||||
@@ -3,8 +3,11 @@ from datetime import UTC, datetime
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
from prisma.types import ProfileCreateInput
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
@@ -13,11 +16,9 @@ from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import APIKeyCredentials
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.store import db as store_db
|
||||
|
||||
|
||||
def make_session(user_id: str | None = None):
|
||||
def make_session(user_id: str):
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
@@ -49,13 +50,13 @@ async def setup_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create a test graph with agent input -> agent output
|
||||
@@ -172,13 +173,13 @@ async def setup_llm_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for LLM tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile for LLM tests",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create test OpenAI credentials for the user
|
||||
@@ -332,13 +333,13 @@ async def setup_firecrawl_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for Firecrawl tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile for Firecrawl tests",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# NOTE: We deliberately do NOT create Firecrawl credentials for this user
|
||||
@@ -0,0 +1,119 @@
|
||||
"""Tool for capturing user business understanding incrementally."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AddUnderstandingTool(BaseTool):
|
||||
"""Tool for capturing user's business understanding incrementally."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "add_understanding"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Capture and store information about the user's business context,
|
||||
workflows, pain points, and automation goals. Call this tool whenever the user
|
||||
shares information about their business. Each call incrementally adds to the
|
||||
existing understanding - you don't need to provide all fields at once.
|
||||
|
||||
Use this to build a comprehensive profile that helps recommend better agents
|
||||
and automations for the user's specific needs."""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
# Auto-generate from Pydantic model schema
|
||||
schema = BusinessUnderstandingInput.model_json_schema()
|
||||
properties = {}
|
||||
for field_name, field_schema in schema.get("properties", {}).items():
|
||||
prop: dict[str, Any] = {"description": field_schema.get("description", "")}
|
||||
# Handle anyOf for Optional types
|
||||
if "anyOf" in field_schema:
|
||||
for option in field_schema["anyOf"]:
|
||||
if option.get("type") != "null":
|
||||
prop["type"] = option.get("type", "string")
|
||||
if "items" in option:
|
||||
prop["items"] = option["items"]
|
||||
break
|
||||
else:
|
||||
prop["type"] = field_schema.get("type", "string")
|
||||
if "items" in field_schema:
|
||||
prop["items"] = field_schema["items"]
|
||||
properties[field_name] = prop
|
||||
return {"type": "object", "properties": properties, "required": []}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
"""Requires authentication to store user-specific data."""
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""
|
||||
Capture and store business understanding incrementally.
|
||||
|
||||
Each call merges new data with existing understanding:
|
||||
- String fields are overwritten if provided
|
||||
- List fields are appended (with deduplication)
|
||||
"""
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required to save business understanding.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if any data was provided
|
||||
if not any(v is not None for v in kwargs.values()):
|
||||
return ErrorResponse(
|
||||
message="Please provide at least one field to update.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build input model from kwargs (only include fields defined in the model)
|
||||
valid_fields = set(BusinessUnderstandingInput.model_fields.keys())
|
||||
input_data = BusinessUnderstandingInput(
|
||||
**{k: v for k, v in kwargs.items() if k in valid_fields}
|
||||
)
|
||||
|
||||
# Track which fields were updated
|
||||
updated_fields = [
|
||||
k for k, v in kwargs.items() if k in valid_fields and v is not None
|
||||
]
|
||||
|
||||
# Upsert with merge
|
||||
understanding = await upsert_business_understanding(user_id, input_data)
|
||||
|
||||
# Build current understanding summary (filter out empty values)
|
||||
current_understanding = {
|
||||
k: v
|
||||
for k, v in understanding.model_dump(
|
||||
exclude={"id", "user_id", "created_at", "updated_at"}
|
||||
).items()
|
||||
if v is not None and v != [] and v != ""
|
||||
}
|
||||
|
||||
return UnderstandingUpdatedResponse(
|
||||
message=f"Updated understanding with: {', '.join(updated_fields)}. "
|
||||
"I now have a better picture of your business context.",
|
||||
session_id=session_id,
|
||||
updated_fields=updated_fields,
|
||||
current_understanding=current_understanding,
|
||||
)
|
||||
@@ -0,0 +1,446 @@
|
||||
"""Tool for retrieving agent execution outputs from user's library."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentOutputResponse,
|
||||
ErrorResponse,
|
||||
ExecutionOutputInfo,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .utils import fetch_graph_from_store_slug
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentOutputInput(BaseModel):
|
||||
"""Input parameters for the agent_output tool."""
|
||||
|
||||
agent_name: str = ""
|
||||
library_agent_id: str = ""
|
||||
store_slug: str = ""
|
||||
execution_id: str = ""
|
||||
run_time: str = "latest"
|
||||
|
||||
@field_validator(
|
||||
"agent_name",
|
||||
"library_agent_id",
|
||||
"store_slug",
|
||||
"execution_id",
|
||||
"run_time",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> Any:
|
||||
"""Strip whitespace from string fields."""
|
||||
return v.strip() if isinstance(v, str) else v
|
||||
|
||||
|
||||
def parse_time_expression(
|
||||
time_expr: str | None,
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
Parse time expression into datetime range (start, end).
|
||||
|
||||
Supports: "latest", "yesterday", "today", "last week", "last 7 days",
|
||||
"last month", "last 30 days", ISO date "YYYY-MM-DD", ISO datetime.
|
||||
"""
|
||||
if not time_expr or time_expr.lower() == "latest":
|
||||
return None, None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
expr = time_expr.lower().strip()
|
||||
|
||||
# Relative time expressions lookup
|
||||
relative_times: dict[str, tuple[datetime, datetime]] = {
|
||||
"yesterday": (today_start - timedelta(days=1), today_start),
|
||||
"today": (today_start, now),
|
||||
"last week": (now - timedelta(days=7), now),
|
||||
"last 7 days": (now - timedelta(days=7), now),
|
||||
"last month": (now - timedelta(days=30), now),
|
||||
"last 30 days": (now - timedelta(days=30), now),
|
||||
}
|
||||
if expr in relative_times:
|
||||
return relative_times[expr]
|
||||
|
||||
# Try ISO date format (YYYY-MM-DD)
|
||||
date_match = re.match(r"^(\d{4})-(\d{2})-(\d{2})$", expr)
|
||||
if date_match:
|
||||
try:
|
||||
year, month, day = map(int, date_match.groups())
|
||||
start = datetime(year, month, day, 0, 0, 0, tzinfo=timezone.utc)
|
||||
return start, start + timedelta(days=1)
|
||||
except ValueError:
|
||||
# Invalid date components (e.g., month=13, day=32)
|
||||
pass
|
||||
|
||||
# Try ISO datetime
|
||||
try:
|
||||
parsed = datetime.fromisoformat(expr.replace("Z", "+00:00"))
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed - timedelta(hours=1), parsed + timedelta(hours=1)
|
||||
except ValueError:
|
||||
return None, None
|
||||
|
||||
|
||||
class AgentOutputTool(BaseTool):
|
||||
"""Tool for retrieving execution outputs from user's library agents."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "agent_output"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Retrieve execution outputs from agents in the user's library.
|
||||
|
||||
Identify the agent using one of:
|
||||
- agent_name: Fuzzy search in user's library
|
||||
- library_agent_id: Exact library agent ID
|
||||
- store_slug: Marketplace format 'username/agent-name'
|
||||
|
||||
Select which run to retrieve using:
|
||||
- execution_id: Specific execution ID
|
||||
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
|
||||
"""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_name": {
|
||||
"type": "string",
|
||||
"description": "Agent name to search for in user's library (fuzzy match)",
|
||||
},
|
||||
"library_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Exact library agent ID",
|
||||
},
|
||||
"store_slug": {
|
||||
"type": "string",
|
||||
"description": "Marketplace identifier: 'username/agent-slug'",
|
||||
},
|
||||
"execution_id": {
|
||||
"type": "string",
|
||||
"description": "Specific execution ID to retrieve",
|
||||
},
|
||||
"run_time": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _resolve_agent(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_name: str | None,
|
||||
library_agent_id: str | None,
|
||||
store_slug: str | None,
|
||||
) -> tuple[LibraryAgent | None, str | None]:
|
||||
"""
|
||||
Resolve agent from provided identifiers.
|
||||
Returns (library_agent, error_message).
|
||||
"""
|
||||
# Priority 1: Exact library agent ID
|
||||
if library_agent_id:
|
||||
try:
|
||||
agent = await library_db.get_library_agent(library_agent_id, user_id)
|
||||
return agent, None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get library agent by ID: {e}")
|
||||
return None, f"Library agent '{library_agent_id}' not found"
|
||||
|
||||
# Priority 2: Store slug (username/agent-name)
|
||||
if store_slug and "/" in store_slug:
|
||||
username, agent_slug = store_slug.split("/", 1)
|
||||
graph, _ = await fetch_graph_from_store_slug(username, agent_slug)
|
||||
if not graph:
|
||||
return None, f"Agent '{store_slug}' not found in marketplace"
|
||||
|
||||
# Find in user's library by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
if not agent:
|
||||
return (
|
||||
None,
|
||||
f"Agent '{store_slug}' is not in your library. "
|
||||
"Add it first to see outputs.",
|
||||
)
|
||||
return agent, None
|
||||
|
||||
# Priority 3: Fuzzy name search in library
|
||||
if agent_name:
|
||||
try:
|
||||
response = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=agent_name,
|
||||
page_size=5,
|
||||
)
|
||||
if not response.agents:
|
||||
return (
|
||||
None,
|
||||
f"No agents matching '{agent_name}' found in your library",
|
||||
)
|
||||
|
||||
# Return best match (first result from search)
|
||||
return response.agents[0], None
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching library agents: {e}")
|
||||
return None, f"Error searching for agent: {e}"
|
||||
|
||||
return (
|
||||
None,
|
||||
"Please specify an agent name, library_agent_id, or store_slug",
|
||||
)
|
||||
|
||||
async def _get_execution(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
execution_id: str | None,
|
||||
time_start: datetime | None,
|
||||
time_end: datetime | None,
|
||||
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
|
||||
"""
|
||||
Fetch execution(s) based on filters.
|
||||
Returns (single_execution, available_executions_meta, error_message).
|
||||
"""
|
||||
# If specific execution_id provided, fetch it directly
|
||||
if execution_id:
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if not execution:
|
||||
return None, [], f"Execution '{execution_id}' not found"
|
||||
return execution, [], None
|
||||
|
||||
# Get completed executions with time filters
|
||||
executions = await execution_db.get_graph_executions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
created_time_gte=time_start,
|
||||
created_time_lte=time_end,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
if not executions:
|
||||
return None, [], None # No error, just no executions
|
||||
|
||||
# If only one execution, fetch full details
|
||||
if len(executions) == 1:
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
return full_execution, [], None
|
||||
|
||||
# Multiple executions - return latest with full details, plus list of available
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
return full_execution, executions, None
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
agent: LibraryAgent,
|
||||
execution: GraphExecution | None,
|
||||
available_executions: list[GraphExecutionMeta],
|
||||
session_id: str | None,
|
||||
) -> AgentOutputResponse:
|
||||
"""Build the response based on execution data."""
|
||||
library_agent_link = f"/library/agents/{agent.id}"
|
||||
|
||||
if not execution:
|
||||
return AgentOutputResponse(
|
||||
message=f"No completed executions found for agent '{agent.name}'",
|
||||
session_id=session_id,
|
||||
agent_name=agent.name,
|
||||
agent_id=agent.graph_id,
|
||||
library_agent_id=agent.id,
|
||||
library_agent_link=library_agent_link,
|
||||
total_executions=0,
|
||||
)
|
||||
|
||||
execution_info = ExecutionOutputInfo(
|
||||
execution_id=execution.id,
|
||||
status=execution.status.value,
|
||||
started_at=execution.started_at,
|
||||
ended_at=execution.ended_at,
|
||||
outputs=dict(execution.outputs),
|
||||
inputs_summary=execution.inputs if execution.inputs else None,
|
||||
)
|
||||
|
||||
available_list = None
|
||||
if len(available_executions) > 1:
|
||||
available_list = [
|
||||
{
|
||||
"id": e.id,
|
||||
"status": e.status.value,
|
||||
"started_at": e.started_at.isoformat() if e.started_at else None,
|
||||
}
|
||||
for e in available_executions[:5]
|
||||
]
|
||||
|
||||
message = f"Found execution outputs for agent '{agent.name}'"
|
||||
if len(available_executions) > 1:
|
||||
message += (
|
||||
f". Showing latest of {len(available_executions)} matching executions."
|
||||
)
|
||||
|
||||
return AgentOutputResponse(
|
||||
message=message,
|
||||
session_id=session_id,
|
||||
agent_name=agent.name,
|
||||
agent_id=agent.graph_id,
|
||||
library_agent_id=agent.id,
|
||||
library_agent_link=library_agent_link,
|
||||
execution=execution_info,
|
||||
available_executions=available_list,
|
||||
total_executions=len(available_executions) if available_executions else 1,
|
||||
)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the agent_output tool."""
|
||||
session_id = session.session_id
|
||||
|
||||
# Parse and validate input
|
||||
try:
|
||||
input_data = AgentOutputInput(**kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid input: {e}")
|
||||
return ErrorResponse(
|
||||
message="Invalid input parameters",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Ensure user_id is present (should be guaranteed by requires_auth)
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if at least one identifier is provided
|
||||
if not any(
|
||||
[
|
||||
input_data.agent_name,
|
||||
input_data.library_agent_id,
|
||||
input_data.store_slug,
|
||||
input_data.execution_id,
|
||||
]
|
||||
):
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Please specify at least one of: agent_name, "
|
||||
"library_agent_id, store_slug, or execution_id"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# If only execution_id provided, we need to find the agent differently
|
||||
if (
|
||||
input_data.execution_id
|
||||
and not input_data.agent_name
|
||||
and not input_data.library_agent_id
|
||||
and not input_data.store_slug
|
||||
):
|
||||
# Fetch execution directly to get graph_id
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=input_data.execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if not execution:
|
||||
return ErrorResponse(
|
||||
message=f"Execution '{input_data.execution_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Find library agent by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(
|
||||
user_id, execution.graph_id
|
||||
)
|
||||
if not agent:
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"Execution found but agent not in your library. "
|
||||
f"Graph ID: {execution.graph_id}"
|
||||
),
|
||||
session_id=session_id,
|
||||
suggestions=["Add the agent to your library to see more details"],
|
||||
)
|
||||
|
||||
return self._build_response(agent, execution, [], session_id)
|
||||
|
||||
# Resolve agent from identifiers
|
||||
agent, error = await self._resolve_agent(
|
||||
user_id=user_id,
|
||||
agent_name=input_data.agent_name or None,
|
||||
library_agent_id=input_data.library_agent_id or None,
|
||||
store_slug=input_data.store_slug or None,
|
||||
)
|
||||
|
||||
if error or not agent:
|
||||
return NoResultsResponse(
|
||||
message=error or "Agent not found",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Check the agent name or ID",
|
||||
"Make sure the agent is in your library",
|
||||
],
|
||||
)
|
||||
|
||||
# Parse time expression
|
||||
time_start, time_end = parse_time_expression(input_data.run_time)
|
||||
|
||||
# Fetch execution(s)
|
||||
execution, available_executions, exec_error = await self._get_execution(
|
||||
user_id=user_id,
|
||||
graph_id=agent.graph_id,
|
||||
execution_id=input_data.execution_id or None,
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
)
|
||||
|
||||
if exec_error:
|
||||
return ErrorResponse(
|
||||
message=exec_error,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return self._build_response(agent, execution, available_executions, session_id)
|
||||
@@ -0,0 +1,151 @@
|
||||
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .models import (
|
||||
AgentInfo,
|
||||
AgentsFoundResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SearchSource = Literal["marketplace", "library"]
|
||||
|
||||
|
||||
async def search_agents(
|
||||
query: str,
|
||||
source: SearchSource,
|
||||
session_id: str | None,
|
||||
user_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""
|
||||
Search for agents in marketplace or user library.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
source: "marketplace" or "library"
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (required for library search)
|
||||
|
||||
Returns:
|
||||
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
|
||||
"""
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query", session_id=session_id
|
||||
)
|
||||
|
||||
if source == "library" and not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required to search library",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agents: list[AgentInfo] = []
|
||||
try:
|
||||
if source == "marketplace":
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
results = await store_db.get_store_agents(search_query=query, page_size=5)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=f"{agent.creator}/{agent.slug}",
|
||||
name=agent.agent_name,
|
||||
description=agent.description or "",
|
||||
source="marketplace",
|
||||
in_library=False,
|
||||
creator=agent.creator,
|
||||
category="general",
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False,
|
||||
)
|
||||
)
|
||||
else: # library
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
results = await library_db.list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
)
|
||||
logger.info(f"Found {len(agents)} agents in {source}")
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching {source}: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to search {source}. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agents:
|
||||
suggestions = (
|
||||
[
|
||||
"Try more general terms",
|
||||
"Browse categories in the marketplace",
|
||||
"Check spelling",
|
||||
]
|
||||
if source == "marketplace"
|
||||
else [
|
||||
"Try different keywords",
|
||||
"Use find_agent to search the marketplace",
|
||||
"Check your library at /library",
|
||||
]
|
||||
)
|
||||
no_results_msg = (
|
||||
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
|
||||
if source == "marketplace"
|
||||
else f"No agents matching '{query}' found in your library."
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||
)
|
||||
|
||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
|
||||
title += (
|
||||
f"for '{query}'"
|
||||
if source == "marketplace"
|
||||
else f"in your library for '{query}'"
|
||||
)
|
||||
|
||||
message = (
|
||||
"Now you have found some options for the user to choose from. "
|
||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||
"Please ask the user if they would like to use any of these agents."
|
||||
if source == "marketplace"
|
||||
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
|
||||
)
|
||||
|
||||
return AgentsFoundResponse(
|
||||
message=message,
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.response_model import StreamToolExecutionResult
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
@@ -53,7 +53,7 @@ class BaseTool:
|
||||
session: ChatSession,
|
||||
tool_call_id: str,
|
||||
**kwargs,
|
||||
) -> StreamToolExecutionResult:
|
||||
) -> StreamToolOutputAvailable:
|
||||
"""Execute the tool with authentication check.
|
||||
|
||||
Args:
|
||||
@@ -69,10 +69,10 @@ class BaseTool:
|
||||
logger.error(
|
||||
f"Attempted tool call for {self.name} but user not authenticated"
|
||||
)
|
||||
return StreamToolExecutionResult(
|
||||
tool_id=tool_call_id,
|
||||
tool_name=self.name,
|
||||
result=NeedLoginResponse(
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=NeedLoginResponse(
|
||||
message=f"Please sign in to use {self.name}",
|
||||
session_id=session.session_id,
|
||||
).model_dump_json(),
|
||||
@@ -81,17 +81,17 @@ class BaseTool:
|
||||
|
||||
try:
|
||||
result = await self._execute(user_id, session, **kwargs)
|
||||
return StreamToolExecutionResult(
|
||||
tool_id=tool_call_id,
|
||||
tool_name=self.name,
|
||||
result=result.model_dump_json(),
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=result.model_dump_json(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in {self.name}: {e}", exc_info=True)
|
||||
return StreamToolExecutionResult(
|
||||
tool_id=tool_call_id,
|
||||
tool_name=self.name,
|
||||
result=ErrorResponse(
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=ErrorResponse(
|
||||
message=f"An error occurred while executing {self.name}",
|
||||
error=str(e),
|
||||
session_id=session.session_id,
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Tool for discovering agents from marketplace."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
from .models import ToolResponseBase
|
||||
|
||||
|
||||
class FindAgentTool(BaseTool):
|
||||
"""Tool for discovering agents from the marketplace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Discover agents from the marketplace based on capabilities and user needs."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query describing what the user wants to accomplish. Use single keywords for best results.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
return await search_agents(
|
||||
query=kwargs.get("query", "").strip(),
|
||||
source="marketplace",
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -0,0 +1,192 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||
from backend.api.features.chat.tools.models import (
|
||||
BlockInfoSummary,
|
||||
BlockInputFieldInfo,
|
||||
BlockListResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.data.block import get_block
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FindBlockTool(BaseTool):
|
||||
"""Tool for searching available blocks."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for available blocks by name or description. "
|
||||
"Blocks are reusable components that perform specific tasks like "
|
||||
"sending emails, making API calls, processing text, etc. "
|
||||
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
|
||||
"The response includes each block's id, required_inputs, and input_schema."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find blocks by name or description. "
|
||||
"Use keywords like 'email', 'http', 'text', 'ai', etc."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for blocks matching the query.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
BlockListResponse: List of matching blocks
|
||||
NoResultsResponse: No blocks found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Search for blocks using hybrid search
|
||||
results, total = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found for '{query}'",
|
||||
suggestions=[
|
||||
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||
"Check spelling of technical terms",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Enrich results with full block information
|
||||
blocks: list[BlockInfoSummary] = []
|
||||
for result in results:
|
||||
block_id = result["content_id"]
|
||||
block = get_block(block_id)
|
||||
|
||||
if block:
|
||||
# Get input/output schemas
|
||||
input_schema = {}
|
||||
output_schema = {}
|
||||
try:
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get categories from block instance
|
||||
categories = []
|
||||
if hasattr(block, "categories") and block.categories:
|
||||
categories = [cat.value for cat in block.categories]
|
||||
|
||||
# Extract required inputs for easier use
|
||||
required_inputs: list[BlockInputFieldInfo] = []
|
||||
if input_schema:
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = set(input_schema.get("required", []))
|
||||
# Get credential field names to exclude from required inputs
|
||||
credentials_fields = set(
|
||||
block.input_schema.get_credentials_fields().keys()
|
||||
)
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields - they're handled separately
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
required_inputs.append(
|
||||
BlockInputFieldInfo(
|
||||
name=field_name,
|
||||
type=field_schema.get("type", "string"),
|
||||
description=field_schema.get("description", ""),
|
||||
required=field_name in required_fields,
|
||||
default=field_schema.get("default"),
|
||||
)
|
||||
)
|
||||
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
description=block.description or "",
|
||||
categories=categories,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
required_inputs=required_inputs,
|
||||
)
|
||||
)
|
||||
|
||||
if not blocks:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found for '{query}'",
|
||||
suggestions=[
|
||||
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found {len(blocks)} block(s) matching '{query}'. "
|
||||
"To execute a block, use run_block with the block's 'id' field "
|
||||
"and provide 'input_data' matching the block's input_schema."
|
||||
),
|
||||
blocks=blocks,
|
||||
count=len(blocks),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching blocks: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search blocks",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Tool for searching agents in the user's library."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
from .models import ToolResponseBase
|
||||
|
||||
|
||||
class FindLibraryAgentTool(BaseTool):
|
||||
"""Tool for searching agents in the user's library."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_library_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for agents in the user's library. Use this to find agents "
|
||||
"the user has already added to their library, including agents they "
|
||||
"created or added from the marketplace."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query to find agents by name or description.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
return await search_agents(
|
||||
query=kwargs.get("query", "").strip(),
|
||||
source="library",
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -0,0 +1,148 @@
|
||||
"""GetDocPageTool - Fetch full content of a documentation page."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
DocPageResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Base URL for documentation (can be configured)
|
||||
DOCS_BASE_URL = "https://docs.agpt.co"
|
||||
|
||||
|
||||
class GetDocPageTool(BaseTool):
|
||||
"""Tool for fetching full content of a documentation page."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_doc_page"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Get the full content of a documentation page by its path. "
|
||||
"Use this after search_docs to read the complete content of a relevant page."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The path to the documentation file, as returned by search_docs. "
|
||||
"Example: 'platform/block-sdk-guide.md'"
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False # Documentation is public
|
||||
|
||||
def _get_docs_root(self) -> Path:
|
||||
"""Get the documentation root directory."""
|
||||
this_file = Path(__file__)
|
||||
project_root = this_file.parent.parent.parent.parent.parent.parent.parent.parent
|
||||
return project_root / "docs"
|
||||
|
||||
def _extract_title(self, content: str, fallback: str) -> str:
|
||||
"""Extract title from markdown content."""
|
||||
lines = content.split("\n")
|
||||
for line in lines:
|
||||
if line.startswith("# "):
|
||||
return line[2:].strip()
|
||||
return fallback
|
||||
|
||||
def _make_doc_url(self, path: str) -> str:
|
||||
"""Create a URL for a documentation page."""
|
||||
url_path = path.rsplit(".", 1)[0] if "." in path else path
|
||||
return f"{DOCS_BASE_URL}/{url_path}"
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Fetch full content of a documentation page.
|
||||
|
||||
Args:
|
||||
user_id: User ID (not required for docs)
|
||||
session: Chat session
|
||||
path: Path to the documentation file
|
||||
|
||||
Returns:
|
||||
DocPageResponse: Full document content
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
path = kwargs.get("path", "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide a documentation path.",
|
||||
error="Missing path parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Sanitize path to prevent directory traversal
|
||||
if ".." in path or path.startswith("/"):
|
||||
return ErrorResponse(
|
||||
message="Invalid documentation path.",
|
||||
error="invalid_path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
docs_root = self._get_docs_root()
|
||||
full_path = docs_root / path
|
||||
|
||||
if not full_path.exists():
|
||||
return ErrorResponse(
|
||||
message=f"Documentation page not found: {path}",
|
||||
error="not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Ensure the path is within docs root
|
||||
try:
|
||||
full_path.resolve().relative_to(docs_root.resolve())
|
||||
except ValueError:
|
||||
return ErrorResponse(
|
||||
message="Invalid documentation path.",
|
||||
error="invalid_path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
content = full_path.read_text(encoding="utf-8")
|
||||
title = self._extract_title(content, path)
|
||||
|
||||
return DocPageResponse(
|
||||
message=f"Retrieved documentation page: {title}",
|
||||
title=title,
|
||||
path=path,
|
||||
content=content,
|
||||
doc_url=self._make_doc_url(path),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read documentation page {path}: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read documentation page: {str(e)}",
|
||||
error="read_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Pydantic models for tool responses."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
@@ -11,14 +12,19 @@ from backend.data.model import CredentialsMetaInput
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of tool responses."""
|
||||
|
||||
AGENT_CAROUSEL = "agent_carousel"
|
||||
AGENTS_FOUND = "agents_found"
|
||||
AGENT_DETAILS = "agent_details"
|
||||
SETUP_REQUIREMENTS = "setup_requirements"
|
||||
EXECUTION_STARTED = "execution_started"
|
||||
NEED_LOGIN = "need_login"
|
||||
ERROR = "error"
|
||||
NO_RESULTS = "no_results"
|
||||
SUCCESS = "success"
|
||||
AGENT_OUTPUT = "agent_output"
|
||||
UNDERSTANDING_UPDATED = "understanding_updated"
|
||||
BLOCK_LIST = "block_list"
|
||||
BLOCK_OUTPUT = "block_output"
|
||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||
DOC_PAGE = "doc_page"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -51,14 +57,14 @@ class AgentInfo(BaseModel):
|
||||
graph_id: str | None = None
|
||||
|
||||
|
||||
class AgentCarouselResponse(ToolResponseBase):
|
||||
class AgentsFoundResponse(ToolResponseBase):
|
||||
"""Response for find_agent tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_CAROUSEL
|
||||
type: ResponseType = ResponseType.AGENTS_FOUND
|
||||
title: str = "Available Agents"
|
||||
agents: list[AgentInfo]
|
||||
count: int
|
||||
name: str = "agent_carousel"
|
||||
name: str = "agents_found"
|
||||
|
||||
|
||||
class NoResultsResponse(ToolResponseBase):
|
||||
@@ -173,3 +179,117 @@ class ErrorResponse(ToolResponseBase):
|
||||
type: ResponseType = ResponseType.ERROR
|
||||
error: str | None = None
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# Agent output models
|
||||
class ExecutionOutputInfo(BaseModel):
|
||||
"""Summary of a single execution's outputs."""
|
||||
|
||||
execution_id: str
|
||||
status: str
|
||||
started_at: datetime | None = None
|
||||
ended_at: datetime | None = None
|
||||
outputs: dict[str, list[Any]]
|
||||
inputs_summary: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AgentOutputResponse(ToolResponseBase):
|
||||
"""Response for agent_output tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_OUTPUT
|
||||
agent_name: str
|
||||
agent_id: str
|
||||
library_agent_id: str | None = None
|
||||
library_agent_link: str | None = None
|
||||
execution: ExecutionOutputInfo | None = None
|
||||
available_executions: list[dict[str, Any]] | None = None
|
||||
total_executions: int = 0
|
||||
|
||||
|
||||
# Business understanding models
|
||||
class UnderstandingUpdatedResponse(ToolResponseBase):
|
||||
"""Response for add_understanding tool."""
|
||||
|
||||
type: ResponseType = ResponseType.UNDERSTANDING_UPDATED
|
||||
updated_fields: list[str] = Field(default_factory=list)
|
||||
current_understanding: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
# Documentation search models
|
||||
class DocSearchResult(BaseModel):
|
||||
"""A single documentation search result."""
|
||||
|
||||
title: str
|
||||
path: str
|
||||
section: str
|
||||
snippet: str # Short excerpt for UI display
|
||||
score: float
|
||||
doc_url: str | None = None
|
||||
|
||||
|
||||
class DocSearchResultsResponse(ToolResponseBase):
|
||||
"""Response for search_docs tool."""
|
||||
|
||||
type: ResponseType = ResponseType.DOC_SEARCH_RESULTS
|
||||
results: list[DocSearchResult]
|
||||
count: int
|
||||
query: str
|
||||
|
||||
|
||||
class DocPageResponse(ToolResponseBase):
|
||||
"""Response for get_doc_page tool."""
|
||||
|
||||
type: ResponseType = ResponseType.DOC_PAGE
|
||||
title: str
|
||||
path: str
|
||||
content: str # Full document content
|
||||
doc_url: str | None = None
|
||||
|
||||
|
||||
# Block models
|
||||
class BlockInputFieldInfo(BaseModel):
|
||||
"""Information about a block input field."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
description: str = ""
|
||||
required: bool = False
|
||||
default: Any | None = None
|
||||
|
||||
|
||||
class BlockInfoSummary(BaseModel):
|
||||
"""Summary of a block for search results."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
categories: list[str]
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
required_inputs: list[BlockInputFieldInfo] = Field(
|
||||
default_factory=list,
|
||||
description="List of required input fields for this block",
|
||||
)
|
||||
|
||||
|
||||
class BlockListResponse(ToolResponseBase):
|
||||
"""Response for find_block tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BLOCK_LIST
|
||||
blocks: list[BlockInfoSummary]
|
||||
count: int
|
||||
query: str
|
||||
usage_hint: str = Field(
|
||||
default="To execute a block, call run_block with block_id set to the block's "
|
||||
"'id' field and input_data containing the required fields from input_schema."
|
||||
)
|
||||
|
||||
|
||||
class BlockOutputResponse(ToolResponseBase):
|
||||
"""Response for run_block tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BLOCK_OUTPUT
|
||||
block_id: str
|
||||
block_name: str
|
||||
outputs: dict[str, list[Any]]
|
||||
success: bool = True
|
||||
@@ -5,14 +5,22 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.api.features.chat.config import ChatConfig
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.server.v2.chat.config import ChatConfig
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
from backend.util.timezone_utils import (
|
||||
convert_utc_time_to_user_timezone,
|
||||
get_user_timezone_or_utc,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentDetails,
|
||||
AgentDetailsResponse,
|
||||
ErrorResponse,
|
||||
@@ -23,19 +31,13 @@ from backend.server.v2.chat.tools.models import (
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
from backend.server.v2.chat.tools.utils import (
|
||||
from .utils import (
|
||||
check_user_has_required_credentials,
|
||||
extract_credentials_from_schema,
|
||||
fetch_graph_from_store_slug,
|
||||
get_or_create_library_agent,
|
||||
match_user_credentials_to_graph,
|
||||
)
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
from backend.util.timezone_utils import (
|
||||
convert_utc_time_to_user_timezone,
|
||||
get_user_timezone_or_utc,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
@@ -56,6 +58,7 @@ class RunAgentInput(BaseModel):
|
||||
"""Input parameters for the run_agent tool."""
|
||||
|
||||
username_agent_slug: str = ""
|
||||
library_agent_id: str = ""
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
use_defaults: bool = False
|
||||
schedule_name: str = ""
|
||||
@@ -63,7 +66,12 @@ class RunAgentInput(BaseModel):
|
||||
timezone: str = "UTC"
|
||||
|
||||
@field_validator(
|
||||
"username_agent_slug", "schedule_name", "cron", "timezone", mode="before"
|
||||
"username_agent_slug",
|
||||
"library_agent_id",
|
||||
"schedule_name",
|
||||
"cron",
|
||||
"timezone",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> Any:
|
||||
@@ -89,7 +97,7 @@ class RunAgentTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Run or schedule an agent from the marketplace.
|
||||
return """Run or schedule an agent from the marketplace or user's library.
|
||||
|
||||
The tool automatically handles the setup flow:
|
||||
- Returns missing inputs if required fields are not provided
|
||||
@@ -97,6 +105,10 @@ class RunAgentTool(BaseTool):
|
||||
- Executes immediately if all requirements are met
|
||||
- Schedules execution if cron expression is provided
|
||||
|
||||
Identify the agent using either:
|
||||
- username_agent_slug: Marketplace format 'username/agent-name'
|
||||
- library_agent_id: ID of an agent in the user's library
|
||||
|
||||
For scheduled execution, provide: schedule_name, cron, and optionally timezone."""
|
||||
|
||||
@property
|
||||
@@ -108,6 +120,10 @@ class RunAgentTool(BaseTool):
|
||||
"type": "string",
|
||||
"description": "Agent identifier in format 'username/agent-name'",
|
||||
},
|
||||
"library_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Library agent ID from user's library",
|
||||
},
|
||||
"inputs": {
|
||||
"type": "object",
|
||||
"description": "Input values for the agent",
|
||||
@@ -130,7 +146,7 @@ class RunAgentTool(BaseTool):
|
||||
"description": "IANA timezone for schedule (default: UTC)",
|
||||
},
|
||||
},
|
||||
"required": ["username_agent_slug"],
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -148,10 +164,16 @@ class RunAgentTool(BaseTool):
|
||||
params = RunAgentInput(**kwargs)
|
||||
session_id = session.session_id
|
||||
|
||||
# Validate agent slug format
|
||||
if not params.username_agent_slug or "/" not in params.username_agent_slug:
|
||||
# Validate at least one identifier is provided
|
||||
has_slug = params.username_agent_slug and "/" in params.username_agent_slug
|
||||
has_library_id = bool(params.library_agent_id)
|
||||
|
||||
if not has_slug and not has_library_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide an agent slug in format 'username/agent-name'",
|
||||
message=(
|
||||
"Please provide either a username_agent_slug "
|
||||
"(format 'username/agent-name') or a library_agent_id"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -166,13 +188,41 @@ class RunAgentTool(BaseTool):
|
||||
is_schedule = bool(params.schedule_name or params.cron)
|
||||
|
||||
try:
|
||||
# Step 1: Fetch agent details (always happens first)
|
||||
username, agent_name = params.username_agent_slug.split("/", 1)
|
||||
graph, store_agent = await fetch_graph_from_store_slug(username, agent_name)
|
||||
# Step 1: Fetch agent details
|
||||
graph: GraphModel | None = None
|
||||
library_agent = None
|
||||
|
||||
# Priority: library_agent_id if provided
|
||||
if has_library_id:
|
||||
library_agent = await library_db.get_library_agent(
|
||||
params.library_agent_id, user_id
|
||||
)
|
||||
if not library_agent:
|
||||
return ErrorResponse(
|
||||
message=f"Library agent '{params.library_agent_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
# Get the graph from the library agent
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id,
|
||||
library_agent.graph_version,
|
||||
user_id=user_id,
|
||||
)
|
||||
else:
|
||||
# Fetch from marketplace slug
|
||||
username, agent_name = params.username_agent_slug.split("/", 1)
|
||||
graph, _ = await fetch_graph_from_store_slug(username, agent_name)
|
||||
|
||||
if not graph:
|
||||
identifier = (
|
||||
params.library_agent_id
|
||||
if has_library_id
|
||||
else params.username_agent_slug
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"Agent '{params.username_agent_slug}' not found in marketplace",
|
||||
message=f"Agent '{identifier}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.chat.tools._test_data import (
|
||||
from ._test_data import (
|
||||
make_session,
|
||||
setup_firecrawl_test_data,
|
||||
setup_llm_test_data,
|
||||
setup_test_data,
|
||||
)
|
||||
from backend.server.v2.chat.tools.run_agent import RunAgentTool
|
||||
from .run_agent import RunAgentTool
|
||||
|
||||
# This is so the formatter doesn't remove the fixture imports
|
||||
setup_llm_test_data = setup_llm_test_data
|
||||
@@ -17,6 +18,17 @@ setup_test_data = setup_test_data
|
||||
setup_firecrawl_test_data = setup_firecrawl_test_data
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_embedding_functions():
|
||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.ensure_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_run_agent(setup_test_data):
|
||||
"""Test that the run_agent tool successfully executes an approved agent"""
|
||||
@@ -46,11 +58,11 @@ async def test_run_agent(setup_test_data):
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert hasattr(response, "output")
|
||||
# Parse the result JSON to verify the execution started
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
assert "execution_id" in result_data
|
||||
assert "graph_id" in result_data
|
||||
assert result_data["graph_id"] == graph.id
|
||||
@@ -86,11 +98,11 @@ async def test_run_agent_missing_inputs(setup_test_data):
|
||||
|
||||
# Verify that we get an error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert hasattr(response, "output")
|
||||
# The tool should return an ErrorResponse when setup info indicates not ready
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
assert "message" in result_data
|
||||
|
||||
|
||||
@@ -118,10 +130,10 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
|
||||
|
||||
# Verify that we get an error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert hasattr(response, "output")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
assert "message" in result_data
|
||||
# Should get an error about failed setup or not found
|
||||
assert any(
|
||||
@@ -158,12 +170,12 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert hasattr(response, "output")
|
||||
|
||||
# Parse the result JSON to verify the execution started
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should successfully start execution since credentials are available
|
||||
assert "execution_id" in result_data
|
||||
@@ -195,9 +207,9 @@ async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_da
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return agent_details type showing available inputs
|
||||
assert result_data.get("type") == "agent_details"
|
||||
@@ -230,9 +242,9 @@ async def test_run_agent_with_use_defaults(setup_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should execute successfully
|
||||
assert "execution_id" in result_data
|
||||
@@ -260,9 +272,9 @@ async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return setup_requirements type with missing credentials
|
||||
assert result_data.get("type") == "setup_requirements"
|
||||
@@ -292,9 +304,9 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return error
|
||||
assert result_data.get("type") == "error"
|
||||
@@ -305,9 +317,10 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
|
||||
async def test_run_agent_unauthenticated():
|
||||
"""Test that run_agent returns need_login for unauthenticated users."""
|
||||
tool = RunAgentTool()
|
||||
session = make_session(user_id=None)
|
||||
# Session has a user_id (session owner), but we test tool execution without user_id
|
||||
session = make_session(user_id="test-session-owner")
|
||||
|
||||
# Execute without user_id
|
||||
# Execute without user_id to test unauthenticated behavior
|
||||
response = await tool.execute(
|
||||
user_id=None,
|
||||
session_id=str(uuid.uuid4()),
|
||||
@@ -318,9 +331,9 @@ async def test_run_agent_unauthenticated():
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Base tool returns need_login type for unauthenticated users
|
||||
assert result_data.get("type") == "need_login"
|
||||
@@ -350,9 +363,9 @@ async def test_run_agent_schedule_without_cron(setup_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return error about missing cron
|
||||
assert result_data.get("type") == "error"
|
||||
@@ -382,9 +395,9 @@ async def test_run_agent_schedule_without_name(setup_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return error about missing schedule_name
|
||||
assert result_data.get("type") == "error"
|
||||
@@ -0,0 +1,297 @@
|
||||
"""Tool for executing blocks directly."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunBlockTool(BaseTool):
|
||||
"""Tool for executing a block and returning its outputs."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "run_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a specific block with the provided input data. "
|
||||
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
||||
"do NOT guess or make up block IDs. "
|
||||
"Use the 'id' from find_block results and provide input_data "
|
||||
"matching the block's required_inputs."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"block_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The block's 'id' field from find_block results. "
|
||||
"NEVER guess this - always get it from find_block first."
|
||||
),
|
||||
},
|
||||
"input_data": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Input values for the block. Use the 'required_inputs' field "
|
||||
"from find_block to see what fields are needed."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["block_id", "input_data"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _check_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: Any,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Check if user has required credentials for a block.
|
||||
|
||||
Returns:
|
||||
tuple[matched_credentials, missing_credentials]
|
||||
"""
|
||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
missing_credentials: list[CredentialsMetaInput] = []
|
||||
|
||||
# Get credential field info from block's input schema
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
|
||||
if not credentials_fields_info:
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
# Get user's available credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
# field_info.provider is a frozenset of acceptable providers
|
||||
# field_info.supported_types is a frozenset of acceptable types
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in field_info.provider
|
||||
and cred.type in field_info.supported_types
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if matching_cred:
|
||||
matched_credentials[field_name] = CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
else:
|
||||
# Create a placeholder for the missing credential
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing_credentials.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=field_name.replace("_", " ").title(),
|
||||
)
|
||||
)
|
||||
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute a block with the given input data.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
block_id: Block UUID to execute
|
||||
input_data: Input values for the block
|
||||
|
||||
Returns:
|
||||
BlockOutputResponse: Block execution outputs
|
||||
SetupRequirementsResponse: Missing credentials
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
block_id = kwargs.get("block_id", "").strip()
|
||||
input_data = kwargs.get("input_data", {})
|
||||
session_id = session.session_id
|
||||
|
||||
if not block_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a block_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not isinstance(input_data, dict):
|
||||
return ErrorResponse(
|
||||
message="input_data must be an object",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Get the block
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
# Check credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||
user_id, block
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
# Return setup requirements response with missing credentials
|
||||
missing_creds_dict = {c.id: c.model_dump() for c in missing_credentials}
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires credentials that are not configured. "
|
||||
"Please set up the required credentials before running this block."
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=block_id,
|
||||
agent_name=block.name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_creds_dict,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": [c.model_dump() for c in missing_credentials],
|
||||
"inputs": self._get_inputs_list(block),
|
||||
"execution_modes": ["immediate"],
|
||||
},
|
||||
),
|
||||
graph_id=None,
|
||||
graph_version=None,
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch actual credentials and prepare kwargs for block execution
|
||||
# Create execution context with defaults (blocks may require it)
|
||||
exec_kwargs: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"execution_context": ExecutionContext(),
|
||||
}
|
||||
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
# Inject metadata into input_data (for validation)
|
||||
if field_name not in input_data:
|
||||
input_data[field_name] = cred_meta.model_dump()
|
||||
|
||||
# Fetch actual credentials and pass as kwargs (for execution)
|
||||
actual_credentials = await creds_manager.get(
|
||||
user_id, cred_meta.id, lock=False
|
||||
)
|
||||
if actual_credentials:
|
||||
exec_kwargs[field_name] = actual_credentials
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to retrieve credentials for {field_name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
input_data,
|
||||
**exec_kwargs,
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||
"""Extract non-credential inputs from block schema."""
|
||||
inputs_list = []
|
||||
schema = block.input_schema.jsonschema()
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = set(schema.get("required", []))
|
||||
|
||||
# Get credential field names to exclude
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name in required_fields,
|
||||
}
|
||||
)
|
||||
|
||||
return inputs_list
|
||||
@@ -0,0 +1,208 @@
|
||||
"""SearchDocsTool - Search documentation using hybrid search."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
DocSearchResult,
|
||||
DocSearchResultsResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Base URL for documentation (can be configured)
|
||||
DOCS_BASE_URL = "https://docs.agpt.co"
|
||||
|
||||
# Maximum number of results to return
|
||||
MAX_RESULTS = 5
|
||||
|
||||
# Snippet length for preview
|
||||
SNIPPET_LENGTH = 200
|
||||
|
||||
|
||||
class SearchDocsTool(BaseTool):
|
||||
"""Tool for searching AutoGPT platform documentation."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "search_docs"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search the AutoGPT platform documentation for information about "
|
||||
"how to use the platform, build agents, configure blocks, and more. "
|
||||
"Returns relevant documentation sections. Use get_doc_page to read full content."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find relevant documentation. "
|
||||
"Use natural language to describe what you're looking for."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False # Documentation is public
|
||||
|
||||
def _create_snippet(self, content: str, max_length: int = SNIPPET_LENGTH) -> str:
|
||||
"""Create a short snippet from content for preview."""
|
||||
# Remove markdown formatting for cleaner snippet
|
||||
clean_content = content.replace("#", "").replace("*", "").replace("`", "")
|
||||
# Remove extra whitespace
|
||||
clean_content = " ".join(clean_content.split())
|
||||
|
||||
if len(clean_content) <= max_length:
|
||||
return clean_content
|
||||
|
||||
# Truncate at word boundary
|
||||
truncated = clean_content[:max_length]
|
||||
last_space = truncated.rfind(" ")
|
||||
if last_space > max_length // 2:
|
||||
truncated = truncated[:last_space]
|
||||
|
||||
return truncated + "..."
|
||||
|
||||
def _make_doc_url(self, path: str) -> str:
|
||||
"""Create a URL for a documentation page."""
|
||||
# Remove file extension for URL
|
||||
url_path = path.rsplit(".", 1)[0] if "." in path else path
|
||||
return f"{DOCS_BASE_URL}/{url_path}"
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search documentation and return relevant sections.
|
||||
|
||||
Args:
|
||||
user_id: User ID (not required for docs)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
DocSearchResultsResponse: List of matching documentation sections
|
||||
NoResultsResponse: No results found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query.",
|
||||
error="Missing query parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Search using hybrid search for DOCUMENTATION content type only
|
||||
results, total = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.DOCUMENTATION],
|
||||
page=1,
|
||||
page_size=MAX_RESULTS * 2, # Fetch extra for deduplication
|
||||
min_score=0.1, # Lower threshold for docs
|
||||
)
|
||||
|
||||
if not results:
|
||||
return NoResultsResponse(
|
||||
message=f"No documentation found for '{query}'.",
|
||||
suggestions=[
|
||||
"Try different keywords",
|
||||
"Use more general terms",
|
||||
"Check for typos in your query",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Deduplicate by document path (keep highest scoring section per doc)
|
||||
seen_docs: dict[str, dict[str, Any]] = {}
|
||||
for result in results:
|
||||
metadata = result.get("metadata", {})
|
||||
doc_path = metadata.get("path", "")
|
||||
|
||||
if not doc_path:
|
||||
continue
|
||||
|
||||
# Keep the highest scoring result for each document
|
||||
if doc_path not in seen_docs:
|
||||
seen_docs[doc_path] = result
|
||||
elif result.get("combined_score", 0) > seen_docs[doc_path].get(
|
||||
"combined_score", 0
|
||||
):
|
||||
seen_docs[doc_path] = result
|
||||
|
||||
# Sort by score and take top MAX_RESULTS
|
||||
deduplicated = sorted(
|
||||
seen_docs.values(),
|
||||
key=lambda x: x.get("combined_score", 0),
|
||||
reverse=True,
|
||||
)[:MAX_RESULTS]
|
||||
|
||||
if not deduplicated:
|
||||
return NoResultsResponse(
|
||||
message=f"No documentation found for '{query}'.",
|
||||
suggestions=[
|
||||
"Try different keywords",
|
||||
"Use more general terms",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build response
|
||||
doc_results: list[DocSearchResult] = []
|
||||
for result in deduplicated:
|
||||
metadata = result.get("metadata", {})
|
||||
doc_path = metadata.get("path", "")
|
||||
doc_title = metadata.get("doc_title", "")
|
||||
section_title = metadata.get("section_title", "")
|
||||
searchable_text = result.get("searchable_text", "")
|
||||
score = result.get("combined_score", 0)
|
||||
|
||||
doc_results.append(
|
||||
DocSearchResult(
|
||||
title=doc_title or section_title or doc_path,
|
||||
path=doc_path,
|
||||
section=section_title,
|
||||
snippet=self._create_snippet(searchable_text),
|
||||
score=round(score, 3),
|
||||
doc_url=self._make_doc_url(doc_path),
|
||||
)
|
||||
)
|
||||
|
||||
return DocSearchResultsResponse(
|
||||
message=f"Found {len(doc_results)} relevant documentation sections.",
|
||||
results=doc_results,
|
||||
count=len(doc_results),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Documentation search failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Failed to search documentation: {str(e)}",
|
||||
error="search_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -3,13 +3,13 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.server.v2.library import db as library_db
|
||||
from backend.server.v2.library import model as library_model
|
||||
from backend.server.v2.store import db as store_db
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -7,9 +7,10 @@ import pytest_mock
|
||||
from prisma.enums import ReviewStatus
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.server.rest_api import handle_internal_http_error
|
||||
from backend.server.v2.executions.review.model import PendingHumanReviewModel
|
||||
from backend.server.v2.executions.review.routes import router
|
||||
from backend.api.rest_api import handle_internal_http_error
|
||||
|
||||
from .model import PendingHumanReviewModel
|
||||
from .routes import router
|
||||
|
||||
# Using a fixed timestamp for reproducible tests
|
||||
FIXED_NOW = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
|
||||
@@ -54,13 +55,13 @@ def sample_pending_review(test_user_id: str) -> PendingHumanReviewModel:
|
||||
|
||||
|
||||
def test_get_pending_reviews_empty(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting pending reviews when none exist"""
|
||||
mock_get_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_user"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_user"
|
||||
)
|
||||
mock_get_reviews.return_value = []
|
||||
|
||||
@@ -72,14 +73,14 @@ def test_get_pending_reviews_empty(
|
||||
|
||||
|
||||
def test_get_pending_reviews_with_data(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting pending reviews with data"""
|
||||
mock_get_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_user"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_user"
|
||||
)
|
||||
mock_get_reviews.return_value = [sample_pending_review]
|
||||
|
||||
@@ -94,14 +95,14 @@ def test_get_pending_reviews_with_data(
|
||||
|
||||
|
||||
def test_get_pending_reviews_for_execution_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting pending reviews for specific execution"""
|
||||
mock_get_graph_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_graph_execution_meta"
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_get_graph_execution.return_value = {
|
||||
"id": "test_graph_exec_456",
|
||||
@@ -109,7 +110,7 @@ def test_get_pending_reviews_for_execution_success(
|
||||
}
|
||||
|
||||
mock_get_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews.return_value = [sample_pending_review]
|
||||
|
||||
@@ -121,24 +122,23 @@ def test_get_pending_reviews_for_execution_success(
|
||||
assert data[0]["graph_exec_id"] == "test_graph_exec_456"
|
||||
|
||||
|
||||
def test_get_pending_reviews_for_execution_access_denied(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
def test_get_pending_reviews_for_execution_not_available(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Test access denied when user doesn't own the execution"""
|
||||
mock_get_graph_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_graph_execution_meta"
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_get_graph_execution.return_value = None
|
||||
|
||||
response = client.get("/api/review/execution/test_graph_exec_456")
|
||||
|
||||
assert response.status_code == 403
|
||||
assert "Access denied" in response.json()["detail"]
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_process_review_action_approve_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
@@ -146,12 +146,12 @@ def test_process_review_action_approve_success(
|
||||
# Mock the route functions
|
||||
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
# Create approved review for return
|
||||
approved_review = PendingHumanReviewModel(
|
||||
@@ -174,11 +174,11 @@ def test_process_review_action_approve_success(
|
||||
mock_process_all_reviews.return_value = {"test_node_123": approved_review}
|
||||
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
mocker.patch("backend.server.v2.executions.review.routes.add_graph_execution")
|
||||
mocker.patch("backend.api.features.executions.review.routes.add_graph_execution")
|
||||
|
||||
request_data = {
|
||||
"reviews": [
|
||||
@@ -202,7 +202,7 @@ def test_process_review_action_approve_success(
|
||||
|
||||
|
||||
def test_process_review_action_reject_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
@@ -210,12 +210,12 @@ def test_process_review_action_reject_success(
|
||||
# Mock the route functions
|
||||
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
rejected_review = PendingHumanReviewModel(
|
||||
node_exec_id="test_node_123",
|
||||
@@ -237,7 +237,7 @@ def test_process_review_action_reject_success(
|
||||
mock_process_all_reviews.return_value = {"test_node_123": rejected_review}
|
||||
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
@@ -262,7 +262,7 @@ def test_process_review_action_reject_success(
|
||||
|
||||
|
||||
def test_process_review_action_mixed_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
@@ -289,12 +289,12 @@ def test_process_review_action_mixed_success(
|
||||
# Mock the route functions
|
||||
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review, second_review]
|
||||
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
# Create approved version of first review
|
||||
approved_review = PendingHumanReviewModel(
|
||||
@@ -338,7 +338,7 @@ def test_process_review_action_mixed_success(
|
||||
}
|
||||
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
@@ -369,7 +369,7 @@ def test_process_review_action_mixed_success(
|
||||
|
||||
|
||||
def test_process_review_action_empty_request(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error when no reviews provided"""
|
||||
@@ -386,19 +386,19 @@ def test_process_review_action_empty_request(
|
||||
|
||||
|
||||
def test_process_review_action_review_not_found(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error when review is not found"""
|
||||
# Mock the functions that extract graph execution ID from the request
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [] # No reviews found
|
||||
|
||||
# Mock process_all_reviews to simulate not finding reviews
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
# This should raise a ValueError with "Reviews not found" message based on the data/human_review.py logic
|
||||
mock_process_all_reviews.side_effect = ValueError(
|
||||
@@ -422,20 +422,20 @@ def test_process_review_action_review_not_found(
|
||||
|
||||
|
||||
def test_process_review_action_partial_failure(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test handling of partial failures in review processing"""
|
||||
# Mock the route functions
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||
|
||||
# Mock partial failure in processing
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
mock_process_all_reviews.side_effect = ValueError("Some reviews failed validation")
|
||||
|
||||
@@ -456,20 +456,20 @@ def test_process_review_action_partial_failure(
|
||||
|
||||
|
||||
def test_process_review_action_invalid_node_exec_id(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test failure when trying to process review with invalid node execution ID"""
|
||||
# Mock the route functions
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.get_pending_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||
|
||||
# Mock validation failure - this should return 400, not 500
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.server.v2.executions.review.routes.process_all_reviews_for_execution"
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
mock_process_all_reviews.side_effect = ValueError(
|
||||
"Invalid node execution ID format"
|
||||
@@ -13,11 +13,8 @@ from backend.data.human_review import (
|
||||
process_all_reviews_for_execution,
|
||||
)
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.server.v2.executions.review.model import (
|
||||
PendingHumanReviewModel,
|
||||
ReviewRequest,
|
||||
ReviewResponse,
|
||||
)
|
||||
|
||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -70,8 +67,7 @@ async def list_pending_reviews(
|
||||
response_model=List[PendingHumanReviewModel],
|
||||
responses={
|
||||
200: {"description": "List of pending reviews for the execution"},
|
||||
400: {"description": "Invalid graph execution ID"},
|
||||
403: {"description": "Access denied to graph execution"},
|
||||
404: {"description": "Graph execution not found"},
|
||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
||||
},
|
||||
)
|
||||
@@ -94,7 +90,7 @@ async def list_pending_reviews_for_execution(
|
||||
|
||||
Raises:
|
||||
HTTPException:
|
||||
- 403: If user doesn't own the graph execution
|
||||
- 404: If the graph execution doesn't exist or isn't owned by this user
|
||||
- 500: If authentication fails or database error occurs
|
||||
|
||||
Note:
|
||||
@@ -108,8 +104,8 @@ async def list_pending_reviews_for_execution(
|
||||
)
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to graph execution",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
|
||||
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
||||
@@ -17,6 +17,8 @@ from fastapi import (
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||
|
||||
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||
from backend.api.features.library.model import LibraryAgentPreset
|
||||
from backend.data.graph import NodeModel, get_graph, set_node_webhook
|
||||
from backend.data.integrations import (
|
||||
WebhookEvent,
|
||||
@@ -33,11 +35,7 @@ from backend.data.model import (
|
||||
OAuth2Credentials,
|
||||
UserIntegrations,
|
||||
)
|
||||
from backend.data.onboarding import (
|
||||
OnboardingStep,
|
||||
complete_onboarding_step,
|
||||
increment_runs,
|
||||
)
|
||||
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
@@ -45,13 +43,6 @@ from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.server.integrations.models import (
|
||||
ProviderConstants,
|
||||
ProviderNamesResponse,
|
||||
get_all_provider_names,
|
||||
)
|
||||
from backend.server.v2.library.db import set_preset_webhook, update_preset
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
from backend.util.exceptions import (
|
||||
GraphNotInLibraryError,
|
||||
MissingConfigError,
|
||||
@@ -60,6 +51,8 @@ from backend.util.exceptions import (
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .models import ProviderConstants, ProviderNamesResponse, get_all_provider_names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.integrations.oauth import BaseOAuthHandler
|
||||
|
||||
@@ -178,6 +171,7 @@ async def callback(
|
||||
f"Successfully processed OAuth callback for user {user_id} "
|
||||
f"and provider {provider.value}"
|
||||
)
|
||||
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
provider=credentials.provider,
|
||||
@@ -196,6 +190,7 @@ async def list_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
@@ -218,6 +213,7 @@ async def list_credentials_by_provider(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
@@ -381,7 +377,6 @@ async def webhook_ingress_generic(
|
||||
return
|
||||
|
||||
await complete_onboarding_step(user_id, OnboardingStep.TRIGGER_WEBHOOK)
|
||||
await increment_runs(user_id)
|
||||
|
||||
# Execute all triggers concurrently for better performance
|
||||
tasks = []
|
||||
@@ -834,6 +829,18 @@ async def list_providers() -> List[str]:
|
||||
return all_providers
|
||||
|
||||
|
||||
@router.get("/providers/system", response_model=List[str])
|
||||
async def list_system_providers() -> List[str]:
|
||||
"""
|
||||
Get a list of providers that have platform credits (system credentials) available.
|
||||
|
||||
These providers can be used without the user providing their own API keys.
|
||||
"""
|
||||
from backend.integrations.credentials_store import SYSTEM_PROVIDERS
|
||||
|
||||
return list(SYSTEM_PROVIDERS)
|
||||
|
||||
|
||||
@router.get("/providers/names", response_model=ProviderNamesResponse)
|
||||
async def get_provider_names() -> ProviderNamesResponse:
|
||||
"""
|
||||
@@ -4,16 +4,14 @@ from typing import Literal, Optional
|
||||
|
||||
import fastapi
|
||||
import prisma.errors
|
||||
import prisma.fields
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.api.features.store.exceptions as store_exceptions
|
||||
import backend.api.features.store.image_gen as store_image_gen
|
||||
import backend.api.features.store.media as store_media
|
||||
import backend.data.graph as graph_db
|
||||
import backend.data.integrations as integrations_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
import backend.server.v2.store.image_gen as store_image_gen
|
||||
import backend.server.v2.store.media as store_media
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.db import transaction
|
||||
from backend.data.execution import get_graph_execution
|
||||
@@ -28,6 +26,8 @@ from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.settings import Config
|
||||
|
||||
from . import model as library_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
integration_creds_manager = IntegrationCredentialsManager()
|
||||
@@ -489,7 +489,7 @@ async def update_agent_version_in_library(
|
||||
agent_graph_version: int,
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Updates the agent version in the library if useGraphIsActiveVersion is True.
|
||||
Updates the agent version in the library for any agent owned by the user.
|
||||
|
||||
Args:
|
||||
user_id: Owner of the LibraryAgent.
|
||||
@@ -498,20 +498,31 @@ async def update_agent_version_in_library(
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's an error with the update.
|
||||
NotFoundError: If no library agent is found for this user and agent.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Updating agent version in library for user #{user_id}, "
|
||||
f"agent #{agent_graph_id} v{agent_graph_version}"
|
||||
)
|
||||
try:
|
||||
library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise(
|
||||
async with transaction() as tx:
|
||||
library_agent = await prisma.models.LibraryAgent.prisma(tx).find_first_or_raise(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": agent_graph_id,
|
||||
"useGraphIsActiveVersion": True,
|
||||
},
|
||||
)
|
||||
lib = await prisma.models.LibraryAgent.prisma().update(
|
||||
|
||||
# Delete any conflicting LibraryAgent for the target version
|
||||
await prisma.models.LibraryAgent.prisma(tx).delete_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": agent_graph_id,
|
||||
"agentGraphVersion": agent_graph_version,
|
||||
"id": {"not": library_agent.id},
|
||||
}
|
||||
)
|
||||
|
||||
lib = await prisma.models.LibraryAgent.prisma(tx).update(
|
||||
where={"id": library_agent.id},
|
||||
data={
|
||||
"AgentGraph": {
|
||||
@@ -525,19 +536,20 @@ async def update_agent_version_in_library(
|
||||
},
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
if lib is None:
|
||||
raise NotFoundError(f"Library agent {library_agent.id} not found")
|
||||
|
||||
return library_model.LibraryAgent.from_db(lib)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating agent version in library: {e}")
|
||||
raise DatabaseError("Failed to update agent version in library") from e
|
||||
if lib is None:
|
||||
raise NotFoundError(
|
||||
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
|
||||
)
|
||||
|
||||
return library_model.LibraryAgent.from_db(lib)
|
||||
|
||||
|
||||
async def update_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str,
|
||||
auto_update_version: Optional[bool] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
is_favorite: Optional[bool] = None,
|
||||
is_archived: Optional[bool] = None,
|
||||
is_deleted: Optional[Literal[False]] = None,
|
||||
@@ -550,6 +562,7 @@ async def update_library_agent(
|
||||
library_agent_id: The ID of the LibraryAgent to update.
|
||||
user_id: The owner of this LibraryAgent.
|
||||
auto_update_version: Whether the agent should auto-update to active version.
|
||||
graph_version: Specific graph version to update to.
|
||||
is_favorite: Whether this agent is marked as a favorite.
|
||||
is_archived: Whether this agent is archived.
|
||||
settings: User-specific settings for this library agent.
|
||||
@@ -563,8 +576,8 @@ async def update_library_agent(
|
||||
"""
|
||||
logger.debug(
|
||||
f"Updating library agent {library_agent_id} for user {user_id} with "
|
||||
f"auto_update_version={auto_update_version}, is_favorite={is_favorite}, "
|
||||
f"is_archived={is_archived}, settings={settings}"
|
||||
f"auto_update_version={auto_update_version}, graph_version={graph_version}, "
|
||||
f"is_favorite={is_favorite}, is_archived={is_archived}, settings={settings}"
|
||||
)
|
||||
update_fields: prisma.types.LibraryAgentUpdateManyMutationInput = {}
|
||||
if auto_update_version is not None:
|
||||
@@ -581,10 +594,23 @@ async def update_library_agent(
|
||||
update_fields["isDeleted"] = is_deleted
|
||||
if settings is not None:
|
||||
update_fields["settings"] = SafeJson(settings.model_dump())
|
||||
if not update_fields:
|
||||
raise ValueError("No values were passed to update")
|
||||
|
||||
try:
|
||||
# If graph_version is provided, update to that specific version
|
||||
if graph_version is not None:
|
||||
# Get the current agent to find its graph_id
|
||||
agent = await get_library_agent(id=library_agent_id, user_id=user_id)
|
||||
# Update to the specified version using existing function
|
||||
return await update_agent_version_in_library(
|
||||
user_id=user_id,
|
||||
agent_graph_id=agent.graph_id,
|
||||
agent_graph_version=graph_version,
|
||||
)
|
||||
|
||||
# Otherwise, just update the simple fields
|
||||
if not update_fields:
|
||||
raise ValueError("No values were passed to update")
|
||||
|
||||
n_updated = await prisma.models.LibraryAgent.prisma().update_many(
|
||||
where={"id": library_agent_id, "userId": user_id},
|
||||
data=update_fields,
|
||||
@@ -810,6 +836,7 @@ async def add_store_agent_to_library(
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": SafeJson(
|
||||
_initialize_graph_settings(graph_model).model_dump()
|
||||
),
|
||||
@@ -1,16 +1,15 @@
|
||||
from datetime import datetime
|
||||
|
||||
import prisma.enums
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.api.features.store.exceptions
|
||||
from backend.data.db import connect
|
||||
from backend.data.includes import library_agent_include
|
||||
|
||||
from . import db
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agents(mocker):
|
||||
@@ -88,7 +87,7 @@ async def test_add_agent_to_library(mocker):
|
||||
await connect()
|
||||
|
||||
# Mock the transaction context
|
||||
mock_transaction = mocker.patch("backend.server.v2.library.db.transaction")
|
||||
mock_transaction = mocker.patch("backend.api.features.library.db.transaction")
|
||||
mock_transaction.return_value.__aenter__ = mocker.AsyncMock(return_value=None)
|
||||
mock_transaction.return_value.__aexit__ = mocker.AsyncMock(return_value=None)
|
||||
# Mock data
|
||||
@@ -151,7 +150,7 @@ async def test_add_agent_to_library(mocker):
|
||||
)
|
||||
|
||||
# Mock graph_db.get_graph function that's called to check for HITL blocks
|
||||
mock_graph_db = mocker.patch("backend.server.v2.library.db.graph_db")
|
||||
mock_graph_db = mocker.patch("backend.api.features.library.db.graph_db")
|
||||
mock_graph_model = mocker.Mock()
|
||||
mock_graph_model.nodes = (
|
||||
[]
|
||||
@@ -159,7 +158,9 @@ async def test_add_agent_to_library(mocker):
|
||||
mock_graph_db.get_graph = mocker.AsyncMock(return_value=mock_graph_model)
|
||||
|
||||
# Mock the model conversion
|
||||
mock_from_db = mocker.patch("backend.server.v2.library.model.LibraryAgent.from_db")
|
||||
mock_from_db = mocker.patch(
|
||||
"backend.api.features.library.model.LibraryAgent.from_db"
|
||||
)
|
||||
mock_from_db.return_value = mocker.Mock()
|
||||
|
||||
# Call function
|
||||
@@ -217,7 +218,7 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
)
|
||||
|
||||
# Call function and verify exception
|
||||
with pytest.raises(backend.server.v2.store.exceptions.AgentNotFoundError):
|
||||
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
|
||||
await db.add_store_agent_to_library("version123", "test-user")
|
||||
|
||||
# Verify mock called correctly
|
||||
@@ -48,6 +48,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
owner_user_id: str # ID of user who owns/created this agent graph
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -163,6 +164,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id=agent.id,
|
||||
graph_id=agent.agentGraphId,
|
||||
graph_version=agent.agentGraphVersion,
|
||||
owner_user_id=agent.userId,
|
||||
image_url=agent.imageUrl,
|
||||
creator_name=creator_name,
|
||||
creator_image_url=creator_image_url,
|
||||
@@ -385,6 +387,9 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
|
||||
auto_update_version: Optional[bool] = pydantic.Field(
|
||||
default=None, description="Auto-update the agent version"
|
||||
)
|
||||
graph_version: Optional[int] = pydantic.Field(
|
||||
default=None, description="Specific graph version to update to"
|
||||
)
|
||||
is_favorite: Optional[bool] = pydantic.Field(
|
||||
default=None, description="Mark the agent as a favorite"
|
||||
)
|
||||
@@ -3,7 +3,7 @@ import datetime
|
||||
import prisma.models
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
from . import model as library_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -6,12 +6,13 @@ from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||
from fastapi.responses import Response
|
||||
from prisma.enums import OnboardingStep
|
||||
|
||||
import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
import backend.api.features.store.exceptions as store_exceptions
|
||||
from backend.data.onboarding import complete_onboarding_step
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .. import db as library_db
|
||||
from .. import model as library_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
@@ -284,6 +285,7 @@ async def update_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
auto_update_version=payload.auto_update_version,
|
||||
graph_version=payload.graph_version,
|
||||
is_favorite=payload.is_favorite,
|
||||
is_archived=payload.is_archived,
|
||||
settings=payload.settings,
|
||||
@@ -4,19 +4,19 @@ from typing import Any, Optional
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
from backend.data.execution import GraphExecutionMeta
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.integrations import get_webhook
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.onboarding import increment_runs
|
||||
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.integrations.webhooks.utils import setup_webhook_for_block
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .. import db
|
||||
from .. import model as models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
credentials_manager = IntegrationCredentialsManager()
|
||||
@@ -402,8 +402,6 @@ async def execute_preset(
|
||||
merged_node_input = preset.inputs | inputs
|
||||
merged_credential_inputs = preset.credentials | credential_inputs
|
||||
|
||||
await increment_runs(user_id)
|
||||
|
||||
return await add_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=preset.graph_id,
|
||||
@@ -7,10 +7,11 @@ import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
from backend.server.v2.library.routes import router as library_router
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import model as library_model
|
||||
from .routes import router as library_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(library_router)
|
||||
|
||||
@@ -41,6 +42,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -63,6 +65,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-2",
|
||||
graph_id="test-agent-2",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
@@ -86,7 +89,7 @@ async def test_get_library_agents_success(
|
||||
total_items=2, total_pages=1, current_page=1, page_size=50
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.list_library_agents")
|
||||
mock_db_call = mocker.patch("backend.api.features.library.db.list_library_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/agents?search_term=test")
|
||||
@@ -112,7 +115,7 @@ async def test_get_library_agents_success(
|
||||
|
||||
|
||||
def test_get_library_agents_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.list_library_agents")
|
||||
mock_db_call = mocker.patch("backend.api.features.library.db.list_library_agents")
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.get("/agents?search_term=test")
|
||||
@@ -137,6 +140,7 @@ async def test_get_favorite_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Favorite Agent 1",
|
||||
description="Test Favorite Description 1",
|
||||
image_url=None,
|
||||
@@ -161,7 +165,7 @@ async def test_get_favorite_library_agents_success(
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
||||
"backend.api.features.library.db.list_favorite_library_agents"
|
||||
)
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
@@ -184,7 +188,7 @@ def test_get_favorite_library_agents_error(
|
||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||
):
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
||||
"backend.api.features.library.db.list_favorite_library_agents"
|
||||
)
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
@@ -204,6 +208,7 @@ def test_add_agent_to_library_success(
|
||||
id="test-library-agent-id",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -223,11 +228,11 @@ def test_add_agent_to_library_success(
|
||||
)
|
||||
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.add_store_agent_to_library"
|
||||
"backend.api.features.library.db.add_store_agent_to_library"
|
||||
)
|
||||
mock_db_call.return_value = mock_library_agent
|
||||
mock_complete_onboarding = mocker.patch(
|
||||
"backend.server.v2.library.routes.agents.complete_onboarding_step",
|
||||
"backend.api.features.library.routes.agents.complete_onboarding_step",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
@@ -249,7 +254,7 @@ def test_add_agent_to_library_success(
|
||||
|
||||
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.add_store_agent_to_library"
|
||||
"backend.api.features.library.db.add_store_agent_to_library"
|
||||
)
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
833
autogpt_platform/backend/backend/api/features/oauth.py
Normal file
833
autogpt_platform/backend/backend/api/features/oauth.py
Normal file
@@ -0,0 +1,833 @@
|
||||
"""
|
||||
OAuth 2.0 Provider Endpoints
|
||||
|
||||
Implements OAuth 2.0 Authorization Code flow with PKCE support.
|
||||
|
||||
Flow:
|
||||
1. User clicks "Login with AutoGPT" in 3rd party app
|
||||
2. App redirects user to /auth/authorize with client_id, redirect_uri, scope, state
|
||||
3. User sees consent screen (if not already logged in, redirects to login first)
|
||||
4. User approves → backend creates authorization code
|
||||
5. User redirected back to app with code
|
||||
6. App exchanges code for access/refresh tokens at /api/oauth/token
|
||||
7. App uses access token to call external API endpoints
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import APIRouter, Body, HTTPException, Security, UploadFile, status
|
||||
from gcloud.aio import storage as async_storage
|
||||
from PIL import Image
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.auth.oauth import (
|
||||
InvalidClientError,
|
||||
InvalidGrantError,
|
||||
OAuthApplicationInfo,
|
||||
TokenIntrospectionResult,
|
||||
consume_authorization_code,
|
||||
create_access_token,
|
||||
create_authorization_code,
|
||||
create_refresh_token,
|
||||
get_oauth_application,
|
||||
get_oauth_application_by_id,
|
||||
introspect_token,
|
||||
list_user_oauth_applications,
|
||||
refresh_tokens,
|
||||
revoke_access_token,
|
||||
revoke_refresh_token,
|
||||
update_oauth_application,
|
||||
validate_client_credentials,
|
||||
validate_redirect_uri,
|
||||
validate_scopes,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""OAuth 2.0 token response"""
|
||||
|
||||
token_type: Literal["Bearer"] = "Bearer"
|
||||
access_token: str
|
||||
access_token_expires_at: datetime
|
||||
refresh_token: str
|
||||
refresh_token_expires_at: datetime
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""OAuth 2.0 error response"""
|
||||
|
||||
error: str
|
||||
error_description: Optional[str] = None
|
||||
|
||||
|
||||
class OAuthApplicationPublicInfo(BaseModel):
|
||||
"""Public information about an OAuth application (for consent screen)"""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Application Info Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/app/{client_id}",
|
||||
responses={
|
||||
404: {"description": "Application not found or disabled"},
|
||||
},
|
||||
)
|
||||
async def get_oauth_app_info(
|
||||
client_id: str, user_id: str = Security(get_user_id)
|
||||
) -> OAuthApplicationPublicInfo:
|
||||
"""
|
||||
Get public information about an OAuth application.
|
||||
|
||||
This endpoint is used by the consent screen to display application details
|
||||
to the user before they authorize access.
|
||||
|
||||
Returns:
|
||||
- name: Application name
|
||||
- description: Application description (if provided)
|
||||
- scopes: List of scopes the application is allowed to request
|
||||
"""
|
||||
app = await get_oauth_application(client_id)
|
||||
if not app or not app.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found",
|
||||
)
|
||||
|
||||
return OAuthApplicationPublicInfo(
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
logo_url=app.logo_url,
|
||||
scopes=[s.value for s in app.scopes],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authorization Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AuthorizeRequest(BaseModel):
|
||||
"""OAuth 2.0 authorization request"""
|
||||
|
||||
client_id: str = Field(description="Client identifier")
|
||||
redirect_uri: str = Field(description="Redirect URI")
|
||||
scopes: list[str] = Field(description="List of scopes")
|
||||
state: str = Field(description="Anti-CSRF token from client")
|
||||
response_type: str = Field(
|
||||
default="code", description="Must be 'code' for authorization code flow"
|
||||
)
|
||||
code_challenge: str = Field(description="PKCE code challenge (required)")
|
||||
code_challenge_method: Literal["S256", "plain"] = Field(
|
||||
default="S256", description="PKCE code challenge method (S256 recommended)"
|
||||
)
|
||||
|
||||
|
||||
class AuthorizeResponse(BaseModel):
|
||||
"""OAuth 2.0 authorization response with redirect URL"""
|
||||
|
||||
redirect_url: str = Field(description="URL to redirect the user to")
|
||||
|
||||
|
||||
@router.post("/authorize")
|
||||
async def authorize(
|
||||
request: AuthorizeRequest = Body(),
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> AuthorizeResponse:
|
||||
"""
|
||||
OAuth 2.0 Authorization Endpoint
|
||||
|
||||
User must be logged in (authenticated with Supabase JWT).
|
||||
This endpoint creates an authorization code and returns a redirect URL.
|
||||
|
||||
PKCE (Proof Key for Code Exchange) is REQUIRED for all authorization requests.
|
||||
|
||||
The frontend consent screen should call this endpoint after the user approves,
|
||||
then redirect the user to the returned `redirect_url`.
|
||||
|
||||
Request Body:
|
||||
- client_id: The OAuth application's client ID
|
||||
- redirect_uri: Where to redirect after authorization (must match registered URI)
|
||||
- scopes: List of permissions (e.g., "EXECUTE_GRAPH READ_GRAPH")
|
||||
- state: Anti-CSRF token provided by client (will be returned in redirect)
|
||||
- response_type: Must be "code" (for authorization code flow)
|
||||
- code_challenge: PKCE code challenge (required)
|
||||
- code_challenge_method: "S256" (recommended) or "plain"
|
||||
|
||||
Returns:
|
||||
- redirect_url: The URL to redirect the user to (includes authorization code)
|
||||
|
||||
Error cases return a redirect_url with error parameters, or raise HTTPException
|
||||
for critical errors (like invalid redirect_uri).
|
||||
"""
|
||||
try:
|
||||
# Validate response_type
|
||||
if request.response_type != "code":
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"unsupported_response_type",
|
||||
"Only 'code' response type is supported",
|
||||
)
|
||||
|
||||
# Get application
|
||||
app = await get_oauth_application(request.client_id)
|
||||
if not app:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_client",
|
||||
"Unknown client_id",
|
||||
)
|
||||
|
||||
if not app.is_active:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_client",
|
||||
"Application is not active",
|
||||
)
|
||||
|
||||
# Validate redirect URI
|
||||
if not validate_redirect_uri(app, request.redirect_uri):
|
||||
# For invalid redirect_uri, we can't redirect safely
|
||||
# Must return error instead
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"Invalid redirect_uri. "
|
||||
f"Must be one of: {', '.join(app.redirect_uris)}"
|
||||
),
|
||||
)
|
||||
|
||||
# Parse and validate scopes
|
||||
try:
|
||||
requested_scopes = [APIKeyPermission(s.strip()) for s in request.scopes]
|
||||
except ValueError as e:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
f"Invalid scope: {e}",
|
||||
)
|
||||
|
||||
if not requested_scopes:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
"At least one scope is required",
|
||||
)
|
||||
|
||||
if not validate_scopes(app, requested_scopes):
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
"Application is not authorized for all requested scopes. "
|
||||
f"Allowed: {', '.join(s.value for s in app.scopes)}",
|
||||
)
|
||||
|
||||
# Create authorization code
|
||||
auth_code = await create_authorization_code(
|
||||
application_id=app.id,
|
||||
user_id=user_id,
|
||||
scopes=requested_scopes,
|
||||
redirect_uri=request.redirect_uri,
|
||||
code_challenge=request.code_challenge,
|
||||
code_challenge_method=request.code_challenge_method,
|
||||
)
|
||||
|
||||
# Build redirect URL with authorization code
|
||||
params = {
|
||||
"code": auth_code.code,
|
||||
"state": request.state,
|
||||
}
|
||||
redirect_url = f"{request.redirect_uri}?{urlencode(params)}"
|
||||
|
||||
logger.info(
|
||||
f"Authorization code issued for user #{user_id} "
|
||||
f"and app {app.name} (#{app.id})"
|
||||
)
|
||||
|
||||
return AuthorizeResponse(redirect_url=redirect_url)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in authorization endpoint: {e}", exc_info=True)
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"server_error",
|
||||
"An unexpected error occurred",
|
||||
)
|
||||
|
||||
|
||||
def _error_redirect_url(
|
||||
redirect_uri: str,
|
||||
state: str,
|
||||
error: str,
|
||||
error_description: Optional[str] = None,
|
||||
) -> AuthorizeResponse:
|
||||
"""Helper to build redirect URL with OAuth error parameters"""
|
||||
params = {
|
||||
"error": error,
|
||||
"state": state,
|
||||
}
|
||||
if error_description:
|
||||
params["error_description"] = error_description
|
||||
|
||||
redirect_url = f"{redirect_uri}?{urlencode(params)}"
|
||||
return AuthorizeResponse(redirect_url=redirect_url)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenRequestByCode(BaseModel):
|
||||
grant_type: Literal["authorization_code"]
|
||||
code: str = Field(description="Authorization code")
|
||||
redirect_uri: str = Field(
|
||||
description="Redirect URI (must match authorization request)"
|
||||
)
|
||||
client_id: str
|
||||
client_secret: str
|
||||
code_verifier: str = Field(description="PKCE code verifier")
|
||||
|
||||
|
||||
class TokenRequestByRefreshToken(BaseModel):
|
||||
grant_type: Literal["refresh_token"]
|
||||
refresh_token: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
|
||||
|
||||
@router.post("/token")
|
||||
async def token(
|
||||
request: TokenRequestByCode | TokenRequestByRefreshToken = Body(),
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
OAuth 2.0 Token Endpoint
|
||||
|
||||
Exchanges authorization code or refresh token for access token.
|
||||
|
||||
Grant Types:
|
||||
1. authorization_code: Exchange authorization code for tokens
|
||||
- Required: grant_type, code, redirect_uri, client_id, client_secret
|
||||
- Optional: code_verifier (required if PKCE was used)
|
||||
|
||||
2. refresh_token: Exchange refresh token for new access token
|
||||
- Required: grant_type, refresh_token, client_id, client_secret
|
||||
|
||||
Returns:
|
||||
- access_token: Bearer token for API access (1 hour TTL)
|
||||
- token_type: "Bearer"
|
||||
- expires_in: Seconds until access token expires
|
||||
- refresh_token: Token for refreshing access (30 days TTL)
|
||||
- scopes: List of scopes
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
app = await validate_client_credentials(
|
||||
request.client_id, request.client_secret
|
||||
)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Handle authorization_code grant
|
||||
if request.grant_type == "authorization_code":
|
||||
# Consume authorization code
|
||||
try:
|
||||
user_id, scopes = await consume_authorization_code(
|
||||
code=request.code,
|
||||
application_id=app.id,
|
||||
redirect_uri=request.redirect_uri,
|
||||
code_verifier=request.code_verifier,
|
||||
)
|
||||
except InvalidGrantError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Create access and refresh tokens
|
||||
access_token = await create_access_token(app.id, user_id, scopes)
|
||||
refresh_token = await create_refresh_token(app.id, user_id, scopes)
|
||||
|
||||
logger.info(
|
||||
f"Access token issued for user #{user_id} and app {app.name} (#{app.id})"
|
||||
"via authorization code"
|
||||
)
|
||||
|
||||
if not access_token.token or not refresh_token.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate tokens",
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
token_type="Bearer",
|
||||
access_token=access_token.token.get_secret_value(),
|
||||
access_token_expires_at=access_token.expires_at,
|
||||
refresh_token=refresh_token.token.get_secret_value(),
|
||||
refresh_token_expires_at=refresh_token.expires_at,
|
||||
scopes=list(s.value for s in scopes),
|
||||
)
|
||||
|
||||
# Handle refresh_token grant
|
||||
elif request.grant_type == "refresh_token":
|
||||
# Refresh access token
|
||||
try:
|
||||
new_access_token, new_refresh_token = await refresh_tokens(
|
||||
request.refresh_token, app.id
|
||||
)
|
||||
except InvalidGrantError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Tokens refreshed for user #{new_access_token.user_id} "
|
||||
f"by app {app.name} (#{app.id})"
|
||||
)
|
||||
|
||||
if not new_access_token.token or not new_refresh_token.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate tokens",
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
token_type="Bearer",
|
||||
access_token=new_access_token.token.get_secret_value(),
|
||||
access_token_expires_at=new_access_token.expires_at,
|
||||
refresh_token=new_refresh_token.token.get_secret_value(),
|
||||
refresh_token_expires_at=new_refresh_token.expires_at,
|
||||
scopes=list(s.value for s in new_access_token.scopes),
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported grant_type: {request.grant_type}. "
|
||||
"Must be 'authorization_code' or 'refresh_token'",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Introspection Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post("/introspect")
|
||||
async def introspect(
|
||||
token: str = Body(description="Token to introspect"),
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
|
||||
None, description="Hint about token type ('access_token' or 'refresh_token')"
|
||||
),
|
||||
client_id: str = Body(description="Client identifier"),
|
||||
client_secret: str = Body(description="Client secret"),
|
||||
) -> TokenIntrospectionResult:
|
||||
"""
|
||||
OAuth 2.0 Token Introspection Endpoint (RFC 7662)
|
||||
|
||||
Allows clients to check if a token is valid and get its metadata.
|
||||
|
||||
Returns:
|
||||
- active: Whether the token is currently active
|
||||
- scopes: List of authorized scopes (if active)
|
||||
- client_id: The client the token was issued to (if active)
|
||||
- user_id: The user the token represents (if active)
|
||||
- exp: Expiration timestamp (if active)
|
||||
- token_type: "access_token" or "refresh_token" (if active)
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
await validate_client_credentials(client_id, client_secret)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Introspect the token
|
||||
return await introspect_token(token, token_type_hint)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Revocation Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post("/revoke")
|
||||
async def revoke(
|
||||
token: str = Body(description="Token to revoke"),
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
|
||||
None, description="Hint about token type ('access_token' or 'refresh_token')"
|
||||
),
|
||||
client_id: str = Body(description="Client identifier"),
|
||||
client_secret: str = Body(description="Client secret"),
|
||||
):
|
||||
"""
|
||||
OAuth 2.0 Token Revocation Endpoint (RFC 7009)
|
||||
|
||||
Allows clients to revoke an access or refresh token.
|
||||
|
||||
Note: Revoking a refresh token does NOT revoke associated access tokens.
|
||||
Revoking an access token does NOT revoke the associated refresh token.
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
app = await validate_client_credentials(client_id, client_secret)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Try to revoke as access token first
|
||||
# Note: We pass app.id to ensure the token belongs to the authenticated app
|
||||
if token_type_hint != "refresh_token":
|
||||
revoked = await revoke_access_token(token, app.id)
|
||||
if revoked:
|
||||
logger.info(
|
||||
f"Access token revoked for app {app.name} (#{app.id}); "
|
||||
f"user #{revoked.user_id}"
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
# Try to revoke as refresh token
|
||||
revoked = await revoke_refresh_token(token, app.id)
|
||||
if revoked:
|
||||
logger.info(
|
||||
f"Refresh token revoked for app {app.name} (#{app.id}); "
|
||||
f"user #{revoked.user_id}"
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
# Per RFC 7009, revocation endpoint returns 200 even if token not found
|
||||
# or if token belongs to a different application.
|
||||
# This prevents token scanning attacks.
|
||||
logger.warning(f"Unsuccessful token revocation attempt by app {app.name} #{app.id}")
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Application Management Endpoints (for app owners)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get("/apps/mine")
|
||||
async def list_my_oauth_apps(
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> list[OAuthApplicationInfo]:
|
||||
"""
|
||||
List all OAuth applications owned by the current user.
|
||||
|
||||
Returns a list of OAuth applications with their details including:
|
||||
- id, name, description, logo_url
|
||||
- client_id (public identifier)
|
||||
- redirect_uris, grant_types, scopes
|
||||
- is_active status
|
||||
- created_at, updated_at timestamps
|
||||
|
||||
Note: client_secret is never returned for security reasons.
|
||||
"""
|
||||
return await list_user_oauth_applications(user_id)
|
||||
|
||||
|
||||
@router.patch("/apps/{app_id}/status")
|
||||
async def update_app_status(
|
||||
app_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
is_active: bool = Body(description="Whether the app should be active", embed=True),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Enable or disable an OAuth application.
|
||||
|
||||
Only the application owner can update the status.
|
||||
When disabled, the application cannot be used for new authorizations
|
||||
and existing access tokens will fail validation.
|
||||
|
||||
Returns the updated application info.
|
||||
"""
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
action = "enabled" if is_active else "disabled"
|
||||
logger.info(f"OAuth app {updated_app.name} (#{app_id}) {action} by user #{user_id}")
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
class UpdateAppLogoRequest(BaseModel):
|
||||
logo_url: str = Field(description="URL of the uploaded logo image")
|
||||
|
||||
|
||||
@router.patch("/apps/{app_id}/logo")
|
||||
async def update_app_logo(
|
||||
app_id: str,
|
||||
request: UpdateAppLogoRequest = Body(),
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Update the logo URL for an OAuth application.
|
||||
|
||||
Only the application owner can update the logo.
|
||||
The logo should be uploaded first using the media upload endpoint,
|
||||
then this endpoint is called with the resulting URL.
|
||||
|
||||
Logo requirements:
|
||||
- Must be square (1:1 aspect ratio)
|
||||
- Minimum 512x512 pixels
|
||||
- Maximum 2048x2048 pixels
|
||||
|
||||
Returns the updated application info.
|
||||
"""
|
||||
if (
|
||||
not (app := await get_oauth_application_by_id(app_id))
|
||||
or app.owner_id != user_id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth App not found",
|
||||
)
|
||||
|
||||
# Delete the current app logo file (if any and it's in our cloud storage)
|
||||
await _delete_app_current_logo_file(app)
|
||||
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
logo_url=request.logo_url,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth app {updated_app.name} (#{app_id}) logo updated by user #{user_id}"
|
||||
)
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
# Logo upload constraints
|
||||
LOGO_MIN_SIZE = 512
|
||||
LOGO_MAX_SIZE = 2048
|
||||
LOGO_ALLOWED_TYPES = {"image/jpeg", "image/png", "image/webp"}
|
||||
LOGO_MAX_FILE_SIZE = 3 * 1024 * 1024 # 3MB
|
||||
|
||||
|
||||
@router.post("/apps/{app_id}/logo/upload")
|
||||
async def upload_app_logo(
|
||||
app_id: str,
|
||||
file: UploadFile,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Upload a logo image for an OAuth application.
|
||||
|
||||
Requirements:
|
||||
- Image must be square (1:1 aspect ratio)
|
||||
- Minimum 512x512 pixels
|
||||
- Maximum 2048x2048 pixels
|
||||
- Allowed formats: JPEG, PNG, WebP
|
||||
- Maximum file size: 3MB
|
||||
|
||||
The image is uploaded to cloud storage and the app's logoUrl is updated.
|
||||
Returns the updated application info.
|
||||
"""
|
||||
# Verify ownership to reduce vulnerability to DoS(torage) or DoM(oney) attacks
|
||||
if (
|
||||
not (app := await get_oauth_application_by_id(app_id))
|
||||
or app.owner_id != user_id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth App not found",
|
||||
)
|
||||
|
||||
# Check GCS configuration
|
||||
if not settings.config.media_gcs_bucket_name:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Media storage is not configured",
|
||||
)
|
||||
|
||||
# Validate content type
|
||||
content_type = file.content_type
|
||||
if content_type not in LOGO_ALLOWED_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid file type. Allowed: JPEG, PNG, WebP. Got: {content_type}",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
file_bytes = await file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading logo file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to read uploaded file",
|
||||
)
|
||||
|
||||
# Check file size
|
||||
if len(file_bytes) > LOGO_MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"File too large. "
|
||||
f"Maximum size is {LOGO_MAX_FILE_SIZE // 1024 // 1024}MB"
|
||||
),
|
||||
)
|
||||
|
||||
# Validate image dimensions
|
||||
try:
|
||||
image = Image.open(io.BytesIO(file_bytes))
|
||||
width, height = image.size
|
||||
|
||||
if width != height:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo must be square. Got {width}x{height}",
|
||||
)
|
||||
|
||||
if width < LOGO_MIN_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo too small. Minimum {LOGO_MIN_SIZE}x{LOGO_MIN_SIZE}. "
|
||||
f"Got {width}x{height}",
|
||||
)
|
||||
|
||||
if width > LOGO_MAX_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo too large. Maximum {LOGO_MAX_SIZE}x{LOGO_MAX_SIZE}. "
|
||||
f"Got {width}x{height}",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating logo image: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid image file",
|
||||
)
|
||||
|
||||
# Scan for viruses
|
||||
filename = file.filename or "logo"
|
||||
await scan_content_safe(file_bytes, filename=filename)
|
||||
|
||||
# Generate unique filename
|
||||
file_ext = os.path.splitext(filename)[1].lower() or ".png"
|
||||
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||
storage_path = f"oauth-apps/{app_id}/logo/{unique_filename}"
|
||||
|
||||
# Upload to GCS
|
||||
try:
|
||||
async with async_storage.Storage() as async_client:
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
|
||||
await async_client.upload(
|
||||
bucket_name, storage_path, file_bytes, content_type=content_type
|
||||
)
|
||||
|
||||
logo_url = f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading logo to GCS: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to upload logo",
|
||||
)
|
||||
|
||||
# Delete the current app logo file (if any and it's in our cloud storage)
|
||||
await _delete_app_current_logo_file(app)
|
||||
|
||||
# Update the app with the new logo URL
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
logo_url=logo_url,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth app {updated_app.name} (#{app_id}) logo uploaded by user #{user_id}"
|
||||
)
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
async def _delete_app_current_logo_file(app: OAuthApplicationInfo):
|
||||
"""
|
||||
Delete the current logo file for the given app, if there is one in our cloud storage
|
||||
"""
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
storage_base_url = f"https://storage.googleapis.com/{bucket_name}/"
|
||||
|
||||
if app.logo_url and app.logo_url.startswith(storage_base_url):
|
||||
# Parse blob path from URL: https://storage.googleapis.com/{bucket}/{path}
|
||||
old_path = app.logo_url.replace(storage_base_url, "")
|
||||
try:
|
||||
async with async_storage.Storage() as async_client:
|
||||
await async_client.delete(bucket_name, old_path)
|
||||
logger.info(f"Deleted old logo for OAuth app #{app.id}: {old_path}")
|
||||
except Exception as e:
|
||||
# Log but don't fail - the new logo was uploaded successfully
|
||||
logger.warning(
|
||||
f"Failed to delete old logo for OAuth app #{app.id}: {e}", exc_info=e
|
||||
)
|
||||
1784
autogpt_platform/backend/backend/api/features/oauth_test.py
Normal file
1784
autogpt_platform/backend/backend/api/features/oauth_test.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,9 +6,9 @@ import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.v2.otto.models as otto_models
|
||||
import backend.server.v2.otto.routes as otto_routes
|
||||
from backend.server.v2.otto.service import OttoService
|
||||
from . import models as otto_models
|
||||
from . import routes as otto_routes
|
||||
from .service import OttoService
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(otto_routes.router)
|
||||
@@ -4,12 +4,15 @@ from typing import Annotated
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Security
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from backend.api.utils.api_key_auth import APIKeyAuthenticator
|
||||
from backend.data.user import (
|
||||
get_user_by_email,
|
||||
set_user_email_verification,
|
||||
unsubscribe_user_by_token,
|
||||
)
|
||||
from backend.server.routers.postmark.models import (
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .models import (
|
||||
PostmarkBounceEnum,
|
||||
PostmarkBounceWebhook,
|
||||
PostmarkClickWebhook,
|
||||
@@ -19,8 +22,6 @@ from backend.server.routers.postmark.models import (
|
||||
PostmarkSubscriptionChangeWebhook,
|
||||
PostmarkWebhook,
|
||||
)
|
||||
from backend.server.utils.api_key_auth import APIKeyAuthenticator
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import Literal
|
||||
|
||||
import backend.server.v2.store.db
|
||||
from backend.util.cache import cached
|
||||
|
||||
from . import db as store_db
|
||||
|
||||
##############################################
|
||||
############### Caches #######################
|
||||
##############################################
|
||||
@@ -29,7 +30,7 @@ async def _get_cached_store_agents(
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store agents."""
|
||||
return await backend.server.v2.store.db.get_store_agents(
|
||||
return await store_db.get_store_agents(
|
||||
featured=featured,
|
||||
creators=[creator] if creator else None,
|
||||
sorted_by=sorted_by,
|
||||
@@ -42,10 +43,12 @@ async def _get_cached_store_agents(
|
||||
|
||||
# Cache individual agent details for 15 minutes
|
||||
@cached(maxsize=200, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
async def _get_cached_agent_details(
|
||||
username: str, agent_name: str, include_changelog: bool = False
|
||||
):
|
||||
"""Cached helper to get agent details."""
|
||||
return await backend.server.v2.store.db.get_store_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
return await store_db.get_store_agent_details(
|
||||
username=username, agent_name=agent_name, include_changelog=include_changelog
|
||||
)
|
||||
|
||||
|
||||
@@ -59,7 +62,7 @@ async def _get_cached_store_creators(
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store creators."""
|
||||
return await backend.server.v2.store.db.get_store_creators(
|
||||
return await store_db.get_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
@@ -72,6 +75,4 @@ async def _get_cached_store_creators(
|
||||
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await backend.server.v2.store.db.get_store_creator_details(
|
||||
username=username.lower()
|
||||
)
|
||||
return await store_db.get_store_creator_details(username=username.lower())
|
||||
@@ -0,0 +1,610 @@
|
||||
"""
|
||||
Content Type Handlers for Unified Embeddings
|
||||
|
||||
Pluggable system for different content sources (store agents, blocks, docs).
|
||||
Each handler knows how to fetch and process its content type for embedding.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.data.db import query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentItem:
|
||||
"""Represents a piece of content to be embedded."""
|
||||
|
||||
content_id: str # Unique identifier (DB ID or file path)
|
||||
content_type: ContentType
|
||||
searchable_text: str # Combined text for embedding
|
||||
metadata: dict[str, Any] # Content-specific metadata
|
||||
user_id: str | None = None # For user-scoped content
|
||||
|
||||
|
||||
class ContentHandler(ABC):
|
||||
"""Base handler for fetching and processing content for embeddings."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def content_type(self) -> ContentType:
|
||||
"""The ContentType this handler manages."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""
|
||||
Fetch items that don't have embeddings yet.
|
||||
|
||||
Args:
|
||||
batch_size: Maximum number of items to return
|
||||
|
||||
Returns:
|
||||
List of ContentItem objects ready for embedding
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get statistics about embedding coverage.
|
||||
|
||||
Returns:
|
||||
Dict with keys: total, with_embeddings, without_embeddings
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StoreAgentHandler(ContentHandler):
|
||||
"""Handler for marketplace store agent listings."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.STORE_AGENT
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch approved store listings without embeddings."""
|
||||
from backend.api.features.store.embeddings import build_searchable_text
|
||||
|
||||
missing = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
slv.id,
|
||||
slv.name,
|
||||
slv.description,
|
||||
slv."subHeading",
|
||||
slv.categories
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce
|
||||
ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
AND uce."contentId" IS NULL
|
||||
LIMIT $1
|
||||
""",
|
||||
batch_size,
|
||||
)
|
||||
|
||||
return [
|
||||
ContentItem(
|
||||
content_id=row["id"],
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
searchable_text=build_searchable_text(
|
||||
name=row["name"],
|
||||
description=row["description"],
|
||||
sub_heading=row["subHeading"],
|
||||
categories=row["categories"] or [],
|
||||
),
|
||||
metadata={
|
||||
"name": row["name"],
|
||||
"categories": row["categories"] or [],
|
||||
},
|
||||
user_id=None, # Store agents are public
|
||||
)
|
||||
for row in missing
|
||||
]
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about store agent embedding coverage."""
|
||||
# Count approved versions
|
||||
approved_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
AND "isDeleted" = false
|
||||
"""
|
||||
)
|
||||
total_approved = approved_result[0]["count"] if approved_result else 0
|
||||
|
||||
# Count versions with embeddings
|
||||
embedded_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
"""
|
||||
)
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_approved,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_approved - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
class BlockHandler(ContentHandler):
|
||||
"""Handler for block definitions (Python classes)."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.BLOCK
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch blocks without embeddings."""
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
# Get all available blocks
|
||||
all_blocks = get_blocks()
|
||||
|
||||
# Check which ones have embeddings
|
||||
if not all_blocks:
|
||||
return []
|
||||
|
||||
block_ids = list(all_blocks.keys())
|
||||
|
||||
# Query for existing embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
existing_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT "contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*block_ids,
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
missing_blocks = [
|
||||
(block_id, block_cls)
|
||||
for block_id, block_cls in all_blocks.items()
|
||||
if block_id not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem
|
||||
items = []
|
||||
for block_id, block_cls in missing_blocks[:batch_size]:
|
||||
try:
|
||||
block_instance = block_cls()
|
||||
|
||||
# Build searchable text from block metadata
|
||||
parts = []
|
||||
if hasattr(block_instance, "name") and block_instance.name:
|
||||
parts.append(block_instance.name)
|
||||
if (
|
||||
hasattr(block_instance, "description")
|
||||
and block_instance.description
|
||||
):
|
||||
parts.append(block_instance.description)
|
||||
if hasattr(block_instance, "categories") and block_instance.categories:
|
||||
# Convert BlockCategory enum to strings
|
||||
parts.append(
|
||||
" ".join(str(cat.value) for cat in block_instance.categories)
|
||||
)
|
||||
|
||||
# Add input/output schema info
|
||||
if hasattr(block_instance, "input_schema"):
|
||||
schema = block_instance.input_schema
|
||||
if hasattr(schema, "model_json_schema"):
|
||||
schema_dict = schema.model_json_schema()
|
||||
if "properties" in schema_dict:
|
||||
for prop_name, prop_info in schema_dict[
|
||||
"properties"
|
||||
].items():
|
||||
if "description" in prop_info:
|
||||
parts.append(
|
||||
f"{prop_name}: {prop_info['description']}"
|
||||
)
|
||||
|
||||
searchable_text = " ".join(parts)
|
||||
|
||||
# Convert categories set of enums to list of strings for JSON serialization
|
||||
categories = getattr(block_instance, "categories", set())
|
||||
categories_list = (
|
||||
[cat.value for cat in categories] if categories else []
|
||||
)
|
||||
|
||||
items.append(
|
||||
ContentItem(
|
||||
content_id=block_id,
|
||||
content_type=ContentType.BLOCK,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"name": getattr(block_instance, "name", ""),
|
||||
"categories": categories_list,
|
||||
},
|
||||
user_id=None, # Blocks are public
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process block {block_id}: {e}")
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about block embedding coverage."""
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
all_blocks = get_blocks()
|
||||
total_blocks = len(all_blocks)
|
||||
|
||||
if total_blocks == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
block_ids = list(all_blocks.keys())
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
|
||||
embedded_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*block_ids,
|
||||
)
|
||||
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_blocks,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_blocks - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarkdownSection:
|
||||
"""Represents a section of a markdown document."""
|
||||
|
||||
title: str # Section heading text
|
||||
content: str # Section content (including the heading line)
|
||||
level: int # Heading level (1 for #, 2 for ##, etc.)
|
||||
index: int # Section index within the document
|
||||
|
||||
|
||||
class DocumentationHandler(ContentHandler):
|
||||
"""Handler for documentation files (.md/.mdx).
|
||||
|
||||
Chunks documents by markdown headings to create multiple embeddings per file.
|
||||
Each section (## heading) becomes a separate embedding for better retrieval.
|
||||
"""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.DOCUMENTATION
|
||||
|
||||
def _get_docs_root(self) -> Path:
|
||||
"""Get the documentation root directory."""
|
||||
# content_handlers.py is at: backend/backend/api/features/store/content_handlers.py
|
||||
# Need to go up to project root then into docs/
|
||||
# In container: /app/autogpt_platform/backend/backend/api/features/store -> /app/docs
|
||||
# In development: /repo/autogpt_platform/backend/backend/api/features/store -> /repo/docs
|
||||
this_file = Path(
|
||||
__file__
|
||||
) # .../backend/backend/api/features/store/content_handlers.py
|
||||
project_root = (
|
||||
this_file.parent.parent.parent.parent.parent.parent.parent
|
||||
) # -> /app or /repo
|
||||
docs_root = project_root / "docs"
|
||||
return docs_root
|
||||
|
||||
def _extract_doc_title(self, file_path: Path) -> str:
|
||||
"""Extract the document title from a markdown file."""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
lines = content.split("\n")
|
||||
|
||||
# Try to extract title from first # heading
|
||||
for line in lines:
|
||||
if line.startswith("# "):
|
||||
return line[2:].strip()
|
||||
|
||||
# If no title found, use filename
|
||||
return file_path.stem.replace("-", " ").replace("_", " ").title()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read title from {file_path}: {e}")
|
||||
return file_path.stem.replace("-", " ").replace("_", " ").title()
|
||||
|
||||
def _chunk_markdown_by_headings(
|
||||
self, file_path: Path, min_heading_level: int = 2
|
||||
) -> list[MarkdownSection]:
|
||||
"""
|
||||
Split a markdown file into sections based on headings.
|
||||
|
||||
Args:
|
||||
file_path: Path to the markdown file
|
||||
min_heading_level: Minimum heading level to split on (default: 2 for ##)
|
||||
|
||||
Returns:
|
||||
List of MarkdownSection objects, one per section.
|
||||
If no headings found, returns a single section with all content.
|
||||
"""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read {file_path}: {e}")
|
||||
return []
|
||||
|
||||
lines = content.split("\n")
|
||||
sections: list[MarkdownSection] = []
|
||||
current_section_lines: list[str] = []
|
||||
current_title = ""
|
||||
current_level = 0
|
||||
section_index = 0
|
||||
doc_title = ""
|
||||
|
||||
for line in lines:
|
||||
# Check if line is a heading
|
||||
if line.startswith("#"):
|
||||
# Count heading level
|
||||
level = 0
|
||||
for char in line:
|
||||
if char == "#":
|
||||
level += 1
|
||||
else:
|
||||
break
|
||||
|
||||
heading_text = line[level:].strip()
|
||||
|
||||
# Track document title (level 1 heading)
|
||||
if level == 1 and not doc_title:
|
||||
doc_title = heading_text
|
||||
# Don't create a section for just the title - add it to first section
|
||||
current_section_lines.append(line)
|
||||
continue
|
||||
|
||||
# Check if this heading should start a new section
|
||||
if level >= min_heading_level:
|
||||
# Save previous section if it has content
|
||||
if current_section_lines:
|
||||
section_content = "\n".join(current_section_lines).strip()
|
||||
if section_content:
|
||||
# Use doc title for first section if no specific title
|
||||
title = current_title if current_title else doc_title
|
||||
if not title:
|
||||
title = file_path.stem.replace("-", " ").replace(
|
||||
"_", " "
|
||||
)
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
title=title,
|
||||
content=section_content,
|
||||
level=current_level if current_level else 1,
|
||||
index=section_index,
|
||||
)
|
||||
)
|
||||
section_index += 1
|
||||
|
||||
# Start new section
|
||||
current_section_lines = [line]
|
||||
current_title = heading_text
|
||||
current_level = level
|
||||
else:
|
||||
# Lower level heading (e.g., # when splitting on ##)
|
||||
current_section_lines.append(line)
|
||||
else:
|
||||
current_section_lines.append(line)
|
||||
|
||||
# Don't forget the last section
|
||||
if current_section_lines:
|
||||
section_content = "\n".join(current_section_lines).strip()
|
||||
if section_content:
|
||||
title = current_title if current_title else doc_title
|
||||
if not title:
|
||||
title = file_path.stem.replace("-", " ").replace("_", " ")
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
title=title,
|
||||
content=section_content,
|
||||
level=current_level if current_level else 1,
|
||||
index=section_index,
|
||||
)
|
||||
)
|
||||
|
||||
# If no sections were created (no headings found), create one section with all content
|
||||
if not sections and content.strip():
|
||||
title = (
|
||||
doc_title
|
||||
if doc_title
|
||||
else file_path.stem.replace("-", " ").replace("_", " ")
|
||||
)
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
title=title,
|
||||
content=content.strip(),
|
||||
level=1,
|
||||
index=0,
|
||||
)
|
||||
)
|
||||
|
||||
return sections
|
||||
|
||||
def _make_section_content_id(self, doc_path: str, section_index: int) -> str:
|
||||
"""Create a unique content ID for a document section.
|
||||
|
||||
Format: doc_path::section_index
|
||||
Example: 'platform/getting-started.md::0'
|
||||
"""
|
||||
return f"{doc_path}::{section_index}"
|
||||
|
||||
def _parse_section_content_id(self, content_id: str) -> tuple[str, int]:
|
||||
"""Parse a section content ID back into doc_path and section_index.
|
||||
|
||||
Returns: (doc_path, section_index)
|
||||
"""
|
||||
if "::" in content_id:
|
||||
parts = content_id.rsplit("::", 1)
|
||||
return parts[0], int(parts[1])
|
||||
# Legacy format (whole document)
|
||||
return content_id, 0
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch documentation sections without embeddings.
|
||||
|
||||
Chunks each document by markdown headings and creates embeddings for each section.
|
||||
Content IDs use the format: 'path/to/doc.md::section_index'
|
||||
"""
|
||||
docs_root = self._get_docs_root()
|
||||
|
||||
if not docs_root.exists():
|
||||
logger.warning(f"Documentation root not found: {docs_root}")
|
||||
return []
|
||||
|
||||
# Find all .md and .mdx files
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
|
||||
if not all_docs:
|
||||
return []
|
||||
|
||||
# Build list of all sections from all documents
|
||||
all_sections: list[tuple[str, Path, MarkdownSection]] = []
|
||||
for doc_file in all_docs:
|
||||
doc_path = str(doc_file.relative_to(docs_root))
|
||||
sections = self._chunk_markdown_by_headings(doc_file)
|
||||
for section in sections:
|
||||
all_sections.append((doc_path, doc_file, section))
|
||||
|
||||
if not all_sections:
|
||||
return []
|
||||
|
||||
# Generate content IDs for all sections
|
||||
section_content_ids = [
|
||||
self._make_section_content_id(doc_path, section.index)
|
||||
for doc_path, _, section in all_sections
|
||||
]
|
||||
|
||||
# Check which ones have embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(section_content_ids))])
|
||||
existing_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT "contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*section_content_ids,
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
|
||||
# Filter to missing sections
|
||||
missing_sections = [
|
||||
(doc_path, doc_file, section, content_id)
|
||||
for (doc_path, doc_file, section), content_id in zip(
|
||||
all_sections, section_content_ids
|
||||
)
|
||||
if content_id not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem (up to batch_size)
|
||||
items = []
|
||||
for doc_path, doc_file, section, content_id in missing_sections[:batch_size]:
|
||||
try:
|
||||
# Get document title for context
|
||||
doc_title = self._extract_doc_title(doc_file)
|
||||
|
||||
# Build searchable text with context
|
||||
# Include doc title and section title for better search relevance
|
||||
searchable_text = f"{doc_title} - {section.title}\n\n{section.content}"
|
||||
|
||||
items.append(
|
||||
ContentItem(
|
||||
content_id=content_id,
|
||||
content_type=ContentType.DOCUMENTATION,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"doc_title": doc_title,
|
||||
"section_title": section.title,
|
||||
"section_index": section.index,
|
||||
"heading_level": section.level,
|
||||
"path": doc_path,
|
||||
},
|
||||
user_id=None, # Documentation is public
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process section {content_id}: {e}")
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
def _get_all_section_content_ids(self, docs_root: Path) -> set[str]:
|
||||
"""Get all current section content IDs from the docs directory.
|
||||
|
||||
Used for stats and cleanup to know what sections should exist.
|
||||
"""
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
content_ids = set()
|
||||
|
||||
for doc_file in all_docs:
|
||||
doc_path = str(doc_file.relative_to(docs_root))
|
||||
sections = self._chunk_markdown_by_headings(doc_file)
|
||||
for section in sections:
|
||||
content_ids.add(self._make_section_content_id(doc_path, section.index))
|
||||
|
||||
return content_ids
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about documentation embedding coverage.
|
||||
|
||||
Counts sections (not documents) since each section gets its own embedding.
|
||||
"""
|
||||
docs_root = self._get_docs_root()
|
||||
|
||||
if not docs_root.exists():
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
# Get all section content IDs
|
||||
all_section_ids = self._get_all_section_content_ids(docs_root)
|
||||
total_sections = len(all_section_ids)
|
||||
|
||||
if total_sections == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
# Count embeddings in database for DOCUMENTATION type
|
||||
embedded_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{schema_prefix}"ContentType"
|
||||
"""
|
||||
)
|
||||
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_sections,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_sections - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
# Content handler registry
|
||||
CONTENT_HANDLERS: dict[ContentType, ContentHandler] = {
|
||||
ContentType.STORE_AGENT: StoreAgentHandler(),
|
||||
ContentType.BLOCK: BlockHandler(),
|
||||
ContentType.DOCUMENTATION: DocumentationHandler(),
|
||||
}
|
||||
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Integration tests for content handlers using real DB.
|
||||
|
||||
Run with: poetry run pytest backend/api/features/store/content_handlers_integration_test.py -xvs
|
||||
|
||||
These tests use the real database but mock OpenAI calls.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.store.content_handlers import (
|
||||
CONTENT_HANDLERS,
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
)
|
||||
from backend.api.features.store.embeddings import (
|
||||
EMBEDDING_DIM,
|
||||
backfill_all_content_types,
|
||||
ensure_content_embedding,
|
||||
get_embedding_stats,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_real_db():
|
||||
"""Test StoreAgentHandler with real database queries."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Get stats from real DB
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list (may be empty if all have embeddings)
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None
|
||||
assert item.content_type.value == "STORE_AGENT"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_real_db():
|
||||
"""Test BlockHandler with real database queries."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Get stats from real DB
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0 # Should have at least some blocks
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None # Should be block UUID
|
||||
assert item.content_type.value == "BLOCK"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_real_fs():
|
||||
"""Test DocumentationHandler with real filesystem."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Get stats from real filesystem
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None # Should be relative path
|
||||
assert item.content_type.value == "DOCUMENTATION"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_stats_all_types():
|
||||
"""Test get_embedding_stats aggregates all content types."""
|
||||
stats = await get_embedding_stats()
|
||||
|
||||
# Should have structure with by_type and totals
|
||||
assert "by_type" in stats
|
||||
assert "totals" in stats
|
||||
|
||||
# Check each content type is present
|
||||
by_type = stats["by_type"]
|
||||
assert "STORE_AGENT" in by_type
|
||||
assert "BLOCK" in by_type
|
||||
assert "DOCUMENTATION" in by_type
|
||||
|
||||
# Check totals are aggregated
|
||||
totals = stats["totals"]
|
||||
assert totals["total"] >= 0
|
||||
assert totals["with_embeddings"] >= 0
|
||||
assert totals["without_embeddings"] >= 0
|
||||
assert "coverage_percent" in totals
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
async def test_ensure_content_embedding_blocks(mock_generate):
|
||||
"""Test creating embeddings for blocks (mocked OpenAI)."""
|
||||
# Mock OpenAI to return fake embedding
|
||||
mock_generate.return_value = [0.1] * EMBEDDING_DIM
|
||||
|
||||
# Get one block without embedding
|
||||
handler = BlockHandler()
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
if not items:
|
||||
pytest.skip("No blocks without embeddings")
|
||||
|
||||
item = items[0]
|
||||
|
||||
# Try to create embedding (OpenAI mocked)
|
||||
result = await ensure_content_embedding(
|
||||
content_type=item.content_type,
|
||||
content_id=item.content_id,
|
||||
searchable_text=item.searchable_text,
|
||||
metadata=item.metadata,
|
||||
user_id=item.user_id,
|
||||
)
|
||||
|
||||
# Should succeed with mocked OpenAI
|
||||
assert result is True
|
||||
mock_generate.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
async def test_backfill_all_content_types_dry_run(mock_generate):
|
||||
"""Test backfill_all_content_types processes all handlers in order."""
|
||||
# Mock OpenAI to return fake embedding
|
||||
mock_generate.return_value = [0.1] * EMBEDDING_DIM
|
||||
|
||||
# Run backfill with batch_size=1 to process max 1 per type
|
||||
result = await backfill_all_content_types(batch_size=1)
|
||||
|
||||
# Should have results for all content types
|
||||
assert "by_type" in result
|
||||
assert "totals" in result
|
||||
|
||||
by_type = result["by_type"]
|
||||
assert "BLOCK" in by_type
|
||||
assert "STORE_AGENT" in by_type
|
||||
assert "DOCUMENTATION" in by_type
|
||||
|
||||
# Each type should have correct structure
|
||||
for content_type, type_result in by_type.items():
|
||||
assert "processed" in type_result
|
||||
assert "success" in type_result
|
||||
assert "failed" in type_result
|
||||
|
||||
# Totals should aggregate
|
||||
totals = result["totals"]
|
||||
assert totals["processed"] >= 0
|
||||
assert totals["success"] >= 0
|
||||
assert totals["failed"] >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handler_registry():
|
||||
"""Test all handlers are registered in correct order."""
|
||||
from prisma.enums import ContentType
|
||||
|
||||
# All three types should be registered
|
||||
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
|
||||
assert ContentType.BLOCK in CONTENT_HANDLERS
|
||||
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
|
||||
|
||||
# Check handler types
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
E2E tests for content handlers (blocks, store agents, documentation).
|
||||
|
||||
Tests the full flow: discovering content → generating embeddings → storing.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store.content_handlers import (
|
||||
CONTENT_HANDLERS,
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_missing_items(mocker):
|
||||
"""Test StoreAgentHandler fetches approved agents without embeddings."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock database query
|
||||
mock_missing = [
|
||||
{
|
||||
"id": "agent-1",
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"subHeading": "Test heading",
|
||||
"categories": ["AI", "Testing"],
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_missing,
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "agent-1"
|
||||
assert items[0].content_type == ContentType.STORE_AGENT
|
||||
assert "Test Agent" in items[0].searchable_text
|
||||
assert "A test agent" in items[0].searchable_text
|
||||
assert items[0].metadata["name"] == "Test Agent"
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_stats(mocker):
|
||||
"""Test StoreAgentHandler returns correct stats."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock approved count query
|
||||
mock_approved = [{"count": 50}]
|
||||
# Mock embedded count query
|
||||
mock_embedded = [{"count": 30}]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
side_effect=[mock_approved, mock_embedded],
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 50
|
||||
assert stats["with_embeddings"] == 30
|
||||
assert stats["without_embeddings"] == 20
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_missing_items(mocker):
|
||||
"""Test BlockHandler discovers blocks without embeddings."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks to return test blocks
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Calculator Block"
|
||||
mock_block_instance.description = "Performs calculations"
|
||||
mock_block_instance.categories = [MagicMock(value="MATH")]
|
||||
mock_block_instance.input_schema.model_json_schema.return_value = {
|
||||
"properties": {"expression": {"description": "Math expression to evaluate"}}
|
||||
}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-uuid-1": mock_block_class}
|
||||
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
mock_existing = []
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_existing,
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "block-uuid-1"
|
||||
assert items[0].content_type == ContentType.BLOCK
|
||||
assert "Calculator Block" in items[0].searchable_text
|
||||
assert "Performs calculations" in items[0].searchable_text
|
||||
assert "MATH" in items[0].searchable_text
|
||||
assert "expression: Math expression" in items[0].searchable_text
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats(mocker):
|
||||
"""Test BlockHandler returns correct stats."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks
|
||||
mock_blocks = {
|
||||
"block-1": MagicMock(),
|
||||
"block-2": MagicMock(),
|
||||
"block-3": MagicMock(),
|
||||
}
|
||||
|
||||
# Mock embedded count query (2 blocks have embeddings)
|
||||
mock_embedded = [{"count": 2}]
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["with_embeddings"] == 2
|
||||
assert stats["without_embeddings"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
"""Test DocumentationHandler discovers docs without embeddings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory with test files
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
|
||||
(docs_root / "guide.md").write_text("# Getting Started\n\nThis is a guide.")
|
||||
(docs_root / "api.mdx").write_text("# API Reference\n\nAPI documentation.")
|
||||
|
||||
# Mock _get_docs_root to return temp dir
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 2
|
||||
|
||||
# Check guide.md (content_id format: doc_path::section_index)
|
||||
guide_item = next(
|
||||
(item for item in items if item.content_id == "guide.md::0"), None
|
||||
)
|
||||
assert guide_item is not None
|
||||
assert guide_item.content_type == ContentType.DOCUMENTATION
|
||||
assert "Getting Started" in guide_item.searchable_text
|
||||
assert "This is a guide" in guide_item.searchable_text
|
||||
assert guide_item.metadata["doc_title"] == "Getting Started"
|
||||
assert guide_item.user_id is None
|
||||
|
||||
# Check api.mdx (content_id format: doc_path::section_index)
|
||||
api_item = next(
|
||||
(item for item in items if item.content_id == "api.mdx::0"), None
|
||||
)
|
||||
assert api_item is not None
|
||||
assert "API Reference" in api_item.searchable_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_stats(tmp_path, mocker):
|
||||
"""Test DocumentationHandler returns correct stats."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
(docs_root / "doc1.md").write_text("# Doc 1")
|
||||
(docs_root / "doc2.md").write_text("# Doc 2")
|
||||
(docs_root / "doc3.mdx").write_text("# Doc 3")
|
||||
|
||||
# Mock embedded count query (1 doc has embedding)
|
||||
mock_embedded = [{"count": 1}]
|
||||
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["with_embeddings"] == 1
|
||||
assert stats["without_embeddings"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_title_extraction(tmp_path):
|
||||
"""Test DocumentationHandler extracts title from markdown heading."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test with heading
|
||||
doc_with_heading = tmp_path / "with_heading.md"
|
||||
doc_with_heading.write_text("# My Title\n\nContent here")
|
||||
title = handler._extract_doc_title(doc_with_heading)
|
||||
assert title == "My Title"
|
||||
|
||||
# Test without heading
|
||||
doc_without_heading = tmp_path / "no-heading.md"
|
||||
doc_without_heading.write_text("Just content, no heading")
|
||||
title = handler._extract_doc_title(doc_without_heading)
|
||||
assert title == "No Heading" # Uses filename
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
"""Test DocumentationHandler chunks markdown by headings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test document with multiple sections
|
||||
doc_with_sections = tmp_path / "sections.md"
|
||||
doc_with_sections.write_text(
|
||||
"# Document Title\n\n"
|
||||
"Intro paragraph.\n\n"
|
||||
"## Section One\n\n"
|
||||
"Content for section one.\n\n"
|
||||
"## Section Two\n\n"
|
||||
"Content for section two.\n"
|
||||
)
|
||||
sections = handler._chunk_markdown_by_headings(doc_with_sections)
|
||||
|
||||
# Should have 3 sections: intro (with doc title), section one, section two
|
||||
assert len(sections) == 3
|
||||
assert sections[0].title == "Document Title"
|
||||
assert sections[0].index == 0
|
||||
assert "Intro paragraph" in sections[0].content
|
||||
|
||||
assert sections[1].title == "Section One"
|
||||
assert sections[1].index == 1
|
||||
assert "Content for section one" in sections[1].content
|
||||
|
||||
assert sections[2].title == "Section Two"
|
||||
assert sections[2].index == 2
|
||||
assert "Content for section two" in sections[2].content
|
||||
|
||||
# Test document without headings
|
||||
doc_no_sections = tmp_path / "no-sections.md"
|
||||
doc_no_sections.write_text("Just plain content without any headings.")
|
||||
sections = handler._chunk_markdown_by_headings(doc_no_sections)
|
||||
assert len(sections) == 1
|
||||
assert sections[0].index == 0
|
||||
assert "Just plain content" in sections[0].content
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_section_content_ids():
|
||||
"""Test DocumentationHandler creates and parses section content IDs."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test making content ID
|
||||
content_id = handler._make_section_content_id("docs/guide.md", 2)
|
||||
assert content_id == "docs/guide.md::2"
|
||||
|
||||
# Test parsing content ID
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/guide.md::2")
|
||||
assert doc_path == "docs/guide.md"
|
||||
assert section_index == 2
|
||||
|
||||
# Test parsing legacy format (no section index)
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/old-format.md")
|
||||
assert doc_path == "docs/old-format.md"
|
||||
assert section_index == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handlers_registry():
|
||||
"""Test all content types are registered."""
|
||||
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
|
||||
assert ContentType.BLOCK in CONTENT_HANDLERS
|
||||
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
|
||||
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_missing_attributes():
|
||||
"""Test BlockHandler gracefully handles blocks with missing attributes."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock block with minimal attributes
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Minimal Block"
|
||||
# No description, categories, or schema
|
||||
del mock_block_instance.description
|
||||
del mock_block_instance.categories
|
||||
del mock_block_instance.input_schema
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-minimal": mock_block_class}
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].searchable_text == "Minimal Block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_skips_failed_blocks():
|
||||
"""Test BlockHandler skips blocks that fail to instantiate."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock one good block and one bad block
|
||||
good_block = MagicMock()
|
||||
good_instance = MagicMock()
|
||||
good_instance.name = "Good Block"
|
||||
good_instance.description = "Works fine"
|
||||
good_instance.categories = []
|
||||
good_block.return_value = good_instance
|
||||
|
||||
bad_block = MagicMock()
|
||||
bad_block.side_effect = Exception("Instantiation failed")
|
||||
|
||||
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
# Should only get the good block
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "good-block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_missing_docs_directory():
|
||||
"""Test DocumentationHandler handles missing docs directory gracefully."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Mock _get_docs_root to return non-existent path
|
||||
fake_path = Path("/nonexistent/docs")
|
||||
with patch.object(handler, "_get_docs_root", return_value=fake_path):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
assert items == []
|
||||
|
||||
stats = await handler.get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["with_embeddings"] == 0
|
||||
assert stats["without_embeddings"] == 0
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user