Compare commits

..

7 Commits

Author SHA1 Message Date
Nicholas Tindle
ad1a814724 Update manager.py 2026-02-09 13:42:14 -06:00
Otto
562cf04ab6 refactor(backend): extract shared auto-credential parsing to utils.py
Addresses review feedback on #12004:
- Added AutoCredentialFieldInfo dataclass and parse_auto_credential_field()
  helper to executor/utils.py
- Updated _acquire_auto_credentials in manager.py to use shared helper
- Updated _validate_node_input_credentials in utils.py to use shared helper

This consolidates the duplicate logic for parsing GoogleDriveFileField-style
auto-credential fields, making manager.py less cluttered while ensuring
consistent validation/acquisition behavior.
2026-02-09 07:57:36 +00:00
Nicholas Tindle
90b3b5ba16 fix(backend): Fix misplaced section header in graph_test.py
Move the _reassign_ids section comment to above the actual _reassign_ids
tests, and label the combine() tests correctly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 16:11:47 -06:00
Nicholas Tindle
f4f81bc4fc fix(backend): Remove _credentials_id key on fork instead of setting to None
Setting _credentials_id to None on fork was ambiguous — both "forked,
needs re-auth" and "chained data from upstream" were represented as None.
This caused _acquire_auto_credentials to silently skip credential
acquisition for forked agents, leading to confusing TypeErrors at runtime.

Now the key is deleted entirely, making the three states unambiguous:
- Present with value: user-selected credentials
- Present as None: chained data from upstream block
- Absent: forked/needs re-authentication

Also adds pre-run validation for the missing key case and makes error
messages provider-agnostic.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 17:34:16 -06:00
Nicholas Tindle
c5abc01f25 fix(backend): Add error handling for auto-credentials store lookup
Wrap get_creds_by_id call in try/except in the auto-credentials
validation path to match the error handling pattern used for regular
credentials.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 16:53:29 -06:00
Nicholas Tindle
8b7053c1de merge: Resolve conflicts with dev (PR #11986 graph model refactor)
Adapt auto-credentials filtering to dev's refactored graph model:
- aggregate_credentials_inputs() now returns 3-tuples (field_info, node_pairs, is_required)
- credentials_input_schema moved to GraphModel, builds JSON schema directly
- Update regular/auto_credentials_inputs properties for 3-tuple format
- Update test mocks and assertions for new tuple format and class hierarchy

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 16:39:57 -06:00
Nicholas Tindle
e00c1202ad fix(platform): Fix Google Drive auto-credentials handling across the platform
- Tag auto-credentials with `is_auto_credential` and `input_field_name` on `CredentialsFieldInfo` to distinguish them from regular user-provided credentials
- Add `regular_credentials_inputs` and `auto_credentials_inputs` properties to `Graph` so UI schemas, CoPilot, and library presets only surface regular credentials
- Extract `_acquire_auto_credentials()` helper in executor to resolve embedded `_credentials_id` at execution time with proper lock management
- Validate auto-credentials ownership in `_validate_node_input_credentials()` to catch stale/missing credentials before execution
- Clear `_credentials_id` in `_reassign_ids()` on graph fork so cloned agents require re-authentication
- Propagate `is_auto_credential` through `combine()` and `discriminate()` on `CredentialsFieldInfo`
- Add `referrerPolicy: "no-referrer-when-downgrade"` to Google API script loading to fix Firefox API key validation
- Comprehensive test coverage for all new behavior

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 16:08:53 -06:00
574 changed files with 19614 additions and 41703 deletions

View File

@@ -5,13 +5,42 @@
!docs/ !docs/
# Platform - Libs # Platform - Libs
!autogpt_platform/autogpt_libs/ !autogpt_platform/autogpt_libs/autogpt_libs/
!autogpt_platform/autogpt_libs/pyproject.toml
!autogpt_platform/autogpt_libs/poetry.lock
!autogpt_platform/autogpt_libs/README.md
# Platform - Backend # Platform - Backend
!autogpt_platform/backend/ !autogpt_platform/backend/backend/
!autogpt_platform/backend/test/e2e_test_data.py
!autogpt_platform/backend/migrations/
!autogpt_platform/backend/schema.prisma
!autogpt_platform/backend/pyproject.toml
!autogpt_platform/backend/poetry.lock
!autogpt_platform/backend/README.md
!autogpt_platform/backend/.env
!autogpt_platform/backend/gen_prisma_types_stub.py
# Platform - Market
!autogpt_platform/market/market/
!autogpt_platform/market/scripts.py
!autogpt_platform/market/schema.prisma
!autogpt_platform/market/pyproject.toml
!autogpt_platform/market/poetry.lock
!autogpt_platform/market/README.md
# Platform - Frontend # Platform - Frontend
!autogpt_platform/frontend/ !autogpt_platform/frontend/src/
!autogpt_platform/frontend/public/
!autogpt_platform/frontend/scripts/
!autogpt_platform/frontend/package.json
!autogpt_platform/frontend/pnpm-lock.yaml
!autogpt_platform/frontend/tsconfig.json
!autogpt_platform/frontend/README.md
## config
!autogpt_platform/frontend/*.config.*
!autogpt_platform/frontend/.env.*
!autogpt_platform/frontend/.env
# Classic - AutoGPT # Classic - AutoGPT
!classic/original_autogpt/autogpt/ !classic/original_autogpt/autogpt/
@@ -35,38 +64,6 @@
# Classic - Frontend # Classic - Frontend
!classic/frontend/build/web/ !classic/frontend/build/web/
# Explicitly re-ignore unwanted files from whitelisted directories # Explicitly re-ignore some folders
# Note: These patterns MUST come after the whitelist rules to take effect .*
**/__pycache__
# Hidden files and directories (but keep frontend .env files needed for build)
**/.*
!autogpt_platform/frontend/.env
!autogpt_platform/frontend/.env.default
!autogpt_platform/frontend/.env.production
# Python artifacts
**/__pycache__/
**/*.pyc
**/*.pyo
**/.venv/
**/.ruff_cache/
**/.pytest_cache/
**/.coverage
**/htmlcov/
# Node artifacts
**/node_modules/
**/.next/
**/storybook-static/
**/playwright-report/
**/test-results/
# Build artifacts
**/dist/
**/build/
!autogpt_platform/frontend/src/**/build/
**/target/
# Logs and temp files
**/*.log
**/*.tmp

File diff suppressed because it is too large Load Diff

View File

@@ -49,7 +49,7 @@ jobs:
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }} - name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
if: github.event_name == 'push' if: github.event_name == 'push'
uses: peter-evans/create-pull-request@v8 uses: peter-evans/create-pull-request@v7
with: with:
add-paths: classic/frontend/build/web add-paths: classic/frontend/build/web
base: ${{ github.ref_name }} base: ${{ github.ref_name }}

View File

@@ -22,7 +22,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
ref: ${{ github.event.workflow_run.head_branch }} ref: ${{ github.event.workflow_run.head_branch }}
fetch-depth: 0 fetch-depth: 0
@@ -40,51 +40,9 @@ jobs:
git checkout -b "$BRANCH_NAME" git checkout -b "$BRANCH_NAME"
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
# Backend Python/Poetry setup (so Claude can run linting/tests)
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v5
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
- name: Install Poetry
run: |
cd autogpt_platform/backend
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
echo "$HOME/.local/bin" >> $GITHUB_PATH
- name: Install Python dependencies
working-directory: autogpt_platform/backend
run: poetry install
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
# Frontend Node.js/pnpm setup (so Claude can run linting/tests)
- name: Enable corepack
run: corepack enable
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: "22"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Install JavaScript dependencies
working-directory: autogpt_platform/frontend
run: pnpm install --frozen-lockfile
- name: Get CI failure details - name: Get CI failure details
id: failure_details id: failure_details
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
const run = await github.rest.actions.getWorkflowRun({ const run = await github.rest.actions.getWorkflowRun({

View File

@@ -30,7 +30,7 @@ jobs:
actions: read # Required for CI access actions: read # Required for CI access
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
fetch-depth: 1 fetch-depth: 1
@@ -41,7 +41,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache - name: Set up Python dependency cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.cache/pypoetry path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -77,15 +77,27 @@ jobs:
run: poetry run prisma generate && poetry run gen-prisma-stub run: poetry run prisma generate && poetry run gen-prisma-stub
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22"
- name: Enable corepack - name: Enable corepack
run: corepack enable run: corepack enable
- name: Set up Node.js - name: Set pnpm store directory
uses: actions/setup-node@v6 run: |
pnpm config set store-dir ~/.pnpm-store
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v4
with: with:
node-version: "22" path: ~/.pnpm-store
cache: "pnpm" key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install JavaScript dependencies - name: Install JavaScript dependencies
working-directory: autogpt_platform/frontend working-directory: autogpt_platform/frontend
@@ -112,7 +124,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup # Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache - name: Set up Docker image cache
id: docker-cache id: docker-cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/docker-cache path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes # Use a versioned key for cache invalidation when image list changes
@@ -297,7 +309,6 @@ jobs:
uses: anthropics/claude-code-action@v1 uses: anthropics/claude-code-action@v1
with: with:
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
allowed_bots: "dependabot[bot]"
claude_args: | claude_args: |
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)" --allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
prompt: | prompt: |

View File

@@ -40,7 +40,7 @@ jobs:
actions: read # Required for CI access actions: read # Required for CI access
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
fetch-depth: 1 fetch-depth: 1
@@ -57,7 +57,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache - name: Set up Python dependency cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.cache/pypoetry path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -93,15 +93,27 @@ jobs:
run: poetry run prisma generate && poetry run gen-prisma-stub run: poetry run prisma generate && poetry run gen-prisma-stub
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22"
- name: Enable corepack - name: Enable corepack
run: corepack enable run: corepack enable
- name: Set up Node.js - name: Set pnpm store directory
uses: actions/setup-node@v6 run: |
pnpm config set store-dir ~/.pnpm-store
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v4
with: with:
node-version: "22" path: ~/.pnpm-store
cache: "pnpm" key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install JavaScript dependencies - name: Install JavaScript dependencies
working-directory: autogpt_platform/frontend working-directory: autogpt_platform/frontend
@@ -128,7 +140,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup # Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache - name: Set up Docker image cache
id: docker-cache id: docker-cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/docker-cache path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes # Use a versioned key for cache invalidation when image list changes

View File

@@ -58,11 +58,11 @@ jobs:
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
# Initializes the CodeQL tools for scanning. # Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL - name: Initialize CodeQL
uses: github/codeql-action/init@v4 uses: github/codeql-action/init@v3
with: with:
languages: ${{ matrix.language }} languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }} build-mode: ${{ matrix.build-mode }}
@@ -93,6 +93,6 @@ jobs:
exit 1 exit 1
- name: Perform CodeQL Analysis - name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v4 uses: github/codeql-action/analyze@v3
with: with:
category: "/language:${{matrix.language}}" category: "/language:${{matrix.language}}"

View File

@@ -27,7 +27,7 @@ jobs:
# If you do not check out your code, Copilot will do this for you. # If you do not check out your code, Copilot will do this for you.
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
submodules: true submodules: true
@@ -39,7 +39,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache - name: Set up Python dependency cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.cache/pypoetry path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -76,7 +76,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v6 uses: actions/setup-node@v4
with: with:
node-version: "22" node-version: "22"
@@ -89,7 +89,7 @@ jobs:
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies - name: Cache frontend dependencies
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.pnpm-store path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }} key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
@@ -132,7 +132,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup # Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache - name: Set up Docker image cache
id: docker-cache id: docker-cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/docker-cache path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes # Use a versioned key for cache invalidation when image list changes

View File

@@ -23,7 +23,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
fetch-depth: 1 fetch-depth: 1
@@ -33,7 +33,7 @@ jobs:
python-version: "3.11" python-version: "3.11"
- name: Set up Python dependency cache - name: Set up Python dependency cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.cache/pypoetry path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -7,10 +7,6 @@ on:
- "docs/integrations/**" - "docs/integrations/**"
- "autogpt_platform/backend/backend/blocks/**" - "autogpt_platform/backend/backend/blocks/**"
concurrency:
group: claude-docs-review-${{ github.event.pull_request.number }}
cancel-in-progress: true
jobs: jobs:
claude-review: claude-review:
# Only run for PRs from members/collaborators # Only run for PRs from members/collaborators
@@ -27,7 +23,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
@@ -37,7 +33,7 @@ jobs:
python-version: "3.11" python-version: "3.11"
- name: Set up Python dependency cache - name: Set up Python dependency cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.cache/pypoetry path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -95,35 +91,5 @@ jobs:
3. Read corresponding documentation files to verify accuracy 3. Read corresponding documentation files to verify accuracy
4. Provide your feedback as a PR comment 4. Provide your feedback as a PR comment
## IMPORTANT: Comment Marker
Start your PR comment with exactly this HTML comment marker on its own line:
<!-- CLAUDE_DOCS_REVIEW -->
This marker is used to identify and replace your comment on subsequent runs.
Be constructive and specific. If everything looks good, say so! Be constructive and specific. If everything looks good, say so!
If there are issues, explain what's wrong and suggest how to fix it. If there are issues, explain what's wrong and suggest how to fix it.
- name: Delete old Claude review comments
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Get all comment IDs with our marker, sorted by creation date (oldest first)
COMMENT_IDS=$(gh api \
repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments \
--jq '[.[] | select(.body | contains("<!-- CLAUDE_DOCS_REVIEW -->"))] | sort_by(.created_at) | .[].id')
# Count comments
COMMENT_COUNT=$(echo "$COMMENT_IDS" | grep -c . || true)
if [ "$COMMENT_COUNT" -gt 1 ]; then
# Delete all but the last (newest) comment
echo "$COMMENT_IDS" | head -n -1 | while read -r COMMENT_ID; do
if [ -n "$COMMENT_ID" ]; then
echo "Deleting old review comment: $COMMENT_ID"
gh api -X DELETE repos/${{ github.repository }}/issues/comments/$COMMENT_ID
fi
done
else
echo "No old review comments to clean up"
fi

View File

@@ -28,7 +28,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
fetch-depth: 1 fetch-depth: 1
@@ -38,7 +38,7 @@ jobs:
python-version: "3.11" python-version: "3.11"
- name: Set up Python dependency cache - name: Set up Python dependency cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.cache/pypoetry path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -25,7 +25,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
ref: ${{ github.event.inputs.git_ref || github.ref_name }} ref: ${{ github.event.inputs.git_ref || github.ref_name }}
@@ -52,7 +52,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Trigger deploy workflow - name: Trigger deploy workflow
uses: peter-evans/repository-dispatch@v4 uses: peter-evans/repository-dispatch@v3
with: with:
token: ${{ secrets.DEPLOY_TOKEN }} token: ${{ secrets.DEPLOY_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure repository: Significant-Gravitas/AutoGPT_cloud_infrastructure

View File

@@ -17,7 +17,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
ref: ${{ github.ref_name || 'master' }} ref: ${{ github.ref_name || 'master' }}
@@ -45,7 +45,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Trigger deploy workflow - name: Trigger deploy workflow
uses: peter-evans/repository-dispatch@v4 uses: peter-evans/repository-dispatch@v3
with: with:
token: ${{ secrets.DEPLOY_TOKEN }} token: ${{ secrets.DEPLOY_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure repository: Significant-Gravitas/AutoGPT_cloud_infrastructure

View File

@@ -41,18 +41,13 @@ jobs:
ports: ports:
- 6379:6379 - 6379:6379
rabbitmq: rabbitmq:
image: rabbitmq:4.1.4 image: rabbitmq:3.12-management
ports: ports:
- 5672:5672 - 5672:5672
- 15672:15672
env: env:
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }} RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }} RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
options: >-
--health-cmd "rabbitmq-diagnostics -q ping"
--health-interval 30s
--health-timeout 10s
--health-retries 5
--health-start-period 10s
clamav: clamav:
image: clamav/clamav-debian:latest image: clamav/clamav-debian:latest
ports: ports:
@@ -73,7 +68,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
submodules: true submodules: true
@@ -93,7 +88,7 @@ jobs:
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
- name: Set up Python dependency cache - name: Set up Python dependency cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.cache/pypoetry path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -17,7 +17,7 @@ jobs:
- name: Check comment permissions and deployment status - name: Check comment permissions and deployment status
id: check_status id: check_status
if: github.event_name == 'issue_comment' && github.event.issue.pull_request if: github.event_name == 'issue_comment' && github.event.issue.pull_request
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
const commentBody = context.payload.comment.body.trim(); const commentBody = context.payload.comment.body.trim();
@@ -55,7 +55,7 @@ jobs:
- name: Post permission denied comment - name: Post permission denied comment
if: steps.check_status.outputs.permission_denied == 'true' if: steps.check_status.outputs.permission_denied == 'true'
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
await github.rest.issues.createComment({ await github.rest.issues.createComment({
@@ -68,7 +68,7 @@ jobs:
- name: Get PR details for deployment - name: Get PR details for deployment
id: pr_details id: pr_details
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true' if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
const pr = await github.rest.pulls.get({ const pr = await github.rest.pulls.get({
@@ -82,7 +82,7 @@ jobs:
- name: Dispatch Deploy Event - name: Dispatch Deploy Event
if: steps.check_status.outputs.should_deploy == 'true' if: steps.check_status.outputs.should_deploy == 'true'
uses: peter-evans/repository-dispatch@v4 uses: peter-evans/repository-dispatch@v3
with: with:
token: ${{ secrets.DISPATCH_TOKEN }} token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
@@ -98,7 +98,7 @@ jobs:
- name: Post deploy success comment - name: Post deploy success comment
if: steps.check_status.outputs.should_deploy == 'true' if: steps.check_status.outputs.should_deploy == 'true'
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
await github.rest.issues.createComment({ await github.rest.issues.createComment({
@@ -110,7 +110,7 @@ jobs:
- name: Dispatch Undeploy Event (from comment) - name: Dispatch Undeploy Event (from comment)
if: steps.check_status.outputs.should_undeploy == 'true' if: steps.check_status.outputs.should_undeploy == 'true'
uses: peter-evans/repository-dispatch@v4 uses: peter-evans/repository-dispatch@v3
with: with:
token: ${{ secrets.DISPATCH_TOKEN }} token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
@@ -126,7 +126,7 @@ jobs:
- name: Post undeploy success comment - name: Post undeploy success comment
if: steps.check_status.outputs.should_undeploy == 'true' if: steps.check_status.outputs.should_undeploy == 'true'
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
await github.rest.issues.createComment({ await github.rest.issues.createComment({
@@ -139,7 +139,7 @@ jobs:
- name: Check deployment status on PR close - name: Check deployment status on PR close
id: check_pr_close id: check_pr_close
if: github.event_name == 'pull_request' && github.event.action == 'closed' if: github.event_name == 'pull_request' && github.event.action == 'closed'
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
const comments = await github.rest.issues.listComments({ const comments = await github.rest.issues.listComments({
@@ -168,7 +168,7 @@ jobs:
github.event_name == 'pull_request' && github.event_name == 'pull_request' &&
github.event.action == 'closed' && github.event.action == 'closed' &&
steps.check_pr_close.outputs.should_undeploy == 'true' steps.check_pr_close.outputs.should_undeploy == 'true'
uses: peter-evans/repository-dispatch@v4 uses: peter-evans/repository-dispatch@v3
with: with:
token: ${{ secrets.DISPATCH_TOKEN }} token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
@@ -187,7 +187,7 @@ jobs:
github.event_name == 'pull_request' && github.event_name == 'pull_request' &&
github.event.action == 'closed' && github.event.action == 'closed' &&
steps.check_pr_close.outputs.should_undeploy == 'true' steps.check_pr_close.outputs.should_undeploy == 'true'
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
await github.rest.issues.createComment({ await github.rest.issues.createComment({

View File

@@ -6,16 +6,10 @@ on:
paths: paths:
- ".github/workflows/platform-frontend-ci.yml" - ".github/workflows/platform-frontend-ci.yml"
- "autogpt_platform/frontend/**" - "autogpt_platform/frontend/**"
- "autogpt_platform/backend/Dockerfile"
- "autogpt_platform/docker-compose.yml"
- "autogpt_platform/docker-compose.platform.yml"
pull_request: pull_request:
paths: paths:
- ".github/workflows/platform-frontend-ci.yml" - ".github/workflows/platform-frontend-ci.yml"
- "autogpt_platform/frontend/**" - "autogpt_platform/frontend/**"
- "autogpt_platform/backend/Dockerfile"
- "autogpt_platform/docker-compose.yml"
- "autogpt_platform/docker-compose.platform.yml"
merge_group: merge_group:
workflow_dispatch: workflow_dispatch:
@@ -32,11 +26,12 @@ jobs:
setup: setup:
runs-on: ubuntu-latest runs-on: ubuntu-latest
outputs: outputs:
cache-key: ${{ steps.cache-key.outputs.key }}
components-changed: ${{ steps.filter.outputs.components }} components-changed: ${{ steps.filter.outputs.components }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
- name: Check for component changes - name: Check for component changes
uses: dorny/paths-filter@v3 uses: dorny/paths-filter@v3
@@ -46,17 +41,28 @@ jobs:
components: components:
- 'autogpt_platform/frontend/src/components/**' - 'autogpt_platform/frontend/src/components/**'
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
- name: Enable corepack - name: Enable corepack
run: corepack enable run: corepack enable
- name: Set up Node - name: Generate cache key
uses: actions/setup-node@v6 id: cache-key
with: run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Install dependencies to populate cache - name: Cache dependencies
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile run: pnpm install --frozen-lockfile
lint: lint:
@@ -65,17 +71,24 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
- name: Enable corepack - name: Enable corepack
run: corepack enable run: corepack enable
- name: Set up Node - name: Restore dependencies cache
uses: actions/setup-node@v6 uses: actions/cache@v4
with: with:
node-version: "22.18.0" path: ~/.pnpm-store
cache: "pnpm" key: ${{ needs.setup.outputs.cache-key }}
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies - name: Install dependencies
run: pnpm install --frozen-lockfile run: pnpm install --frozen-lockfile
@@ -94,19 +107,26 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
- name: Enable corepack - name: Enable corepack
run: corepack enable run: corepack enable
- name: Set up Node - name: Restore dependencies cache
uses: actions/setup-node@v6 uses: actions/cache@v4
with: with:
node-version: "22.18.0" path: ~/.pnpm-store
cache: "pnpm" key: ${{ needs.setup.outputs.cache-key }}
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies - name: Install dependencies
run: pnpm install --frozen-lockfile run: pnpm install --frozen-lockfile
@@ -121,20 +141,30 @@ jobs:
exitOnceUploaded: true exitOnceUploaded: true
e2e_test: e2e_test:
name: end-to-end tests
runs-on: big-boi runs-on: big-boi
needs: setup
strategy:
fail-fast: false
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
submodules: recursive submodules: recursive
- name: Set up Platform - Copy default supabase .env - name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
- name: Enable corepack
run: corepack enable
- name: Copy default supabase .env
run: | run: |
cp ../.env.default ../.env cp ../.env.default ../.env
- name: Set up Platform - Copy backend .env and set OpenAI API key - name: Copy backend .env and set OpenAI API key
run: | run: |
cp ../backend/.env.default ../backend/.env cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
@@ -142,125 +172,77 @@ jobs:
# Used by E2E test data script to generate embeddings for approved store agents # Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Set up Platform - Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
- name: Cache Docker layers
uses: actions/cache@v4
with: with:
driver: docker-container path: /tmp/.buildx-cache
driver-opts: network=host key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
restore-keys: |
${{ runner.os }}-buildx-frontend-test-
- name: Set up Platform - Expose GHA cache to docker buildx CLI - name: Run docker compose
uses: crazy-max/ghaction-github-runtime@v3
- name: Set up Platform - Build Docker images (with cache)
working-directory: autogpt_platform
run: | run: |
pip install pyyaml NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
# Resolve extends and generate a flat compose file that bake can understand
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
# Add cache configuration to the resolved compose file
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
--source docker-compose.resolved.yml \
--cache-from "type=gha" \
--cache-to "type=gha,mode=max" \
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
--git-ref "${{ github.ref }}"
# Build with bake using the resolved compose file (now includes cache config)
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
env: env:
NEXT_PUBLIC_PW_TEST: true DOCKER_BUILDKIT: 1
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max
- name: Set up tests - Cache E2E test data - name: Move cache
id: e2e-data-cache
uses: actions/cache@v5
with:
path: /tmp/e2e_test_data.sql
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
- name: Set up Platform - Start Supabase DB + Auth
run: | run: |
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build rm -rf /tmp/.buildx-cache
echo "Waiting for database to be ready..." if [ -d "/tmp/.buildx-cache-new" ]; then
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' mv /tmp/.buildx-cache-new /tmp/.buildx-cache
echo "Waiting for auth service to be ready..." fi
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
- name: Set up Platform - Run migrations - name: Wait for services to be ready
run: | run: |
echo "Running migrations..."
docker compose -f ../docker-compose.resolved.yml run --rm migrate
echo "✅ Migrations completed"
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Load cached E2E test data
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
run: |
echo "✅ Found cached E2E test data, restoring..."
{
echo "SET session_replication_role = 'replica';"
cat /tmp/e2e_test_data.sql
echo "SET session_replication_role = 'origin';"
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
# Refresh materialized views after restore
docker compose -f ../docker-compose.resolved.yml exec -T db \
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
echo "✅ E2E test data restored from cache"
- name: Set up Platform - Start (all other services)
run: |
docker compose -f ../docker-compose.resolved.yml up -d --no-build
echo "Waiting for rest_server to be ready..." echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..." timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
env: echo "Waiting for database to be ready..."
NEXT_PUBLIC_PW_TEST: true timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
- name: Set up tests - Create E2E test data - name: Create E2E test data
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
run: | run: |
echo "Creating E2E test data..." echo "Creating E2E test data..."
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py # First try to run the script from inside the container
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || { if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then
echo "❌ E2E test data creation failed!" echo "✅ Found e2e_test_data.py in container, running it..."
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || {
exit 1 echo "❌ E2E test data creation failed!"
} docker compose -f ../docker-compose.yml logs --tail=50 rest_server
exit 1
}
else
echo "⚠️ e2e_test_data.py not found in container, copying and running..."
# Copy the script into the container and run it
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || {
echo "❌ Failed to copy script to container"
exit 1
}
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
exit 1
}
fi
# Dump auth.users + platform schema for cache (two separate dumps) - name: Restore dependencies cache
echo "Dumping database for cache..." uses: actions/cache@v4
{
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--table='auth.users' postgres
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--schema=platform \
--exclude-table='platform._prisma_migrations' \
--exclude-table='platform.apscheduler_jobs' \
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
postgres
} > /tmp/e2e_test_data.sql
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
- name: Set up tests - Enable corepack
run: corepack enable
- name: Set up tests - Set up Node
uses: actions/setup-node@v6
with: with:
node-version: "22.18.0" path: ~/.pnpm-store
cache: "pnpm" key: ${{ needs.setup.outputs.cache-key }}
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Set up tests - Install dependencies - name: Install dependencies
run: pnpm install --frozen-lockfile run: pnpm install --frozen-lockfile
- name: Set up tests - Install browser 'chromium' - name: Install Browser 'chromium'
run: pnpm playwright install --with-deps chromium run: pnpm playwright install --with-deps chromium
- name: Run Playwright tests - name: Run Playwright tests
@@ -287,7 +269,7 @@ jobs:
- name: Print Final Docker Compose logs - name: Print Final Docker Compose logs
if: always() if: always()
run: docker compose -f ../docker-compose.resolved.yml logs run: docker compose -f ../docker-compose.yml logs
integration_test: integration_test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -295,19 +277,26 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
submodules: recursive submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
- name: Enable corepack - name: Enable corepack
run: corepack enable run: corepack enable
- name: Set up Node - name: Restore dependencies cache
uses: actions/setup-node@v6 uses: actions/cache@v4
with: with:
node-version: "22.18.0" path: ~/.pnpm-store
cache: "pnpm" key: ${{ needs.setup.outputs.cache-key }}
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies - name: Install dependencies
run: pnpm install --frozen-lockfile run: pnpm install --frozen-lockfile

View File

@@ -29,10 +29,10 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v6 uses: actions/setup-node@v4
with: with:
node-version: "22.18.0" node-version: "22.18.0"
@@ -44,7 +44,7 @@ jobs:
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies - name: Cache dependencies
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.pnpm-store path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }} key: ${{ steps.cache-key.outputs.key }}
@@ -56,19 +56,19 @@ jobs:
run: pnpm install --frozen-lockfile run: pnpm install --frozen-lockfile
types: types:
runs-on: big-boi runs-on: ubuntu-latest
needs: setup needs: setup
strategy: strategy:
fail-fast: false fail-fast: false
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v4
with: with:
submodules: recursive submodules: recursive
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v6 uses: actions/setup-node@v4
with: with:
node-version: "22.18.0" node-version: "22.18.0"
@@ -85,10 +85,10 @@ jobs:
- name: Run docker compose - name: Run docker compose
run: | run: |
docker compose -f ../docker-compose.yml --profile local up -d deps_backend docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
- name: Restore dependencies cache - name: Restore dependencies cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.pnpm-store path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }} key: ${{ needs.setup.outputs.cache-key }}

View File

@@ -1,39 +0,0 @@
name: PR Overlap Detection
on:
pull_request:
types: [opened, synchronize, reopened]
branches:
- dev
- master
permissions:
contents: read
pull-requests: write
jobs:
check-overlaps:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0 # Need full history for merge testing
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Configure git
run: |
git config user.email "github-actions[bot]@users.noreply.github.com"
git config user.name "github-actions[bot]"
- name: Run overlap detection
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# Always succeed - this check informs contributors, it shouldn't block merging
continue-on-error: true
run: |
python .github/scripts/detect_overlaps.py ${{ github.event.pull_request.number }}

View File

@@ -11,7 +11,7 @@ jobs:
steps: steps:
# - name: Wait some time for all actions to start # - name: Wait some time for all actions to start
# run: sleep 30 # run: sleep 30
- uses: actions/checkout@v6 - uses: actions/checkout@v4
# with: # with:
# fetch-depth: 0 # fetch-depth: 0
- name: Set up Python - name: Set up Python

View File

@@ -1,195 +0,0 @@
#!/usr/bin/env python3
"""
Add cache configuration to a resolved docker-compose file for all services
that have a build key, and ensure image names match what docker compose expects.
"""
import argparse
import yaml
DEFAULT_BRANCH = "dev"
CACHE_BUILDS_FOR_COMPONENTS = ["backend", "frontend"]
def main():
parser = argparse.ArgumentParser(
description="Add cache config to a resolved compose file"
)
parser.add_argument(
"--source",
required=True,
help="Source compose file to read (should be output of `docker compose config`)",
)
parser.add_argument(
"--cache-from",
default="type=gha",
help="Cache source configuration",
)
parser.add_argument(
"--cache-to",
default="type=gha,mode=max",
help="Cache destination configuration",
)
for component in CACHE_BUILDS_FOR_COMPONENTS:
parser.add_argument(
f"--{component}-hash",
default="",
help=f"Hash for {component} cache scope (e.g., from hashFiles())",
)
parser.add_argument(
"--git-ref",
default="",
help="Git ref for branch-based cache scope (e.g., refs/heads/master)",
)
args = parser.parse_args()
# Normalize git ref to a safe scope name (e.g., refs/heads/master -> master)
git_ref_scope = ""
if args.git_ref:
git_ref_scope = args.git_ref.replace("refs/heads/", "").replace("/", "-")
with open(args.source, "r") as f:
compose = yaml.safe_load(f)
# Get project name from compose file or default
project_name = compose.get("name", "autogpt_platform")
def get_image_name(dockerfile: str, target: str) -> str:
"""Generate image name based on Dockerfile folder and build target."""
dockerfile_parts = dockerfile.replace("\\", "/").split("/")
if len(dockerfile_parts) >= 2:
folder_name = dockerfile_parts[-2] # e.g., "backend" or "frontend"
else:
folder_name = "app"
return f"{project_name}-{folder_name}:{target}"
def get_build_key(dockerfile: str, target: str) -> str:
"""Generate a unique key for a Dockerfile+target combination."""
return f"{dockerfile}:{target}"
def get_component(dockerfile: str) -> str | None:
"""Get component name (frontend/backend) from dockerfile path."""
for component in CACHE_BUILDS_FOR_COMPONENTS:
if component in dockerfile:
return component
return None
# First pass: collect all services with build configs and identify duplicates
# Track which (dockerfile, target) combinations we've seen
build_key_to_first_service: dict[str, str] = {}
services_to_build: list[str] = []
services_to_dedupe: list[str] = []
for service_name, service_config in compose.get("services", {}).items():
if "build" not in service_config:
continue
build_config = service_config["build"]
dockerfile = build_config.get("dockerfile", "Dockerfile")
target = build_config.get("target", "default")
build_key = get_build_key(dockerfile, target)
if build_key not in build_key_to_first_service:
# First service with this build config - it will do the actual build
build_key_to_first_service[build_key] = service_name
services_to_build.append(service_name)
else:
# Duplicate - will just use the image from the first service
services_to_dedupe.append(service_name)
# Second pass: configure builds and deduplicate
modified_services = []
for service_name, service_config in compose.get("services", {}).items():
if "build" not in service_config:
continue
build_config = service_config["build"]
dockerfile = build_config.get("dockerfile", "Dockerfile")
target = build_config.get("target", "latest")
image_name = get_image_name(dockerfile, target)
# Set image name for all services (needed for both builders and deduped)
service_config["image"] = image_name
if service_name in services_to_dedupe:
# Remove build config - this service will use the pre-built image
del service_config["build"]
continue
# This service will do the actual build - add cache config
cache_from_list = []
cache_to_list = []
component = get_component(dockerfile)
if not component:
# Skip services that don't clearly match frontend/backend
continue
# Get the hash for this component
component_hash = getattr(args, f"{component}_hash")
# Scope format: platform-{component}-{target}-{hash|ref}
# Example: platform-backend-server-abc123
if "type=gha" in args.cache_from:
# 1. Primary: exact hash match (most specific)
if component_hash:
hash_scope = f"platform-{component}-{target}-{component_hash}"
cache_from_list.append(f"{args.cache_from},scope={hash_scope}")
# 2. Fallback: branch-based cache
if git_ref_scope:
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
cache_from_list.append(f"{args.cache_from},scope={ref_scope}")
# 3. Fallback: dev branch cache (for PRs/feature branches)
if git_ref_scope and git_ref_scope != DEFAULT_BRANCH:
master_scope = f"platform-{component}-{target}-{DEFAULT_BRANCH}"
cache_from_list.append(f"{args.cache_from},scope={master_scope}")
if "type=gha" in args.cache_to:
# Write to both hash-based and branch-based scopes
if component_hash:
hash_scope = f"platform-{component}-{target}-{component_hash}"
cache_to_list.append(f"{args.cache_to},scope={hash_scope}")
if git_ref_scope:
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
cache_to_list.append(f"{args.cache_to},scope={ref_scope}")
# Ensure we have at least one cache source/target
if not cache_from_list:
cache_from_list.append(args.cache_from)
if not cache_to_list:
cache_to_list.append(args.cache_to)
build_config["cache_from"] = cache_from_list
build_config["cache_to"] = cache_to_list
modified_services.append(service_name)
# Write back to the same file
with open(args.source, "w") as f:
yaml.dump(compose, f, default_flow_style=False, sort_keys=False)
print(f"Added cache config to {len(modified_services)} services in {args.source}:")
for svc in modified_services:
svc_config = compose["services"][svc]
build_cfg = svc_config.get("build", {})
cache_from_list = build_cfg.get("cache_from", ["none"])
cache_to_list = build_cfg.get("cache_to", ["none"])
print(f" - {svc}")
print(f" image: {svc_config.get('image', 'N/A')}")
print(f" cache_from: {cache_from_list}")
print(f" cache_to: {cache_to_list}")
if services_to_dedupe:
print(
f"Deduplicated {len(services_to_dedupe)} services (will use pre-built images):"
)
for svc in services_to_dedupe:
print(f" - {svc} -> {compose['services'][svc].get('image', 'N/A')}")
if __name__ == "__main__":
main()

View File

@@ -45,11 +45,6 @@ AutoGPT Platform is a monorepo containing:
- Backend/Frontend services use YAML anchors for consistent configuration - Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern - Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
### Branching Strategy
- **`dev`** is the main development branch. All PRs should target `dev`.
- **`master`** is the production branch. Only used for production releases.
### Creating Pull Requests ### Creating Pull Requests
- Create the PR against the `dev` branch of the repository. - Create the PR against the `dev` branch of the repository.

File diff suppressed because it is too large Load Diff

View File

@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.10,<4.0" python = ">=3.10,<4.0"
colorama = "^0.4.6" colorama = "^0.4.6"
cryptography = "^46.0" cryptography = "^45.0"
expiringdict = "^1.2.2" expiringdict = "^1.2.2"
fastapi = "^0.128.7" fastapi = "^0.116.1"
google-cloud-logging = "^3.13.0" google-cloud-logging = "^3.12.1"
launchdarkly-server-sdk = "^9.15.0" launchdarkly-server-sdk = "^9.12.0"
pydantic = "^2.12.5" pydantic = "^2.11.7"
pydantic-settings = "^2.12.0" pydantic-settings = "^2.10.1"
pyjwt = { version = "^2.11.0", extras = ["crypto"] } pyjwt = { version = "^2.10.1", extras = ["crypto"] }
redis = "^6.2.0" redis = "^6.2.0"
supabase = "^2.28.0" supabase = "^2.16.0"
uvicorn = "^0.40.0" uvicorn = "^0.35.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pyright = "^1.1.408" pyright = "^1.1.404"
pytest = "^8.4.1" pytest = "^8.4.1"
pytest-asyncio = "^1.3.0" pytest-asyncio = "^1.1.0"
pytest-mock = "^3.15.1" pytest-mock = "^3.14.1"
pytest-cov = "^7.0.0" pytest-cov = "^6.2.1"
ruff = "^0.15.0" ruff = "^0.12.11"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]

View File

@@ -104,12 +104,6 @@ TWITTER_CLIENT_SECRET=
# Make a new workspace for your OAuth APP -- trust me # Make a new workspace for your OAuth APP -- trust me
# https://linear.app/settings/api/applications/new # https://linear.app/settings/api/applications/new
# Callback URL: http://localhost:3000/auth/integrations/oauth_callback # Callback URL: http://localhost:3000/auth/integrations/oauth_callback
LINEAR_API_KEY=
# Linear project and team IDs for the feature request tracker.
# Find these in your Linear workspace URL: linear.app/<workspace>/project/<project-id>
# and in team settings. Used by the chat copilot to file and search feature requests.
LINEAR_FEATURE_REQUEST_PROJECT_ID=
LINEAR_FEATURE_REQUEST_TEAM_ID=
LINEAR_CLIENT_ID= LINEAR_CLIENT_ID=
LINEAR_CLIENT_SECRET= LINEAR_CLIENT_SECRET=

View File

@@ -1,5 +1,3 @@
# ============================ DEPENDENCY BUILDER ============================ #
FROM debian:13-slim AS builder FROM debian:13-slim AS builder
# Set environment variables # Set environment variables
@@ -53,62 +51,27 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./ COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
RUN poetry run prisma generate && poetry run gen-prisma-stub RUN poetry run prisma generate && poetry run gen-prisma-stub
# =============================== DB MIGRATOR =============================== # FROM debian:13-slim AS server_dependencies
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
FROM debian:13-slim AS migrate
WORKDIR /app/autogpt_platform/backend
ENV DEBIAN_FRONTEND=noninteractive
# Install only what's needed for prisma migrate: Node.js and minimal Python for prisma-python
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.13 \
python3-pip \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# Copy Node.js from builder (needed for Prisma CLI)
COPY --from=builder /usr/bin/node /usr/bin/node
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
COPY --from=builder /usr/bin/npm /usr/bin/npm
# Copy Prisma binaries
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
# Install prisma-client-py directly (much smaller than copying full venv)
RUN pip3 install prisma>=0.15.0 --break-system-packages
COPY autogpt_platform/backend/schema.prisma ./
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
COPY autogpt_platform/backend/migrations ./migrations
# ============================== BACKEND SERVER ============================== #
FROM debian:13-slim AS server
WORKDIR /app WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive ENV POETRY_HOME=/opt/poetry \
POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_CREATE=true \
POETRY_VIRTUALENVS_IN_PROJECT=true \
DEBIAN_FRONTEND=noninteractive
ENV PATH=/opt/poetry/bin:$PATH
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use. # Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network) RUN apt-get update && apt-get install -y \
# for the bash_exec MCP tool.
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.13 \ python3.13 \
python3-pip \ python3-pip \
ffmpeg \ ffmpeg \
imagemagick \ imagemagick \
jq \
ripgrep \
tree \
bubblewrap \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points) # Copy only necessary files from builder
COPY --from=builder /app /app
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3* COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
# Copy Node.js installation for Prisma # Copy Node.js installation for Prisma
@@ -118,25 +81,30 @@ COPY --from=builder /usr/bin/npm /usr/bin/npm
COPY --from=builder /usr/bin/npx /usr/bin/npx COPY --from=builder /usr/bin/npx /usr/bin/npx
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
WORKDIR /app/autogpt_platform/backend
# Copy only the .venv from builder (not the entire /app directory)
# The .venv includes the generated Prisma client
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH" ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
# Copy dependency files + autogpt_libs (path dependency) RUN mkdir -p /app/autogpt_platform/autogpt_libs
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs RUN mkdir -p /app/autogpt_platform/backend
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
# Copy backend code + docs (for Copilot docs search) COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
COPY autogpt_platform/backend ./
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
WORKDIR /app/autogpt_platform/backend
FROM server_dependencies AS migrate
# Migration stage only needs schema and migrations - much lighter than full backend
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
FROM server_dependencies AS server
COPY autogpt_platform/backend /app/autogpt_platform/backend
COPY docs /app/docs COPY docs /app/docs
# Install the project package to create entry point scripts in .venv/bin/ RUN poetry install --no-ansi --only-root
# (e.g., rest, executor, ws, db, scheduler, notification - see [tool.poetry.scripts])
RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
poetry install --no-ansi --only-root
ENV PORT=8000 ENV PORT=8000
CMD ["rest"] CMD ["poetry", "run", "rest"]

View File

@@ -1,9 +1,4 @@
"""Common test fixtures for server tests. """Common test fixtures for server tests."""
Note: Common fixtures like test_user_id, admin_user_id, target_user_id,
setup_test_user, and setup_admin_user are defined in the parent conftest.py
(backend/conftest.py) and are available here automatically.
"""
import pytest import pytest
from pytest_snapshot.plugin import Snapshot from pytest_snapshot.plugin import Snapshot
@@ -16,6 +11,54 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
return snapshot return snapshot
@pytest.fixture
def test_user_id() -> str:
"""Test user ID fixture."""
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture
def admin_user_id() -> str:
"""Admin user ID fixture."""
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
@pytest.fixture
def target_user_id() -> str:
"""Target user ID fixture."""
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
@pytest.fixture
async def setup_test_user(test_user_id):
"""Create test user in database before tests."""
from backend.data.user import get_or_create_user
# Create the test user in the database using JWT token format
user_data = {
"sub": test_user_id,
"email": "test@example.com",
"user_metadata": {"name": "Test User"},
}
await get_or_create_user(user_data)
return test_user_id
@pytest.fixture
async def setup_admin_user(admin_user_id):
"""Create admin user in database before tests."""
from backend.data.user import get_or_create_user
# Create the admin user in the database using JWT token format
user_data = {
"sub": admin_user_id,
"email": "test-admin@example.com",
"user_metadata": {"name": "Test Admin"},
}
await get_or_create_user(user_data)
return admin_user_id
@pytest.fixture @pytest.fixture
def mock_jwt_user(test_user_id): def mock_jwt_user(test_user_id):
"""Provide mock JWT payload for regular user testing.""" """Provide mock JWT payload for regular user testing."""

View File

@@ -10,7 +10,7 @@ from typing_extensions import TypedDict
import backend.api.features.store.cache as store_cache import backend.api.features.store.cache as store_cache
import backend.api.features.store.model as store_model import backend.api.features.store.model as store_model
import backend.blocks import backend.data.block
from backend.api.external.middleware import require_permission from backend.api.external.middleware import require_permission
from backend.data import execution as execution_db from backend.data import execution as execution_db
from backend.data import graph as graph_db from backend.data import graph as graph_db
@@ -67,7 +67,7 @@ async def get_user_info(
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))], dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
) )
async def get_graph_blocks() -> Sequence[dict[Any, Any]]: async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.blocks.get_blocks().values()] blocks = [block() for block in backend.data.block.get_blocks().values()]
return [b.to_dict() for b in blocks if not b.disabled] return [b.to_dict() for b in blocks if not b.disabled]
@@ -83,7 +83,7 @@ async def execute_graph_block(
require_permission(APIKeyPermission.EXECUTE_BLOCK) require_permission(APIKeyPermission.EXECUTE_BLOCK)
), ),
) -> CompletedBlockOutput: ) -> CompletedBlockOutput:
obj = backend.blocks.get_block(block_id) obj = backend.data.block.get_block(block_id)
if not obj: if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.") raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
if obj.disabled: if obj.disabled:

View File

@@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from backend.api.external.middleware import require_permission from backend.api.external.middleware import require_permission
from backend.copilot.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.copilot.tools import find_agent_tool, run_agent_tool from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
from backend.copilot.tools.models import ToolResponseBase from backend.api.features.chat.tools.models import ToolResponseBase
from backend.data.auth.base import APIAuthorizationInfo from backend.data.auth.base import APIAuthorizationInfo
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -10,15 +10,10 @@ import backend.api.features.library.db as library_db
import backend.api.features.library.model as library_model import backend.api.features.library.model as library_model
import backend.api.features.store.db as store_db import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model import backend.api.features.store.model as store_model
import backend.data.block
from backend.blocks import load_all_blocks from backend.blocks import load_all_blocks
from backend.blocks._base import (
AnyBlockSchema,
BlockCategory,
BlockInfo,
BlockSchema,
BlockType,
)
from backend.blocks.llm import LlmModel from backend.blocks.llm import LlmModel
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
from backend.data.db import query_raw_with_schema from backend.data.db import query_raw_with_schema
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.util.cache import cached from backend.util.cache import cached
@@ -27,7 +22,7 @@ from backend.util.models import Pagination
from .model import ( from .model import (
BlockCategoryResponse, BlockCategoryResponse,
BlockResponse, BlockResponse,
BlockTypeFilter, BlockType,
CountResponse, CountResponse,
FilterType, FilterType,
Provider, Provider,
@@ -93,7 +88,7 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
def get_blocks( def get_blocks(
*, *,
category: str | None = None, category: str | None = None,
type: BlockTypeFilter | None = None, type: BlockType | None = None,
provider: ProviderName | None = None, provider: ProviderName | None = None,
page: int = 1, page: int = 1,
page_size: int = 50, page_size: int = 50,
@@ -674,9 +669,9 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
for block_type in load_all_blocks().values(): for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type() block: AnyBlockSchema = block_type()
if block.disabled or block.block_type in ( if block.disabled or block.block_type in (
BlockType.INPUT, backend.data.block.BlockType.INPUT,
BlockType.OUTPUT, backend.data.block.BlockType.OUTPUT,
BlockType.AGENT, backend.data.block.BlockType.AGENT,
): ):
continue continue
# Find the execution count for this block # Find the execution count for this block

View File

@@ -4,7 +4,7 @@ from pydantic import BaseModel
import backend.api.features.library.model as library_model import backend.api.features.library.model as library_model
import backend.api.features.store.model as store_model import backend.api.features.store.model as store_model
from backend.blocks._base import BlockInfo from backend.data.block import BlockInfo
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.util.models import Pagination from backend.util.models import Pagination
@@ -15,7 +15,7 @@ FilterType = Literal[
"my_agents", "my_agents",
] ]
BlockTypeFilter = Literal["all", "input", "action", "output"] BlockType = Literal["all", "input", "action", "output"]
class SearchEntry(BaseModel): class SearchEntry(BaseModel):

View File

@@ -88,7 +88,7 @@ async def get_block_categories(
) )
async def get_blocks( async def get_blocks(
category: Annotated[str | None, fastapi.Query()] = None, category: Annotated[str | None, fastapi.Query()] = None,
type: Annotated[builder_model.BlockTypeFilter | None, fastapi.Query()] = None, type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None,
provider: Annotated[ProviderName | None, fastapi.Query()] = None, provider: Annotated[ProviderName | None, fastapi.Query()] = None,
page: Annotated[int, fastapi.Query()] = 1, page: Annotated[int, fastapi.Query()] = 1,
page_size: Annotated[int, fastapi.Query()] = 50, page_size: Annotated[int, fastapi.Query()] = 50,

View File

@@ -37,10 +37,12 @@ stale pending messages from dead consumers.
import asyncio import asyncio
import logging import logging
import os
import uuid import uuid
from typing import Any from typing import Any
import orjson import orjson
from prisma import Prisma
from pydantic import BaseModel from pydantic import BaseModel
from redis.exceptions import ResponseError from redis.exceptions import ResponseError
@@ -67,8 +69,8 @@ class OperationCompleteMessage(BaseModel):
class ChatCompletionConsumer: class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from Redis Streams. """Consumer for chat operation completion messages from Redis Streams.
Database operations are handled through the chat_db() accessor, which This consumer initializes its own Prisma client in start() to ensure
routes through DatabaseManager RPC when Prisma is not directly connected. database operations work correctly within this async context.
Uses Redis consumer groups to allow multiple platform pods to consume Uses Redis consumer groups to allow multiple platform pods to consume
messages reliably with automatic redelivery on failure. messages reliably with automatic redelivery on failure.
@@ -77,6 +79,7 @@ class ChatCompletionConsumer:
def __init__(self): def __init__(self):
self._consumer_task: asyncio.Task | None = None self._consumer_task: asyncio.Task | None = None
self._running = False self._running = False
self._prisma: Prisma | None = None
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}" self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
async def start(self) -> None: async def start(self) -> None:
@@ -112,6 +115,15 @@ class ChatCompletionConsumer:
f"Chat completion consumer started (consumer: {self._consumer_name})" f"Chat completion consumer started (consumer: {self._consumer_name})"
) )
async def _ensure_prisma(self) -> Prisma:
"""Lazily initialize Prisma client on first use."""
if self._prisma is None:
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
self._prisma = Prisma(datasource={"url": database_url})
await self._prisma.connect()
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
return self._prisma
async def stop(self) -> None: async def stop(self) -> None:
"""Stop the completion consumer.""" """Stop the completion consumer."""
self._running = False self._running = False
@@ -124,6 +136,11 @@ class ChatCompletionConsumer:
pass pass
self._consumer_task = None self._consumer_task = None
if self._prisma:
await self._prisma.disconnect()
self._prisma = None
logger.info("[COMPLETION] Consumer Prisma client disconnected")
logger.info("Chat completion consumer stopped") logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None: async def _consume_messages(self) -> None:
@@ -235,7 +252,7 @@ class ChatCompletionConsumer:
# XAUTOCLAIM after min_idle_time expires # XAUTOCLAIM after min_idle_time expires
async def _handle_message(self, body: bytes) -> None: async def _handle_message(self, body: bytes) -> None:
"""Handle a completion message.""" """Handle a completion message using our own Prisma client."""
try: try:
data = orjson.loads(body) data = orjson.loads(body)
message = OperationCompleteMessage(**data) message = OperationCompleteMessage(**data)
@@ -285,7 +302,8 @@ class ChatCompletionConsumer:
message: OperationCompleteMessage, message: OperationCompleteMessage,
) -> None: ) -> None:
"""Handle successful operation completion.""" """Handle successful operation completion."""
await process_operation_success(task, message.result) prisma = await self._ensure_prisma()
await process_operation_success(task, message.result, prisma)
async def _handle_failure( async def _handle_failure(
self, self,
@@ -293,7 +311,8 @@ class ChatCompletionConsumer:
message: OperationCompleteMessage, message: OperationCompleteMessage,
) -> None: ) -> None:
"""Handle failed operation completion.""" """Handle failed operation completion."""
await process_operation_failure(task, message.error) prisma = await self._ensure_prisma()
await process_operation_failure(task, message.error, prisma)
# Module-level consumer instance # Module-level consumer instance

View File

@@ -9,8 +9,7 @@ import logging
from typing import Any from typing import Any
import orjson import orjson
from prisma import Prisma
from backend.data.db_accessors import chat_db
from . import service as chat_service from . import service as chat_service
from . import stream_registry from . import stream_registry
@@ -73,40 +72,48 @@ async def _update_tool_message(
session_id: str, session_id: str,
tool_call_id: str, tool_call_id: str,
content: str, content: str,
prisma_client: Prisma | None,
) -> None: ) -> None:
"""Update tool message in database using the chat_db accessor. """Update tool message in database.
Routes through DatabaseManager RPC when Prisma is not directly
connected (e.g. in the CoPilot Executor microservice).
Args: Args:
session_id: The session ID session_id: The session ID
tool_call_id: The tool call ID to update tool_call_id: The tool call ID to update
content: The new content for the message content: The new content for the message
prisma_client: Optional Prisma client. If None, uses chat_service.
Raises: Raises:
ToolMessageUpdateError: If the database update fails. ToolMessageUpdateError: If the database update fails. The caller should
handle this to avoid marking the task as completed with inconsistent state.
""" """
try: try:
updated = await chat_db().update_tool_message_content( if prisma_client:
session_id=session_id, # Use provided Prisma client (for consumer with its own connection)
tool_call_id=tool_call_id, updated_count = await prisma_client.chatmessage.update_many(
new_content=content, where={
) "sessionId": session_id,
if not updated: "toolCallId": tool_call_id,
raise ToolMessageUpdateError( },
f"No message found with tool_call_id=" data={"content": content},
f"{tool_call_id} in session {session_id}" )
# Check if any rows were updated - 0 means message not found
if updated_count == 0:
raise ToolMessageUpdateError(
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
)
else:
# Use service function (for webhook endpoint)
await chat_service._update_pending_operation(
session_id=session_id,
tool_call_id=tool_call_id,
result=content,
) )
except ToolMessageUpdateError: except ToolMessageUpdateError:
raise raise
except Exception as e: except Exception as e:
logger.error( logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
f"[COMPLETION] Failed to update tool message: {e}",
exc_info=True,
)
raise ToolMessageUpdateError( raise ToolMessageUpdateError(
f"Failed to update tool message for tool call #{tool_call_id}: {e}" f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
) from e ) from e
@@ -195,6 +202,7 @@ async def _save_agent_from_result(
async def process_operation_success( async def process_operation_success(
task: stream_registry.ActiveTask, task: stream_registry.ActiveTask,
result: dict | str | None, result: dict | str | None,
prisma_client: Prisma | None = None,
) -> None: ) -> None:
"""Handle successful operation completion. """Handle successful operation completion.
@@ -204,10 +212,12 @@ async def process_operation_success(
Args: Args:
task: The active task that completed task: The active task that completed
result: The result data from the operation result: The result data from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
Raises: Raises:
ToolMessageUpdateError: If the database update fails. The task ToolMessageUpdateError: If the database update fails. The task will be
will be marked as failed instead of completed. marked as failed instead of completed to avoid inconsistent state.
""" """
# For agent generation tools, save the agent to library # For agent generation tools, save the agent to library
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict): if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
@@ -240,6 +250,7 @@ async def process_operation_success(
session_id=task.session_id, session_id=task.session_id,
tool_call_id=task.tool_call_id, tool_call_id=task.tool_call_id,
content=result_str, content=result_str,
prisma_client=prisma_client,
) )
except ToolMessageUpdateError: except ToolMessageUpdateError:
# DB update failed - mark task as failed to avoid inconsistent state # DB update failed - mark task as failed to avoid inconsistent state
@@ -282,15 +293,18 @@ async def process_operation_success(
async def process_operation_failure( async def process_operation_failure(
task: stream_registry.ActiveTask, task: stream_registry.ActiveTask,
error: str | None, error: str | None,
prisma_client: Prisma | None = None,
) -> None: ) -> None:
"""Handle failed operation completion. """Handle failed operation completion.
Publishes the error to the stream registry, updates the database Publishes the error to the stream registry, updates the database with
with the error response, and marks the task as failed. the error response, and marks the task as failed.
Args: Args:
task: The active task that failed task: The active task that failed
error: The error message from the operation error: The error message from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
""" """
error_msg = error or "Operation failed" error_msg = error or "Operation failed"
@@ -311,6 +325,7 @@ async def process_operation_failure(
session_id=task.session_id, session_id=task.session_id,
tool_call_id=task.tool_call_id, tool_call_id=task.tool_call_id,
content=error_response.model_dump_json(), content=error_response.model_dump_json(),
prisma_client=prisma_client,
) )
except ToolMessageUpdateError: except ToolMessageUpdateError:
# DB update failed - log but continue with cleanup # DB update failed - log but continue with cleanup

View File

@@ -27,11 +27,12 @@ class ChatConfig(BaseSettings):
session_ttl: int = Field(default=43200, description="Session TTL in seconds") session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration # Streaming Configuration
stream_timeout: int = Field(default=300, description="Stream timeout in seconds") max_context_messages: int = Field(
max_retries: int = Field( default=50, ge=1, le=200, description="Maximum context messages"
default=3,
description="Max retries for fallback path (SDK handles retries internally)",
) )
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(default=3, description="Maximum number of retries")
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs") max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field( max_agent_schedules: int = Field(
default=30, description="Maximum number of agent schedules" default=30, description="Maximum number of agent schedules"
@@ -92,37 +93,6 @@ class ChatConfig(BaseSettings):
description="Name of the prompt in Langfuse to fetch", description="Name of the prompt in Langfuse to fetch",
) )
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
description="Use Claude Agent SDK for chat completions",
)
claude_agent_model: str | None = Field(
default=None,
description="Model for the Claude Agent SDK path. If None, derives from "
"the `model` field by stripping the OpenRouter provider prefix.",
)
claude_agent_max_buffer_size: int = Field(
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
description="Max buffer size in bytes for Claude Agent SDK JSON message parsing. "
"Increase if tool outputs exceed the limit.",
)
claude_agent_max_subtasks: int = Field(
default=10,
description="Max number of sub-agent Tasks the SDK can spawn per session.",
)
claude_agent_use_resume: bool = Field(
default=True,
description="Use --resume for multi-turn conversations instead of "
"history compression. Falls back to compression when unavailable.",
)
# Extended thinking configuration for Claude models
thinking_enabled: bool = Field(
default=True,
description="Enable adaptive thinking for Claude models via OpenRouter",
)
@field_validator("api_key", mode="before") @field_validator("api_key", mode="before")
@classmethod @classmethod
def get_api_key(cls, v): def get_api_key(cls, v):
@@ -162,17 +132,6 @@ class ChatConfig(BaseSettings):
v = os.getenv("CHAT_INTERNAL_API_KEY") v = os.getenv("CHAT_INTERNAL_API_KEY")
return v return v
@field_validator("use_claude_agent_sdk", mode="before")
@classmethod
def get_use_claude_agent_sdk(cls, v):
"""Get use_claude_agent_sdk from environment if not provided."""
# Check environment variable - default to True if not set
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
# Default to True (SDK enabled by default)
return True if v is None else v
# Prompt paths for different contexts # Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = { PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md", "default": "prompts/chat_system.md",

View File

@@ -14,27 +14,29 @@ from prisma.types import (
ChatSessionWhereInput, ChatSessionWhereInput,
) )
from backend.data import db from backend.data.db import transaction
from backend.util.json import SafeJson from backend.util.json import SafeJson
from .model import ChatMessage, ChatSession, ChatSessionInfo
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def get_chat_session(session_id: str) -> ChatSession | None: async def get_chat_session(session_id: str) -> PrismaChatSession | None:
"""Get a chat session by ID from the database.""" """Get a chat session by ID from the database."""
session = await PrismaChatSession.prisma().find_unique( session = await PrismaChatSession.prisma().find_unique(
where={"id": session_id}, where={"id": session_id},
include={"Messages": {"order_by": {"sequence": "asc"}}}, include={"Messages": True},
) )
return ChatSession.from_db(session) if session else None if session and session.Messages:
# Sort messages by sequence in Python - Prisma Python client doesn't support
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
session.Messages.sort(key=lambda m: m.sequence)
return session
async def create_chat_session( async def create_chat_session(
session_id: str, session_id: str,
user_id: str, user_id: str,
) -> ChatSessionInfo: ) -> PrismaChatSession:
"""Create a new chat session in the database.""" """Create a new chat session in the database."""
data = ChatSessionCreateInput( data = ChatSessionCreateInput(
id=session_id, id=session_id,
@@ -43,8 +45,10 @@ async def create_chat_session(
successfulAgentRuns=SafeJson({}), successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}), successfulAgentSchedules=SafeJson({}),
) )
prisma_session = await PrismaChatSession.prisma().create(data=data) return await PrismaChatSession.prisma().create(
return ChatSessionInfo.from_db(prisma_session) data=data,
include={"Messages": True},
)
async def update_chat_session( async def update_chat_session(
@@ -55,7 +59,7 @@ async def update_chat_session(
total_prompt_tokens: int | None = None, total_prompt_tokens: int | None = None,
total_completion_tokens: int | None = None, total_completion_tokens: int | None = None,
title: str | None = None, title: str | None = None,
) -> ChatSession | None: ) -> PrismaChatSession | None:
"""Update a chat session's metadata.""" """Update a chat session's metadata."""
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)} data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
@@ -75,9 +79,12 @@ async def update_chat_session(
session = await PrismaChatSession.prisma().update( session = await PrismaChatSession.prisma().update(
where={"id": session_id}, where={"id": session_id},
data=data, data=data,
include={"Messages": {"order_by": {"sequence": "asc"}}}, include={"Messages": True},
) )
return ChatSession.from_db(session) if session else None if session and session.Messages:
# Sort in Python - Prisma Python doesn't support order_by in include clauses
session.Messages.sort(key=lambda m: m.sequence)
return session
async def add_chat_message( async def add_chat_message(
@@ -90,7 +97,7 @@ async def add_chat_message(
refusal: str | None = None, refusal: str | None = None,
tool_calls: list[dict[str, Any]] | None = None, tool_calls: list[dict[str, Any]] | None = None,
function_call: dict[str, Any] | None = None, function_call: dict[str, Any] | None = None,
) -> ChatMessage: ) -> PrismaChatMessage:
"""Add a message to a chat session.""" """Add a message to a chat session."""
# Build input dict dynamically rather than using ChatMessageCreateInput directly # Build input dict dynamically rather than using ChatMessageCreateInput directly
# because Prisma's TypedDict validation rejects optional fields set to None. # because Prisma's TypedDict validation rejects optional fields set to None.
@@ -125,14 +132,14 @@ async def add_chat_message(
), ),
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)), PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
) )
return ChatMessage.from_db(message) return message
async def add_chat_messages_batch( async def add_chat_messages_batch(
session_id: str, session_id: str,
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
start_sequence: int, start_sequence: int,
) -> list[ChatMessage]: ) -> list[PrismaChatMessage]:
"""Add multiple messages to a chat session in a batch. """Add multiple messages to a chat session in a batch.
Uses a transaction for atomicity - if any message creation fails, Uses a transaction for atomicity - if any message creation fails,
@@ -143,7 +150,7 @@ async def add_chat_messages_batch(
created_messages = [] created_messages = []
async with db.transaction() as tx: async with transaction() as tx:
for i, msg in enumerate(messages): for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput # Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields # directly because Prisma's TypedDict validation rejects optional fields
@@ -183,22 +190,21 @@ async def add_chat_messages_batch(
data={"updatedAt": datetime.now(UTC)}, data={"updatedAt": datetime.now(UTC)},
) )
return [ChatMessage.from_db(m) for m in created_messages] return created_messages
async def get_user_chat_sessions( async def get_user_chat_sessions(
user_id: str, user_id: str,
limit: int = 50, limit: int = 50,
offset: int = 0, offset: int = 0,
) -> list[ChatSessionInfo]: ) -> list[PrismaChatSession]:
"""Get chat sessions for a user, ordered by most recent.""" """Get chat sessions for a user, ordered by most recent."""
prisma_sessions = await PrismaChatSession.prisma().find_many( return await PrismaChatSession.prisma().find_many(
where={"userId": user_id}, where={"userId": user_id},
order={"updatedAt": "desc"}, order={"updatedAt": "desc"},
take=limit, take=limit,
skip=offset, skip=offset,
) )
return [ChatSessionInfo.from_db(s) for s in prisma_sessions]
async def get_user_session_count(user_id: str) -> int: async def get_user_session_count(user_id: str) -> int:

View File

@@ -2,7 +2,7 @@ import asyncio
import logging import logging
import uuid import uuid
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any, Self, cast from typing import Any
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from openai.types.chat import ( from openai.types.chat import (
@@ -23,17 +23,26 @@ from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession from prisma.models import ChatSession as PrismaChatSession
from pydantic import BaseModel from pydantic import BaseModel
from backend.data.db_accessors import chat_db
from backend.data.redis_client import get_redis_async from backend.data.redis_client import get_redis_async
from backend.util import json from backend.util import json
from backend.util.exceptions import DatabaseError, RedisError from backend.util.exceptions import DatabaseError, RedisError
from . import db as chat_db
from .config import ChatConfig from .config import ChatConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
config = ChatConfig() config = ChatConfig()
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
"""Parse a JSON field that may be stored as string or already parsed."""
if value is None:
return default
if isinstance(value, str):
return json.loads(value)
return value
# Redis cache key prefix for chat sessions # Redis cache key prefix for chat sessions
CHAT_SESSION_CACHE_PREFIX = "chat:session:" CHAT_SESSION_CACHE_PREFIX = "chat:session:"
@@ -43,7 +52,28 @@ def _get_session_cache_key(session_id: str) -> str:
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}" return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
# ===================== Chat data models ===================== # # Session-level locks to prevent race conditions during concurrent upserts.
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks.
"""
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
@@ -55,19 +85,6 @@ class ChatMessage(BaseModel):
tool_calls: list[dict] | None = None tool_calls: list[dict] | None = None
function_call: dict | None = None function_call: dict | None = None
@staticmethod
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
"""Convert a Prisma ChatMessage to a Pydantic ChatMessage."""
return ChatMessage(
role=prisma_message.role,
content=prisma_message.content,
name=prisma_message.name,
tool_call_id=prisma_message.toolCallId,
refusal=prisma_message.refusal,
tool_calls=_parse_json_field(prisma_message.toolCalls),
function_call=_parse_json_field(prisma_message.functionCall),
)
class Usage(BaseModel): class Usage(BaseModel):
prompt_tokens: int prompt_tokens: int
@@ -75,10 +92,11 @@ class Usage(BaseModel):
total_tokens: int total_tokens: int
class ChatSessionInfo(BaseModel): class ChatSession(BaseModel):
session_id: str session_id: str
user_id: str user_id: str
title: str | None = None title: str | None = None
messages: list[ChatMessage]
usage: list[Usage] usage: list[Usage]
credentials: dict[str, dict] = {} # Map of provider -> credential metadata credentials: dict[str, dict] = {} # Map of provider -> credential metadata
started_at: datetime started_at: datetime
@@ -86,9 +104,40 @@ class ChatSessionInfo(BaseModel):
successful_agent_runs: dict[str, int] = {} successful_agent_runs: dict[str, int] = {}
successful_agent_schedules: dict[str, int] = {} successful_agent_schedules: dict[str, int] = {}
@classmethod @staticmethod
def from_db(cls, prisma_session: PrismaChatSession) -> Self: def new(user_id: str) -> "ChatSession":
"""Convert Prisma ChatSession to Pydantic ChatSession.""" return ChatSession(
session_id=str(uuid.uuid4()),
user_id=user_id,
title=None,
messages=[],
usage=[],
credentials={},
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
@staticmethod
def from_db(
prisma_session: PrismaChatSession,
prisma_messages: list[PrismaChatMessage] | None = None,
) -> "ChatSession":
"""Convert Prisma models to Pydantic ChatSession."""
messages = []
if prisma_messages:
for msg in prisma_messages:
messages.append(
ChatMessage(
role=msg.role,
content=msg.content,
name=msg.name,
tool_call_id=msg.toolCallId,
refusal=msg.refusal,
tool_calls=_parse_json_field(msg.toolCalls),
function_call=_parse_json_field(msg.functionCall),
)
)
# Parse JSON fields from Prisma # Parse JSON fields from Prisma
credentials = _parse_json_field(prisma_session.credentials, default={}) credentials = _parse_json_field(prisma_session.credentials, default={})
successful_agent_runs = _parse_json_field( successful_agent_runs = _parse_json_field(
@@ -110,10 +159,11 @@ class ChatSessionInfo(BaseModel):
) )
) )
return cls( return ChatSession(
session_id=prisma_session.id, session_id=prisma_session.id,
user_id=prisma_session.userId, user_id=prisma_session.userId,
title=prisma_session.title, title=prisma_session.title,
messages=messages,
usage=usage, usage=usage,
credentials=credentials, credentials=credentials,
started_at=prisma_session.createdAt, started_at=prisma_session.createdAt,
@@ -122,56 +172,6 @@ class ChatSessionInfo(BaseModel):
successful_agent_schedules=successful_agent_schedules, successful_agent_schedules=successful_agent_schedules,
) )
class ChatSession(ChatSessionInfo):
messages: list[ChatMessage]
@classmethod
def new(cls, user_id: str) -> Self:
return cls(
session_id=str(uuid.uuid4()),
user_id=user_id,
title=None,
messages=[],
usage=[],
credentials={},
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
@classmethod
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
"""Convert Prisma ChatSession to Pydantic ChatSession."""
if prisma_session.Messages is None:
raise ValueError(
f"Prisma session {prisma_session.id} is missing Messages relation"
)
return cls(
**ChatSessionInfo.from_db(prisma_session).model_dump(),
messages=[ChatMessage.from_db(m) for m in prisma_session.Messages],
)
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
"""Attach a tool_call to the current turn's assistant message.
Searches backwards for the most recent assistant message (stopping at
any user message boundary). If found, appends the tool_call to it.
Otherwise creates a new assistant message with the tool_call.
"""
for msg in reversed(self.messages):
if msg.role == "user":
break
if msg.role == "assistant":
if not msg.tool_calls:
msg.tool_calls = []
msg.tool_calls.append(tool_call)
return
self.messages.append(
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
)
def to_openai_messages(self) -> list[ChatCompletionMessageParam]: def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
messages = [] messages = []
for message in self.messages: for message in self.messages:
@@ -258,72 +258,43 @@ class ChatSession(ChatSessionInfo):
name=message.name or "", name=message.name or "",
) )
) )
return self._merge_consecutive_assistant_messages(messages) return messages
@staticmethod
def _merge_consecutive_assistant_messages(
messages: list[ChatCompletionMessageParam],
) -> list[ChatCompletionMessageParam]:
"""Merge consecutive assistant messages into single messages.
Long-running tool flows can create split assistant messages: one with
text content and another with tool_calls. Anthropic's API requires
tool_result blocks to reference a tool_use in the immediately preceding
assistant message, so these splits cause 400 errors via OpenRouter.
"""
if len(messages) < 2:
return messages
result: list[ChatCompletionMessageParam] = [messages[0]]
for msg in messages[1:]:
prev = result[-1]
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
result.append(msg)
continue
prev = cast(ChatCompletionAssistantMessageParam, prev)
curr = cast(ChatCompletionAssistantMessageParam, msg)
curr_content = curr.get("content") or ""
if curr_content:
prev_content = prev.get("content") or ""
prev["content"] = (
f"{prev_content}\n{curr_content}" if prev_content else curr_content
)
curr_tool_calls = curr.get("tool_calls")
if curr_tool_calls:
prev_tool_calls = prev.get("tool_calls")
prev["tool_calls"] = (
list(prev_tool_calls) + list(curr_tool_calls)
if prev_tool_calls
else list(curr_tool_calls)
)
return result
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any: async def _get_session_from_cache(session_id: str) -> ChatSession | None:
"""Parse a JSON field that may be stored as string or already parsed.""" """Get a chat session from Redis cache."""
if value is None: redis_key = _get_session_cache_key(session_id)
return default async_redis = await get_redis_async()
if isinstance(value, str): raw_session: bytes | None = await async_redis.get(redis_key)
return json.loads(value)
return value if raw_session is None:
return None
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
except Exception as e:
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
raise RedisError(f"Corrupted session data for {session_id}") from e
# ================ Chat cache + DB operations ================ # async def _cache_session(session: ChatSession) -> None:
"""Cache a chat session in Redis."""
# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not
# connected directly.
async def cache_chat_session(session: ChatSession) -> None:
"""Cache a chat session in Redis (without persisting to the database)."""
redis_key = _get_session_cache_key(session.session_id) redis_key = _get_session_cache_key(session.session_id)
async_redis = await get_redis_async() async_redis = await get_redis_async()
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json()) await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
async def cache_chat_session(session: ChatSession) -> None:
"""Cache a chat session without persisting to the database."""
await _cache_session(session)
async def invalidate_session_cache(session_id: str) -> None: async def invalidate_session_cache(session_id: str) -> None:
"""Invalidate a chat session from Redis cache. """Invalidate a chat session from Redis cache.
@@ -339,6 +310,80 @@ async def invalidate_session_cache(session_id: str) -> None:
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}") logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
prisma_session = await chat_db.get_chat_session(session_id)
if not prisma_session:
return None
messages = prisma_session.Messages
logger.info(
f"Loading session {session_id} from DB: "
f"has_messages={messages is not None}, "
f"message_count={len(messages) if messages else 0}, "
f"roles={[m.role for m in messages] if messages else []}"
)
return ChatSession.from_db(prisma_session, messages)
async def _save_session_to_db(
session: ChatSession, existing_message_count: int
) -> None:
"""Save or update a chat session in the database."""
# Check if session exists in DB
existing = await chat_db.get_chat_session(session.session_id)
if not existing:
# Create new session
await chat_db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
# Calculate total tokens from usage
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
# Update session metadata
await chat_db.update_chat_session(
session_id=session.session_id,
credentials=session.credentials,
successful_agent_runs=session.successful_agent_runs,
successful_agent_schedules=session.successful_agent_schedules,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
if new_messages:
messages_data = []
for msg in new_messages:
messages_data.append(
{
"role": msg.role,
"content": msg.content,
"name": msg.name,
"tool_call_id": msg.tool_call_id,
"refusal": msg.refusal,
"tool_calls": msg.tool_calls,
"function_call": msg.function_call,
}
)
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await chat_db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
async def get_chat_session( async def get_chat_session(
session_id: str, session_id: str,
user_id: str | None = None, user_id: str | None = None,
@@ -370,7 +415,7 @@ async def get_chat_session(
logger.warning(f"Unexpected cache error for session {session_id}: {e}") logger.warning(f"Unexpected cache error for session {session_id}: {e}")
# Fall back to database # Fall back to database
logger.debug(f"Session {session_id} not in cache, checking database") logger.info(f"Session {session_id} not in cache, checking database")
session = await _get_session_from_db(session_id) session = await _get_session_from_db(session_id)
if session is None: if session is None:
@@ -386,7 +431,7 @@ async def get_chat_session(
# Cache the session from DB # Cache the session from DB
try: try:
await cache_chat_session(session) await _cache_session(session)
logger.info(f"Cached session {session_id} from database") logger.info(f"Cached session {session_id} from database")
except Exception as e: except Exception as e:
logger.warning(f"Failed to cache session {session_id}: {e}") logger.warning(f"Failed to cache session {session_id}: {e}")
@@ -394,45 +439,9 @@ async def get_chat_session(
return session return session
async def _get_session_from_cache(session_id: str) -> ChatSession | None: async def upsert_chat_session(
"""Get a chat session from Redis cache.""" session: ChatSession,
redis_key = _get_session_cache_key(session_id) ) -> ChatSession:
async_redis = await get_redis_async()
raw_session: bytes | None = await async_redis.get(redis_key)
if raw_session is None:
return None
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
except Exception as e:
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
raise RedisError(f"Corrupted session data for {session_id}") from e
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
session = await chat_db().get_chat_session(session_id)
if not session:
return None
logger.info(
f"Loaded session {session_id} from DB: "
f"has_messages={bool(session.messages)}, "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
async def upsert_chat_session(session: ChatSession) -> ChatSession:
"""Update a chat session in both cache and database. """Update a chat session in both cache and database.
Uses session-level locking to prevent race conditions when concurrent Uses session-level locking to prevent race conditions when concurrent
@@ -450,7 +459,7 @@ async def upsert_chat_session(session: ChatSession) -> ChatSession:
async with lock: async with lock:
# Get existing message count from DB for incremental saves # Get existing message count from DB for incremental saves
existing_message_count = await chat_db().get_chat_session_message_count( existing_message_count = await chat_db.get_chat_session_message_count(
session.session_id session.session_id
) )
@@ -467,7 +476,7 @@ async def upsert_chat_session(session: ChatSession) -> ChatSession:
# Save to cache (best-effort, even if DB failed) # Save to cache (best-effort, even if DB failed)
try: try:
await cache_chat_session(session) await _cache_session(session)
except Exception as e: except Exception as e:
# If DB succeeded but cache failed, raise cache error # If DB succeeded but cache failed, raise cache error
if db_error is None: if db_error is None:
@@ -488,99 +497,6 @@ async def upsert_chat_session(session: ChatSession) -> ChatSession:
return session return session
async def _save_session_to_db(
session: ChatSession, existing_message_count: int
) -> None:
"""Save or update a chat session in the database."""
db = chat_db()
# Check if session exists in DB
existing = await db.get_chat_session(session.session_id)
if not existing:
# Create new session
await db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
# Calculate total tokens from usage
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
# Update session metadata
await db.update_chat_session(
session_id=session.session_id,
credentials=session.credentials,
successful_agent_runs=session.successful_agent_runs,
successful_agent_schedules=session.successful_agent_schedules,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
if new_messages:
messages_data = []
for msg in new_messages:
messages_data.append(
{
"role": msg.role,
"content": msg.content,
"name": msg.name,
"tool_call_id": msg.tool_call_id,
"refusal": msg.refusal,
"tool_calls": msg.tool_calls,
"function_call": msg.function_call,
}
)
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.
Acquires the session lock, re-fetches the latest session state,
appends the message, and saves preventing message loss when
concurrent requests modify the same session.
"""
lock = await _get_session_lock(session_id)
async with lock:
session = await get_chat_session(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
session.messages.append(message)
existing_message_count = await chat_db().get_chat_session_message_count(
session_id
)
try:
await _save_session_to_db(session, existing_message_count)
except Exception as e:
raise DatabaseError(
f"Failed to persist message to session {session_id}"
) from e
try:
await cache_chat_session(session)
except Exception as e:
logger.warning(f"Cache write failed for session {session_id}: {e}")
return session
async def create_chat_session(user_id: str) -> ChatSession: async def create_chat_session(user_id: str) -> ChatSession:
"""Create a new chat session and persist it. """Create a new chat session and persist it.
@@ -593,7 +509,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
# Create in database first - fail fast if this fails # Create in database first - fail fast if this fails
try: try:
await chat_db().create_chat_session( await chat_db.create_chat_session(
session_id=session.session_id, session_id=session.session_id,
user_id=user_id, user_id=user_id,
) )
@@ -605,7 +521,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
# Cache the session (best-effort optimization, DB is source of truth) # Cache the session (best-effort optimization, DB is source of truth)
try: try:
await cache_chat_session(session) await _cache_session(session)
except Exception as e: except Exception as e:
logger.warning(f"Failed to cache new session {session.session_id}: {e}") logger.warning(f"Failed to cache new session {session.session_id}: {e}")
@@ -616,16 +532,20 @@ async def get_user_sessions(
user_id: str, user_id: str,
limit: int = 50, limit: int = 50,
offset: int = 0, offset: int = 0,
) -> tuple[list[ChatSessionInfo], int]: ) -> tuple[list[ChatSession], int]:
"""Get chat sessions for a user from the database with total count. """Get chat sessions for a user from the database with total count.
Returns: Returns:
A tuple of (sessions, total_count) where total_count is the overall A tuple of (sessions, total_count) where total_count is the overall
number of sessions for the user (not just the current page). number of sessions for the user (not just the current page).
""" """
db = chat_db() prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
sessions = await db.get_user_chat_sessions(user_id, limit, offset) total_count = await chat_db.get_user_session_count(user_id)
total_count = await db.get_user_session_count(user_id)
sessions = []
for prisma_session in prisma_sessions:
# Convert without messages for listing (lighter weight)
sessions.append(ChatSession.from_db(prisma_session, None))
return sessions, total_count return sessions, total_count
@@ -643,7 +563,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
""" """
# Delete from database first (with optional user_id validation) # Delete from database first (with optional user_id validation)
# This confirms ownership before invalidating cache # This confirms ownership before invalidating cache
deleted = await chat_db().delete_chat_session(session_id, user_id) deleted = await chat_db.delete_chat_session(session_id, user_id)
if not deleted: if not deleted:
return False return False
@@ -678,52 +598,20 @@ async def update_session_title(session_id: str, title: str) -> bool:
True if updated successfully, False otherwise. True if updated successfully, False otherwise.
""" """
try: try:
result = await chat_db().update_chat_session(session_id=session_id, title=title) result = await chat_db.update_chat_session(session_id=session_id, title=title)
if result is None: if result is None:
logger.warning(f"Session {session_id} not found for title update") logger.warning(f"Session {session_id} not found for title update")
return False return False
# Update title in cache if it exists (instead of invalidating). # Invalidate cache so next fetch gets updated title
# This prevents race conditions where cache invalidation causes
# the frontend to see stale DB data while streaming is still in progress.
try: try:
cached = await _get_session_from_cache(session_id) redis_key = _get_session_cache_key(session_id)
if cached: async_redis = await get_redis_async()
cached.title = title await async_redis.delete(redis_key)
await cache_chat_session(cached)
except Exception as e: except Exception as e:
# Not critical - title will be correct on next full cache refresh logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
logger.warning(
f"Failed to update title in cache for session {session_id}: {e}"
)
return True return True
except Exception as e: except Exception as e:
logger.error(f"Failed to update title for session {session_id}: {e}") logger.error(f"Failed to update title for session {session_id}: {e}")
return False return False
# ==================== Chat session locks ==================== #
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
This was originally added to solve the specific problem of race conditions between
the session title thread and the conversation thread, which always occurs on the
same instance as we prevent rapid request sends on the frontend.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks. Explicit cleanup also occurs
in `delete_chat_session()`.
"""
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock

View File

@@ -0,0 +1,119 @@
import pytest
from .model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
upsert_chat_session,
)
messages = [
ChatMessage(content="Hello, how are you?", role="user"),
ChatMessage(
content="I'm fine, thank you!",
role="assistant",
tool_calls=[
{
"id": "t123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "New York"}',
},
}
],
),
ChatMessage(
content="I'm using the tool to get the weather",
role="tool",
tool_call_id="t123",
),
]
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_serialization_deserialization():
s = ChatSession.new(user_id="abc123")
s.messages = messages
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
serialized = s.model_dump_json()
s2 = ChatSession.model_validate_json(serialized)
assert s2.model_dump() == s.model_dump()
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
s2 = await get_chat_session(
session_id=s.session_id,
user_id=s.user_id,
)
assert s2 == s
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_redis_storage_user_id_mismatch(
setup_test_user, test_user_id
):
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
s2 = await get_chat_session(s.session_id, "different_user_id")
assert s2 is None
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_db_storage(setup_test_user, test_user_id):
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
from backend.data.redis_client import get_redis_async
# Create session with messages including assistant message
s = ChatSession.new(user_id=test_user_id)
s.messages = messages # Contains user, assistant, and tool messages
assert s.session_id is not None, "Session id is not set"
# Upsert to save to both cache and DB
s = await upsert_chat_session(s)
# Clear the Redis cache to force DB load
redis_key = f"chat:session:{s.session_id}"
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
# Load from DB (cache was cleared)
s2 = await get_chat_session(
session_id=s.session_id,
user_id=s.user_id,
)
assert s2 is not None, "Session not found after loading from DB"
assert len(s2.messages) == len(
s.messages
), f"Message count mismatch: expected {len(s.messages)}, got {len(s2.messages)}"
# Verify all roles are present
roles = [m.role for m in s2.messages]
assert "user" in roles, f"User message missing. Roles found: {roles}"
assert "assistant" in roles, f"Assistant message missing. Roles found: {roles}"
assert "tool" in roles, f"Tool message missing. Roles found: {roles}"
# Verify message content
for orig, loaded in zip(s.messages, s2.messages):
assert orig.role == loaded.role, f"Role mismatch: {orig.role} != {loaded.role}"
assert (
orig.content == loaded.content
), f"Content mismatch for {orig.role}: {orig.content} != {loaded.content}"
if orig.tool_calls:
assert (
loaded.tool_calls is not None
), f"Tool calls missing for {orig.role} message"
assert len(orig.tool_calls) == len(loaded.tool_calls)

View File

@@ -10,8 +10,6 @@ from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from backend.util.json import dumps as json_dumps
class ResponseType(str, Enum): class ResponseType(str, Enum):
"""Types of streaming responses following AI SDK protocol.""" """Types of streaming responses following AI SDK protocol."""
@@ -20,10 +18,6 @@ class ResponseType(str, Enum):
START = "start" START = "start"
FINISH = "finish" FINISH = "finish"
# Step lifecycle (one LLM API call within a message)
START_STEP = "start-step"
FINISH_STEP = "finish-step"
# Text streaming # Text streaming
TEXT_START = "text-start" TEXT_START = "text-start"
TEXT_DELTA = "text-delta" TEXT_DELTA = "text-delta"
@@ -63,16 +57,6 @@ class StreamStart(StreamBaseResponse):
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream", description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
) )
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-protocol fields like taskId."""
import json
data: dict[str, Any] = {
"type": self.type.value,
"messageId": self.messageId,
}
return f"data: {json.dumps(data)}\n\n"
class StreamFinish(StreamBaseResponse): class StreamFinish(StreamBaseResponse):
"""End of message/stream.""" """End of message/stream."""
@@ -80,26 +64,6 @@ class StreamFinish(StreamBaseResponse):
type: ResponseType = ResponseType.FINISH type: ResponseType = ResponseType.FINISH
class StreamStartStep(StreamBaseResponse):
"""Start of a step (one LLM API call within a message).
The AI SDK uses this to add a step-start boundary to message.parts,
enabling visual separation between multiple LLM calls in a single message.
"""
type: ResponseType = ResponseType.START_STEP
class StreamFinishStep(StreamBaseResponse):
"""End of a step (one LLM API call within a message).
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
so the next LLM call in a tool-call continuation starts with clean state.
"""
type: ResponseType = ResponseType.FINISH_STEP
# ========== Text Streaming ========== # ========== Text Streaming ==========
@@ -153,7 +117,7 @@ class StreamToolOutputAvailable(StreamBaseResponse):
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
toolCallId: str = Field(..., description="Tool call ID this responds to") toolCallId: str = Field(..., description="Tool call ID this responds to")
output: str | dict[str, Any] = Field(..., description="Tool execution output") output: str | dict[str, Any] = Field(..., description="Tool execution output")
# Keep these for internal backend use # Additional fields for internal use (not part of AI SDK spec but useful)
toolName: str | None = Field( toolName: str | None = Field(
default=None, description="Name of the tool that was executed" default=None, description="Name of the tool that was executed"
) )
@@ -161,17 +125,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
default=True, description="Whether the tool execution succeeded" default=True, description="Whether the tool execution succeeded"
) )
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-spec fields."""
import json
data = {
"type": self.type.value,
"toolCallId": self.toolCallId,
"output": self.output,
}
return f"data: {json.dumps(data)}\n\n"
# ========== Other ========== # ========== Other ==========
@@ -195,18 +148,6 @@ class StreamError(StreamBaseResponse):
default=None, description="Additional error details" default=None, description="Additional error details"
) )
def to_sse(self) -> str:
"""Convert to SSE format, only emitting fields required by AI SDK protocol.
The AI SDK uses z.strictObject({type, errorText}) which rejects
any extra fields like `code` or `details`.
"""
data = {
"type": self.type.value,
"errorText": self.errorText,
}
return f"data: {json_dumps(data)}\n\n"
class StreamHeartbeat(StreamBaseResponse): class StreamHeartbeat(StreamBaseResponse):
"""Heartbeat to keep SSE connection alive during long-running operations. """Heartbeat to keep SSE connection alive during long-running operations.

View File

@@ -1,60 +1,24 @@
"""Chat API routes for chat session management and streaming via SSE.""" """Chat API routes for chat session management and streaming via SSE."""
import asyncio
import logging import logging
import uuid as uuid_module import uuid as uuid_module
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Annotated from typing import Annotated
from autogpt_libs import auth from autogpt_libs import auth
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.completion_handler import (
process_operation_failure,
process_operation_success,
)
from backend.copilot.config import ChatConfig
from backend.copilot.executor.utils import enqueue_copilot_task
from backend.copilot.model import (
ChatMessage,
ChatSession,
append_and_save_message,
create_chat_session,
delete_chat_session,
get_chat_session,
get_user_sessions,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
AgentPreviewResponse,
AgentSavedResponse,
AgentsFoundResponse,
BlockDetailsResponse,
BlockListResponse,
BlockOutputResponse,
ClarificationNeededResponse,
DocPageResponse,
DocSearchResultsResponse,
ErrorResponse,
ExecutionStartedResponse,
InputValidationErrorResponse,
NeedLoginResponse,
NoResultsResponse,
OperationInProgressResponse,
OperationPendingResponse,
OperationStartedResponse,
SetupRequirementsResponse,
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError from backend.util.exceptions import NotFoundError
from . import service as chat_service
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
config = ChatConfig() config = ChatConfig()
@@ -213,43 +177,6 @@ async def create_session(
) )
@router.delete(
"/sessions/{session_id}",
dependencies=[Security(auth.requires_user)],
status_code=204,
responses={404: {"description": "Session not found or access denied"}},
)
async def delete_session(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> Response:
"""
Delete a chat session.
Permanently removes a chat session and all its messages.
Only the owner can delete their sessions.
Args:
session_id: The session ID to delete.
user_id: The authenticated user's ID.
Returns:
204 No Content on success.
Raises:
HTTPException: 404 if session not found or not owned by user.
"""
deleted = await delete_chat_session(session_id, user_id)
if not deleted:
raise HTTPException(
status_code=404,
detail=f"Session {session_id} not found or access denied",
)
return Response(status_code=204)
@router.get( @router.get(
"/sessions/{session_id}", "/sessions/{session_id}",
) )
@@ -282,10 +209,6 @@ async def get_session(
active_task, last_message_id = await stream_registry.get_active_task_for_session( active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id session_id, user_id
) )
logger.info(
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_task: if active_task:
# Filter out the in-progress assistant message from the session response. # Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE # The client will receive the complete assistant response through the SSE
@@ -343,54 +266,12 @@ async def stream_chat_post(
""" """
import asyncio import asyncio
import time
stream_start_time = time.perf_counter() session = await _validate_and_get_session(session_id, user_id)
log_meta = {"component": "ChatStream", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
f"user={user_id}, message_len={len(request.message)}",
extra={"json_fields": log_meta},
)
await _validate_and_get_session(session_id, user_id)
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
}
},
)
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
# message loss from concurrent requests.
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support # Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4()) task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4()) operation_id = str(uuid_module.uuid4())
log_meta["task_id"] = task_id
task_create_start = time.perf_counter()
await stream_registry.create_task( await stream_registry.create_task(
task_id=task_id, task_id=task_id,
session_id=session_id, session_id=session_id,
@@ -399,45 +280,40 @@ async def stream_chat_post(
tool_name="chat", tool_name="chat",
operation_id=operation_id, operation_id=operation_id,
) )
logger.info(
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
}
},
)
await enqueue_copilot_task( # Background task that runs the AI generation independently of SSE connection
task_id=task_id, async def run_ai_generation():
session_id=session_id, try:
user_id=user_id, # Emit a start event with task_id for reconnection
operation_id=operation_id, start_chunk = StreamStart(messageId=task_id, taskId=task_id)
message=request.message, await stream_registry.publish_chunk(task_id, start_chunk)
is_user_message=request.is_user_message,
context=request.context,
)
setup_time = (time.perf_counter() - stream_start_time) * 1000 async for chunk in chat_service.stream_chat_completion(
logger.info( session_id,
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms", request.message,
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}}, is_user_message=request.is_user_message,
) user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
# Write to Redis (subscribers will receive via XREAD)
await stream_registry.publish_chunk(task_id, chunk)
# Mark task as completed
await stream_registry.mark_task_completed(task_id, "completed")
except Exception as e:
logger.error(
f"Error in background AI generation for session {session_id}: {e}"
)
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
bg_task = asyncio.create_task(run_ai_generation())
await stream_registry.set_task_asyncio_task(task_id, bg_task)
# SSE endpoint that subscribes to the task's stream # SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]: async def event_generator() -> AsyncGenerator[str, None]:
import time as time_module
event_gen_start = time_module.perf_counter()
logger.info(
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
f"user={user_id}",
extra={"json_fields": log_meta},
)
subscriber_queue = None subscriber_queue = None
first_chunk_yielded = False
chunks_yielded = 0
try: try:
# Subscribe to the task stream (this replays existing messages + live updates) # Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task( subscriber_queue = await stream_registry.subscribe_to_task(
@@ -452,78 +328,24 @@ async def stream_chat_post(
return return
# Read from the subscriber queue and yield to SSE # Read from the subscriber queue and yield to SSE
logger.info(
"[TIMING] Starting to read from subscriber_queue",
extra={"json_fields": log_meta},
)
while True: while True:
try: try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0) chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
chunks_yielded += 1
if not first_chunk_yielded:
first_chunk_yielded = True
elapsed = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
f"type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"chunk_type": type(chunk).__name__,
"elapsed_ms": elapsed * 1000,
}
},
)
yield chunk.to_sse() yield chunk.to_sse()
# Check for finish signal # Check for finish signal
if isinstance(chunk, StreamFinish): if isinstance(chunk, StreamFinish):
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
f"n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"total_time_ms": total_time * 1000,
}
},
)
break break
except asyncio.TimeoutError: except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse() yield StreamHeartbeat().to_sse()
except GeneratorExit: except GeneratorExit:
logger.info(
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"reason": "client_disconnect",
}
},
)
pass # Client disconnected - background task continues pass # Client disconnected - background task continues
except Exception as e: except Exception as e:
elapsed = (time_module.perf_counter() - event_gen_start) * 1000 logger.error(f"Error in SSE stream for task {task_id}: {e}")
logger.error(
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
extra={
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
},
)
# Surface error to frontend so it doesn't appear stuck
yield StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
finally: finally:
# Unsubscribe when client disconnects or stream ends # Unsubscribe when client disconnects or stream ends to prevent resource leak
if subscriber_queue is not None: if subscriber_queue is not None:
try: try:
await stream_registry.unsubscribe_from_task( await stream_registry.unsubscribe_from_task(
@@ -535,18 +357,6 @@ async def stream_chat_post(
exc_info=True, exc_info=True,
) )
# AI SDK protocol termination - always yield even if unsubscribe fails # AI SDK protocol termination - always yield even if unsubscribe fails
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time * 1000,
"chunks_yielded": chunks_yielded,
}
},
)
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse( return StreamingResponse(
@@ -564,90 +374,63 @@ async def stream_chat_post(
@router.get( @router.get(
"/sessions/{session_id}/stream", "/sessions/{session_id}/stream",
) )
async def resume_session_stream( async def stream_chat_get(
session_id: str, session_id: str,
message: Annotated[str, Query(min_length=1, max_length=10000)],
user_id: str | None = Depends(auth.get_user_id), user_id: str | None = Depends(auth.get_user_id),
is_user_message: bool = Query(default=True),
): ):
""" """
Resume an active stream for a session. Stream chat responses for a session (GET - legacy endpoint).
Called by the AI SDK's ``useChat(resume: true)`` on page load. Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
Checks for an active (in-progress) task on the session and either replays - Text fragments as they are generated
the full SSE stream or returns 204 No Content if nothing is running. - Tool call UI elements (if invoked)
- Tool execution results
Args: Args:
session_id: The chat session identifier. session_id: The chat session identifier to associate with the streamed messages.
message: The user's new message to process.
user_id: Optional authenticated user ID. user_id: Optional authenticated user ID.
is_user_message: Whether the message is a user message.
Returns: Returns:
StreamingResponse (SSE) when an active stream exists, StreamingResponse: SSE-formatted response chunks.
or 204 No Content when there is nothing to resume.
""" """
import asyncio session = await _validate_and_get_session(session_id, user_id)
active_task, _last_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if not active_task:
return Response(status_code=204)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=active_task.task_id,
user_id=user_id,
last_message_id="0-0", # Full replay so useChat rebuilds the message
)
if subscriber_queue is None:
return Response(status_code=204)
async def event_generator() -> AsyncGenerator[str, None]: async def event_generator() -> AsyncGenerator[str, None]:
chunk_count = 0 chunk_count = 0
first_chunk_type: str | None = None first_chunk_type: str | None = None
try: async for chunk in chat_service.stream_chat_completion(
while True: session_id,
try: message,
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0) is_user_message=is_user_message,
if chunk_count < 3: user_id=user_id,
logger.info( session=session, # Pass pre-fetched session to avoid double-fetch
"Resume stream chunk", ):
extra={ if chunk_count < 3:
"session_id": session_id, logger.info(
"chunk_type": str(chunk.type), "Chat stream chunk",
}, extra={
) "session_id": session_id,
if not first_chunk_type: "chunk_type": str(chunk.type),
first_chunk_type = str(chunk.type) },
chunk_count += 1
yield chunk.to_sse()
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
except GeneratorExit:
pass
except Exception as e:
logger.error(f"Error in resume stream for session {session_id}: {e}")
finally:
try:
await stream_registry.unsubscribe_from_task(
active_task.task_id, subscriber_queue
) )
except Exception as unsub_err: if not first_chunk_type:
logger.error( first_chunk_type = str(chunk.type)
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}", chunk_count += 1
exc_info=True, yield chunk.to_sse()
) logger.info(
logger.info( "Chat stream completed",
"Resume stream completed", extra={
extra={ "session_id": session_id,
"session_id": session_id, "chunk_count": chunk_count,
"n_chunks": chunk_count, "first_chunk_type": first_chunk_type,
"first_chunk_type": first_chunk_type, },
}, )
) # AI SDK protocol termination
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse( return StreamingResponse(
event_generator(), event_generator(),
@@ -655,8 +438,8 @@ async def resume_session_stream(
headers={ headers={
"Cache-Control": "no-cache", "Cache-Control": "no-cache",
"Connection": "keep-alive", "Connection": "keep-alive",
"X-Accel-Buffering": "no", "X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", "x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
}, },
) )
@@ -767,6 +550,8 @@ async def stream_task(
) )
async def event_generator() -> AsyncGenerator[str, None]: async def event_generator() -> AsyncGenerator[str, None]:
import asyncio
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try: try:
while True: while True:
@@ -966,43 +751,3 @@ async def health_check() -> dict:
"service": "chat", "service": "chat",
"version": "0.1.0", "version": "0.1.0",
} }
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
ToolResponseUnion = (
AgentsFoundResponse
| NoResultsResponse
| AgentDetailsResponse
| SetupRequirementsResponse
| ExecutionStartedResponse
| NeedLoginResponse
| ErrorResponse
| InputValidationErrorResponse
| AgentOutputResponse
| UnderstandingUpdatedResponse
| AgentPreviewResponse
| AgentSavedResponse
| ClarificationNeededResponse
| BlockListResponse
| BlockDetailsResponse
| BlockOutputResponse
| DocSearchResultsResponse
| DocPageResponse
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
)
@router.get(
"/schema/tool-responses",
response_model=ToolResponseUnion,
include_in_schema=True,
summary="[Dummy] Tool response type export for codegen",
description="This endpoint is not meant to be called. It exists solely to "
"expose tool response models in the OpenAPI schema for frontend codegen.",
)
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
"""Never called at runtime. Exists only so Orval generates TS types."""
raise HTTPException(status_code=501, detail="Schema-only endpoint")

View File

@@ -27,18 +27,20 @@ from openai.types.chat import (
ChatCompletionToolParam, ChatCompletionToolParam,
) )
from backend.data.db_accessors import chat_db, understanding_db
from backend.data.redis_client import get_redis_async from backend.data.redis_client import get_redis_async
from backend.data.understanding import format_understanding_for_prompt from backend.data.understanding import (
format_understanding_for_prompt,
get_business_understanding,
)
from backend.util.exceptions import NotFoundError from backend.util.exceptions import NotFoundError
from backend.util.settings import AppEnvironment, Settings from backend.util.settings import AppEnvironment, Settings
from . import db as chat_db
from . import stream_registry from . import stream_registry
from .config import ChatConfig from .config import ChatConfig
from .model import ( from .model import (
ChatMessage, ChatMessage,
ChatSession, ChatSession,
ChatSessionInfo,
Usage, Usage,
cache_chat_session, cache_chat_session,
get_chat_session, get_chat_session,
@@ -50,10 +52,8 @@ from .response_model import (
StreamBaseResponse, StreamBaseResponse,
StreamError, StreamError,
StreamFinish, StreamFinish,
StreamFinishStep,
StreamHeartbeat, StreamHeartbeat,
StreamStart, StreamStart,
StreamStartStep,
StreamTextDelta, StreamTextDelta,
StreamTextEnd, StreamTextEnd,
StreamTextStart, StreamTextStart,
@@ -243,16 +243,12 @@ async def _get_system_prompt_template(context: str) -> str:
return DEFAULT_SYSTEM_PROMPT.format(users_information=context) return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
async def _build_system_prompt( async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
user_id: str | None, has_conversation_history: bool = False
) -> tuple[str, Any]:
"""Build the full system prompt including business understanding if available. """Build the full system prompt including business understanding if available.
Args: Args:
user_id: The user ID for fetching business understanding. user_id: The user ID for fetching business understanding
has_conversation_history: Whether there's existing conversation history. If "default" and this is the user's first session, will use "onboarding" instead.
If True, we don't tell the model to greet/introduce (since they're
already in a conversation).
Returns: Returns:
Tuple of (compiled prompt string, business understanding object) Tuple of (compiled prompt string, business understanding object)
@@ -261,15 +257,13 @@ async def _build_system_prompt(
understanding = None understanding = None
if user_id: if user_id:
try: try:
understanding = await understanding_db().get_business_understanding(user_id) understanding = await get_business_understanding(user_id)
except Exception as e: except Exception as e:
logger.warning(f"Failed to fetch business understanding: {e}") logger.warning(f"Failed to fetch business understanding: {e}")
understanding = None understanding = None
if understanding: if understanding:
context = format_understanding_for_prompt(understanding) context = format_understanding_for_prompt(understanding)
elif has_conversation_history:
context = "No prior understanding saved yet. Continue the existing conversation naturally."
else: else:
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform" context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
@@ -337,7 +331,7 @@ async def _generate_session_title(
async def assign_user_to_session( async def assign_user_to_session(
session_id: str, session_id: str,
user_id: str, user_id: str,
) -> ChatSessionInfo: ) -> ChatSession:
""" """
Assign a user to a chat session. Assign a user to a chat session.
""" """
@@ -357,10 +351,6 @@ async def stream_chat_completion(
retry_count: int = 0, retry_count: int = 0,
session: ChatSession | None = None, session: ChatSession | None = None,
context: dict[str, str] | None = None, # {url: str, content: str} context: dict[str, str] | None = None, # {url: str, content: str}
_continuation_message_id: (
str | None
) = None, # Internal: reuse message ID for tool call continuations
_task_id: str | None = None, # Internal: task ID for SSE reconnection support
) -> AsyncGenerator[StreamBaseResponse, None]: ) -> AsyncGenerator[StreamBaseResponse, None]:
"""Main entry point for streaming chat completions with database handling. """Main entry point for streaming chat completions with database handling.
@@ -378,47 +368,24 @@ async def stream_chat_completion(
Raises: Raises:
NotFoundError: If session_id is invalid NotFoundError: If session_id is invalid
ValueError: If max_context_messages is exceeded
""" """
completion_start = time.monotonic()
# Build log metadata for structured logging
log_meta = {"component": "ChatService", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info( logger.info(
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, " f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
extra={
"json_fields": {
**log_meta,
"message_len": len(message) if message else 0,
"is_user_message": is_user_message,
}
},
) )
# Only fetch from Redis if session not provided (initial call) # Only fetch from Redis if session not provided (initial call)
if session is None: if session is None:
fetch_start = time.monotonic()
session = await get_chat_session(session_id, user_id) session = await get_chat_session(session_id, user_id)
fetch_time = (time.monotonic() - fetch_start) * 1000
logger.info( logger.info(
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, " f"Fetched session from Redis: {session.session_id if session else 'None'}, "
f"n_messages={len(session.messages) if session else 0}", f"message_count={len(session.messages) if session else 0}"
extra={
"json_fields": {
**log_meta,
"duration_ms": fetch_time,
"n_messages": len(session.messages) if session else 0,
}
},
) )
else: else:
logger.info( logger.info(
f"[TIMING] Using provided session, messages={len(session.messages)}", f"Using provided session object: {session.session_id}, "
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}}, f"message_count={len(session.messages)}"
) )
if not session: if not session:
@@ -426,16 +393,12 @@ async def stream_chat_completion(
f"Session {session_id} not found. Please create a new session first." f"Session {session_id} not found. Please create a new session first."
) )
# Append the new message to the session if it's not already there if message:
new_message_role = "user" if is_user_message else "assistant" session.messages.append(
if message and ( ChatMessage(
len(session.messages) == 0 role="user" if is_user_message else "assistant", content=message
or not ( )
session.messages[-1].role == new_message_role
and session.messages[-1].content == message
) )
):
session.messages.append(ChatMessage(role=new_message_role, content=message))
logger.info( logger.info(
f"Appended message (role={'user' if is_user_message else 'assistant'}), " f"Appended message (role={'user' if is_user_message else 'assistant'}), "
f"new message_count={len(session.messages)}" f"new message_count={len(session.messages)}"
@@ -443,32 +406,23 @@ async def stream_chat_completion(
# Track user message in PostHog # Track user message in PostHog
if is_user_message: if is_user_message:
posthog_start = time.monotonic()
track_user_message( track_user_message(
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
message_length=len(message), message_length=len(message),
) )
posthog_time = (time.monotonic() - posthog_start) * 1000
logger.info(
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
)
upsert_start = time.monotonic()
session = await upsert_chat_session(session)
upsert_time = (time.monotonic() - upsert_start) * 1000
logger.info( logger.info(
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms", f"Upserting session: {session.session_id} with user id {session.user_id}, "
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}}, f"message_count={len(session.messages)}"
) )
session = await upsert_chat_session(session)
assert session, "Session not found" assert session, "Session not found"
# Generate title for new sessions on first user message (non-blocking) # Generate title for new sessions on first user message (non-blocking)
# Check: is_user_message, no title yet, and this is the first user message # Check: is_user_message, no title yet, and this is the first user message
user_messages = [m for m in session.messages if m.role == "user"] if is_user_message and message and not session.title:
first_user_msg = message or (user_messages[0].content if user_messages else None) user_messages = [m for m in session.messages if m.role == "user"]
if is_user_message and first_user_msg and not session.title:
if len(user_messages) == 1: if len(user_messages) == 1:
# First user message - generate title in background # First user message - generate title in background
import asyncio import asyncio
@@ -476,7 +430,7 @@ async def stream_chat_completion(
# Capture only the values we need (not the session object) to avoid # Capture only the values we need (not the session object) to avoid
# stale data issues when the main flow modifies the session # stale data issues when the main flow modifies the session
captured_session_id = session_id captured_session_id = session_id
captured_message = first_user_msg captured_message = message
captured_user_id = user_id captured_user_id = user_id
async def _update_title(): async def _update_title():
@@ -500,13 +454,7 @@ async def stream_chat_completion(
asyncio.create_task(_update_title()) asyncio.create_task(_update_title())
# Build system prompt with business understanding # Build system prompt with business understanding
prompt_start = time.monotonic()
system_prompt, understanding = await _build_system_prompt(user_id) system_prompt, understanding = await _build_system_prompt(user_id)
prompt_time = (time.monotonic() - prompt_start) * 1000
logger.info(
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
)
# Initialize variables for streaming # Initialize variables for streaming
assistant_response = ChatMessage( assistant_response = ChatMessage(
@@ -531,27 +479,13 @@ async def stream_chat_completion(
# Generate unique IDs for AI SDK protocol # Generate unique IDs for AI SDK protocol
import uuid as uuid_module import uuid as uuid_module
is_continuation = _continuation_message_id is not None message_id = str(uuid_module.uuid4())
message_id = _continuation_message_id or str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4()) text_block_id = str(uuid_module.uuid4())
# Only yield message start for the initial call, not for continuations. # Yield message start
setup_time = (time.monotonic() - completion_start) * 1000 yield StreamStart(messageId=message_id)
logger.info(
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
if not is_continuation:
yield StreamStart(messageId=message_id, taskId=_task_id)
# Emit start-step before each LLM call (AI SDK uses this to add step boundaries)
yield StreamStartStep()
try: try:
logger.info(
"[TIMING] Calling _stream_chat_chunks",
extra={"json_fields": log_meta},
)
async for chunk in _stream_chat_chunks( async for chunk in _stream_chat_chunks(
session=session, session=session,
tools=tools, tools=tools,
@@ -651,10 +585,6 @@ async def stream_chat_completion(
) )
yield chunk yield chunk
elif isinstance(chunk, StreamFinish): elif isinstance(chunk, StreamFinish):
if has_done_tool_call:
# Tool calls happened — close the step but don't send message-level finish.
# The continuation will open a new step, and finish will come at the end.
yield StreamFinishStep()
if not has_done_tool_call: if not has_done_tool_call:
# Emit text-end before finish if we received text but haven't closed it # Emit text-end before finish if we received text but haven't closed it
if has_received_text and not text_streaming_ended: if has_received_text and not text_streaming_ended:
@@ -686,8 +616,6 @@ async def stream_chat_completion(
has_saved_assistant_message = True has_saved_assistant_message = True
has_yielded_end = True has_yielded_end = True
# Emit finish-step before finish (resets AI SDK text/reasoning state)
yield StreamFinishStep()
yield chunk yield chunk
elif isinstance(chunk, StreamError): elif isinstance(chunk, StreamError):
has_yielded_error = True has_yielded_error = True
@@ -737,10 +665,6 @@ async def stream_chat_completion(
logger.info( logger.info(
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}" f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
) )
# Close the current step before retrying so the recursive call's
# StreamStartStep doesn't produce unbalanced step events.
if not has_yielded_end:
yield StreamFinishStep()
should_retry = True should_retry = True
else: else:
# Non-retryable error or max retries exceeded # Non-retryable error or max retries exceeded
@@ -776,7 +700,6 @@ async def stream_chat_completion(
error_response = StreamError(errorText=error_message) error_response = StreamError(errorText=error_message)
yield error_response yield error_response
if not has_yielded_end: if not has_yielded_end:
yield StreamFinishStep()
yield StreamFinish() yield StreamFinish()
return return
@@ -791,8 +714,6 @@ async def stream_chat_completion(
retry_count=retry_count + 1, retry_count=retry_count + 1,
session=session, session=session,
context=context, context=context,
_continuation_message_id=message_id, # Reuse message ID since start was already sent
_task_id=_task_id,
): ):
yield chunk yield chunk
return # Exit after retry to avoid double-saving in finally block return # Exit after retry to avoid double-saving in finally block
@@ -808,13 +729,9 @@ async def stream_chat_completion(
# Build the messages list in the correct order # Build the messages list in the correct order
messages_to_save: list[ChatMessage] = [] messages_to_save: list[ChatMessage] = []
# Add assistant message with tool_calls if any. # Add assistant message with tool_calls if any
# Use extend (not assign) to preserve tool_calls already added by
# _yield_tool_call for long-running tools.
if accumulated_tool_calls: if accumulated_tool_calls:
if not assistant_response.tool_calls: assistant_response.tool_calls = accumulated_tool_calls
assistant_response.tool_calls = []
assistant_response.tool_calls.extend(accumulated_tool_calls)
logger.info( logger.info(
f"Added {len(accumulated_tool_calls)} tool calls to assistant message" f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
) )
@@ -866,8 +783,6 @@ async def stream_chat_completion(
session=session, # Pass session object to avoid Redis refetch session=session, # Pass session object to avoid Redis refetch
context=context, context=context,
tool_call_response=str(tool_response_messages), tool_call_response=str(tool_response_messages),
_continuation_message_id=message_id, # Reuse message ID to avoid duplicates
_task_id=_task_id,
): ):
yield chunk yield chunk
@@ -978,21 +893,9 @@ async def _stream_chat_chunks(
SSE formatted JSON response objects SSE formatted JSON response objects
""" """
import time as time_module
stream_chunks_start = time_module.perf_counter()
model = config.model model = config.model
# Build log metadata for structured logging logger.info("Starting pure chat stream")
log_meta = {"component": "ChatService", "session_id": session.session_id}
if session.user_id:
log_meta["user_id"] = session.user_id
logger.info(
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
f"user={session.user_id}, n_messages={len(session.messages)}",
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
)
messages = session.to_openai_messages() messages = session.to_openai_messages()
if system_prompt: if system_prompt:
@@ -1003,18 +906,12 @@ async def _stream_chat_chunks(
messages = [system_message] + messages messages = [system_message] + messages
# Apply context window management # Apply context window management
context_start = time_module.perf_counter()
context_result = await _manage_context_window( context_result = await _manage_context_window(
messages=messages, messages=messages,
model=model, model=model,
api_key=config.api_key, api_key=config.api_key,
base_url=config.base_url, base_url=config.base_url,
) )
context_time = (time_module.perf_counter() - context_start) * 1000
logger.info(
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
)
if context_result.error: if context_result.error:
if "System prompt dropped" in context_result.error: if "System prompt dropped" in context_result.error:
@@ -1049,19 +946,9 @@ async def _stream_chat_chunks(
while retry_count <= MAX_RETRIES: while retry_count <= MAX_RETRIES:
try: try:
elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000
retry_info = (
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
)
logger.info( logger.info(
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}", f"Creating OpenAI chat completion stream..."
extra={ f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"retry_count": retry_count,
}
},
) )
# Build extra_body for OpenRouter tracing and PostHog analytics # Build extra_body for OpenRouter tracing and PostHog analytics
@@ -1078,11 +965,6 @@ async def _stream_chat_chunks(
:128 :128
] # OpenRouter limit ] # OpenRouter limit
# Enable adaptive thinking for Anthropic models via OpenRouter
if config.thinking_enabled and "anthropic" in model.lower():
extra_body["reasoning"] = {"enabled": True}
api_call_start = time_module.perf_counter()
stream = await client.chat.completions.create( stream = await client.chat.completions.create(
model=model, model=model,
messages=cast(list[ChatCompletionMessageParam], messages), messages=cast(list[ChatCompletionMessageParam], messages),
@@ -1092,11 +974,6 @@ async def _stream_chat_chunks(
stream_options=ChatCompletionStreamOptionsParam(include_usage=True), stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
extra_body=extra_body, extra_body=extra_body,
) )
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
logger.info(
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
)
# Variables to accumulate tool calls # Variables to accumulate tool calls
tool_calls: list[dict[str, Any]] = [] tool_calls: list[dict[str, Any]] = []
@@ -1107,13 +984,10 @@ async def _stream_chat_chunks(
# Track if we've started the text block # Track if we've started the text block
text_started = False text_started = False
first_content_chunk = True
chunk_count = 0
# Process the stream # Process the stream
chunk: ChatCompletionChunk chunk: ChatCompletionChunk
async for chunk in stream: async for chunk in stream:
chunk_count += 1
if chunk.usage: if chunk.usage:
yield StreamUsage( yield StreamUsage(
promptTokens=chunk.usage.prompt_tokens, promptTokens=chunk.usage.prompt_tokens,
@@ -1136,23 +1010,6 @@ async def _stream_chat_chunks(
if not text_started and text_block_id: if not text_started and text_block_id:
yield StreamTextStart(id=text_block_id) yield StreamTextStart(id=text_block_id)
text_started = True text_started = True
# Log timing for first content chunk
if first_content_chunk:
first_content_chunk = False
ttfc = (
time_module.perf_counter() - api_call_start
) * 1000
logger.info(
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
f"(since API call), n_chunks={chunk_count}",
extra={
"json_fields": {
**log_meta,
"time_to_first_chunk_ms": ttfc,
"n_chunks": chunk_count,
}
},
)
# Stream the text delta # Stream the text delta
text_response = StreamTextDelta( text_response = StreamTextDelta(
id=text_block_id or "", id=text_block_id or "",
@@ -1209,21 +1066,7 @@ async def _stream_chat_chunks(
toolName=tool_calls[idx]["function"]["name"], toolName=tool_calls[idx]["function"]["name"],
) )
emitted_start_for_idx.add(idx) emitted_start_for_idx.add(idx)
stream_duration = time_module.perf_counter() - api_call_start logger.info(f"Stream complete. Finish reason: {finish_reason}")
logger.info(
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
f"duration={stream_duration:.2f}s, "
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
extra={
"json_fields": {
**log_meta,
"stream_duration_ms": stream_duration * 1000,
"finish_reason": finish_reason,
"n_chunks": chunk_count,
"n_tool_calls": len(tool_calls),
}
},
)
# Yield all accumulated tool calls after the stream is complete # Yield all accumulated tool calls after the stream is complete
# This ensures all tool call arguments have been fully received # This ensures all tool call arguments have been fully received
@@ -1243,17 +1086,10 @@ async def _stream_chat_chunks(
# Re-raise to trigger retry logic in the parent function # Re-raise to trigger retry logic in the parent function
raise raise
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
logger.info(
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
f"session={session.session_id}, user={session.user_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
yield StreamFinish() yield StreamFinish()
return return
except Exception as e: except Exception as e:
last_error = e last_error = e
if _is_retryable_error(e) and retry_count < MAX_RETRIES: if _is_retryable_error(e) and retry_count < MAX_RETRIES:
retry_count += 1 retry_count += 1
# Calculate delay with exponential backoff # Calculate delay with exponential backoff
@@ -1269,27 +1105,12 @@ async def _stream_chat_chunks(
continue # Retry the stream continue # Retry the stream
else: else:
# Non-retryable error or max retries exceeded # Non-retryable error or max retries exceeded
_log_api_error( logger.error(
error=e, f"Error in stream (not retrying): {e!s}",
context="stream (not retrying)", exc_info=True,
session_id=session.session_id if session else None,
message_count=len(messages) if messages else None,
model=model,
retry_count=retry_count,
) )
error_code = None error_code = None
error_text = str(e) error_text = str(e)
error_details = _extract_api_error_details(e)
if error_details.get("response_body"):
body = error_details["response_body"]
if isinstance(body, dict):
err = body.get("error")
if isinstance(err, dict) and err.get("message"):
error_text = err["message"]
elif body.get("message"):
error_text = body["message"]
if _is_region_blocked_error(e): if _is_region_blocked_error(e):
error_code = "MODEL_NOT_AVAILABLE_REGION" error_code = "MODEL_NOT_AVAILABLE_REGION"
error_text = ( error_text = (
@@ -1306,13 +1127,9 @@ async def _stream_chat_chunks(
# If we exit the retry loop without returning, it means we exhausted retries # If we exit the retry loop without returning, it means we exhausted retries
if last_error: if last_error:
_log_api_error( logger.error(
error=last_error, f"Max retries ({MAX_RETRIES}) exceeded. Last error: {last_error!s}",
context=f"stream (max retries {MAX_RETRIES} exceeded)", exc_info=True,
session_id=session.session_id if session else None,
message_count=len(messages) if messages else None,
model=model,
retry_count=MAX_RETRIES,
) )
yield StreamError(errorText=f"Max retries exceeded: {last_error!s}") yield StreamError(errorText=f"Max retries exceeded: {last_error!s}")
yield StreamFinish() yield StreamFinish()
@@ -1436,9 +1253,13 @@ async def _yield_tool_call(
operation_id=operation_id, operation_id=operation_id,
) )
# Attach the tool_call to the current turn's assistant message # Save assistant message with tool_call FIRST (required by LLM)
# (or create one if this is a tool-only response with no text). assistant_message = ChatMessage(
session.add_tool_call_to_current_turn(tool_calls[yield_idx]) role="assistant",
content="",
tool_calls=[tool_calls[yield_idx]],
)
session.messages.append(assistant_message)
# Then save pending tool result # Then save pending tool result
pending_message = ChatMessage( pending_message = ChatMessage(
@@ -1744,7 +1565,6 @@ async def _execute_long_running_tool_with_streaming(
task_id, task_id,
StreamError(errorText=str(e)), StreamError(errorText=str(e)),
) )
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish()) await stream_registry.publish_chunk(task_id, StreamFinish())
await _update_pending_operation( await _update_pending_operation(
@@ -1772,7 +1592,7 @@ async def _update_pending_operation(
This is called by background tasks when long-running operations complete. This is called by background tasks when long-running operations complete.
""" """
# Update the message in database # Update the message in database
updated = await chat_db().update_tool_message_content( updated = await chat_db.update_tool_message_content(
session_id=session_id, session_id=session_id,
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
new_content=result, new_content=result,
@@ -1861,10 +1681,6 @@ async def _generate_llm_continuation(
if session_id: if session_id:
extra_body["session_id"] = session_id[:128] extra_body["session_id"] = session_id[:128]
# Enable adaptive thinking for Anthropic models via OpenRouter
if config.thinking_enabled and "anthropic" in config.model.lower():
extra_body["reasoning"] = {"enabled": True}
retry_count = 0 retry_count = 0
last_error: Exception | None = None last_error: Exception | None = None
response = None response = None
@@ -1885,7 +1701,6 @@ async def _generate_llm_continuation(
break # Success, exit retry loop break # Success, exit retry loop
except Exception as e: except Exception as e:
last_error = e last_error = e
if _is_retryable_error(e) and retry_count < MAX_RETRIES: if _is_retryable_error(e) and retry_count < MAX_RETRIES:
retry_count += 1 retry_count += 1
delay = min( delay = min(
@@ -1899,25 +1714,17 @@ async def _generate_llm_continuation(
await asyncio.sleep(delay) await asyncio.sleep(delay)
continue continue
else: else:
# Non-retryable error - log details and exit gracefully # Non-retryable error - log and exit gracefully
_log_api_error( logger.error(
error=e, f"Non-retryable error in LLM continuation: {e!s}",
context="LLM continuation (not retrying)", exc_info=True,
session_id=session_id,
message_count=len(messages) if messages else None,
model=config.model,
retry_count=retry_count,
) )
return return
if last_error: if last_error:
_log_api_error( logger.error(
error=last_error, f"Max retries ({MAX_RETRIES}) exceeded for LLM continuation. "
context=f"LLM continuation (max retries {MAX_RETRIES} exceeded)", f"Last error: {last_error!s}"
session_id=session_id,
message_count=len(messages) if messages else None,
model=config.model,
retry_count=MAX_RETRIES,
) )
return return
@@ -1957,91 +1764,6 @@ async def _generate_llm_continuation(
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True) logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
def _log_api_error(
error: Exception,
context: str,
session_id: str | None = None,
message_count: int | None = None,
model: str | None = None,
retry_count: int = 0,
) -> None:
"""Log detailed API error information for debugging."""
details = _extract_api_error_details(error)
details["context"] = context
details["session_id"] = session_id
details["message_count"] = message_count
details["model"] = model
details["retry_count"] = retry_count
if isinstance(error, RateLimitError):
logger.warning(f"Rate limit error in {context}: {details}", exc_info=error)
elif isinstance(error, APIConnectionError):
logger.warning(f"API connection error in {context}: {details}", exc_info=error)
elif isinstance(error, APIStatusError) and error.status_code >= 500:
logger.error(f"API server error (5xx) in {context}: {details}", exc_info=error)
else:
logger.error(f"API error in {context}: {details}", exc_info=error)
def _extract_api_error_details(error: Exception) -> dict[str, Any]:
"""Extract detailed information from OpenAI/OpenRouter API errors."""
error_msg = str(error)
details: dict[str, Any] = {
"error_type": type(error).__name__,
"error_message": error_msg[:500] + "..." if len(error_msg) > 500 else error_msg,
}
if hasattr(error, "code"):
details["code"] = getattr(error, "code", None)
if hasattr(error, "param"):
details["param"] = getattr(error, "param", None)
if isinstance(error, APIStatusError):
details["status_code"] = error.status_code
details["request_id"] = getattr(error, "request_id", None)
if hasattr(error, "body") and error.body:
details["response_body"] = _sanitize_error_body(error.body)
if hasattr(error, "response") and error.response:
headers = error.response.headers
details["openrouter_provider"] = headers.get("x-openrouter-provider")
details["openrouter_model"] = headers.get("x-openrouter-model")
details["retry_after"] = headers.get("retry-after")
details["rate_limit_remaining"] = headers.get("x-ratelimit-remaining")
return details
def _sanitize_error_body(
body: Any, max_length: int = 2000
) -> dict[str, Any] | str | None:
"""Extract only safe fields from error response body to avoid logging sensitive data."""
if not isinstance(body, dict):
# Non-dict bodies (e.g., HTML error pages) - return truncated string
if body is not None:
body_str = str(body)
if len(body_str) > max_length:
return body_str[:max_length] + "...[truncated]"
return body_str
return None
safe_fields = ("message", "type", "code", "param", "error")
sanitized: dict[str, Any] = {}
for field in safe_fields:
if field in body:
value = body[field]
if field == "error" and isinstance(value, dict):
sanitized[field] = _sanitize_error_body(value, max_length)
elif isinstance(value, str) and len(value) > max_length:
sanitized[field] = value[:max_length] + "...[truncated]"
else:
sanitized[field] = value
return sanitized if sanitized else None
async def _generate_llm_continuation_with_streaming( async def _generate_llm_continuation_with_streaming(
session_id: str, session_id: str,
user_id: str | None, user_id: str | None,
@@ -2089,10 +1811,6 @@ async def _generate_llm_continuation_with_streaming(
if session_id: if session_id:
extra_body["session_id"] = session_id[:128] extra_body["session_id"] = session_id[:128]
# Enable adaptive thinking for Anthropic models via OpenRouter
if config.thinking_enabled and "anthropic" in config.model.lower():
extra_body["reasoning"] = {"enabled": True}
# Make streaming LLM call (no tools - just text response) # Make streaming LLM call (no tools - just text response)
from typing import cast from typing import cast
@@ -2104,7 +1822,6 @@ async def _generate_llm_continuation_with_streaming(
# Publish start event # Publish start event
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id)) await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
await stream_registry.publish_chunk(task_id, StreamStartStep())
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id)) await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
# Stream the response # Stream the response
@@ -2128,7 +1845,6 @@ async def _generate_llm_continuation_with_streaming(
# Publish end events # Publish end events
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id)) await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
await stream_registry.publish_chunk(task_id, StreamFinishStep())
if assistant_content: if assistant_content:
# Reload session from DB to avoid race condition with user messages # Reload session from DB to avoid race condition with user messages
@@ -2170,5 +1886,4 @@ async def _generate_llm_continuation_with_streaming(
task_id, task_id,
StreamError(errorText=f"Failed to generate response: {e}"), StreamError(errorText=f"Failed to generate response: {e}"),
) )
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish()) await stream_registry.publish_chunk(task_id, StreamFinish())

View File

@@ -0,0 +1,82 @@
import logging
from os import getenv
import pytest
from . import service as chat_service
from .model import create_chat_session, get_chat_session, upsert_chat_session
from .response_model import (
StreamError,
StreamFinish,
StreamTextDelta,
StreamToolOutputAvailable,
)
logger = logging.getLogger(__name__)
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion(setup_test_user, test_user_id):
"""
Test the stream_chat_completion function.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
has_errors = False
has_ended = False
assistant_message = ""
async for chunk in chat_service.stream_chat_completion(
session.session_id, "Hello, how are you?", user_id=session.user_id
):
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamTextDelta):
assistant_message += chunk.delta
if isinstance(chunk, StreamFinish):
has_ended = True
assert has_ended, "Chat completion did not end"
assert not has_errors, "Error occurred while streaming chat completion"
assert assistant_message, "Assistant message is empty"
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
"""
Test the stream_chat_completion function.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
has_errors = False
has_ended = False
had_tool_calls = False
async for chunk in chat_service.stream_chat_completion(
session.session_id,
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
user_id=session.user_id,
):
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamFinish):
has_ended = True
if isinstance(chunk, StreamToolOutputAvailable):
had_tool_calls = True
assert has_ended, "Chat completion did not end"
assert not has_errors, "Error occurred while streaming chat completion"
assert had_tool_calls, "Tool calls did not occur"
session = await get_chat_session(session.session_id)
assert session, "Session not found"
assert session.usage, "Usage is empty"

View File

@@ -104,24 +104,6 @@ async def create_task(
Returns: Returns:
The created ActiveTask instance (metadata only) The created ActiveTask instance (metadata only)
""" """
import time
start_time = time.perf_counter()
# Build log metadata for structured logging
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"session_id": session_id,
}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
extra={"json_fields": log_meta},
)
task = ActiveTask( task = ActiveTask(
task_id=task_id, task_id=task_id,
session_id=session_id, session_id=session_id,
@@ -132,18 +114,10 @@ async def create_task(
) )
# Store metadata in Redis # Store metadata in Redis
redis_start = time.perf_counter()
redis = await get_redis_async() redis = await get_redis_async()
redis_time = (time.perf_counter() - redis_start) * 1000
logger.info(
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
)
meta_key = _get_task_meta_key(task_id) meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id) op_key = _get_operation_mapping_key(operation_id)
hset_start = time.perf_counter()
await redis.hset( # type: ignore[misc] await redis.hset( # type: ignore[misc]
meta_key, meta_key,
mapping={ mapping={
@@ -157,22 +131,12 @@ async def create_task(
"created_at": task.created_at.isoformat(), "created_at": task.created_at.isoformat(),
}, },
) )
hset_time = (time.perf_counter() - hset_start) * 1000
logger.info(
f"[TIMING] redis.hset took {hset_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
)
await redis.expire(meta_key, config.stream_ttl) await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups # Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl) await redis.set(op_key, task_id, ex=config.stream_ttl)
total_time = (time.perf_counter() - start_time) * 1000 logger.debug(f"Created task {task_id} for session {session_id}")
logger.info(
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
return task return task
@@ -192,60 +156,26 @@ async def publish_chunk(
Returns: Returns:
The Redis Stream message ID The Redis Stream message ID
""" """
import time
start_time = time.perf_counter()
chunk_type = type(chunk).__name__
chunk_json = chunk.model_dump_json() chunk_json = chunk.model_dump_json()
message_id = "0-0" message_id = "0-0"
# Build log metadata
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"chunk_type": chunk_type,
}
try: try:
redis = await get_redis_async() redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id) stream_key = _get_task_stream_key(task_id)
# Write to Redis Stream for persistence and real-time delivery # Write to Redis Stream for persistence and real-time delivery
xadd_start = time.perf_counter()
raw_id = await redis.xadd( raw_id = await redis.xadd(
stream_key, stream_key,
{"data": chunk_json}, {"data": chunk_json},
maxlen=config.stream_max_length, maxlen=config.stream_max_length,
) )
xadd_time = (time.perf_counter() - xadd_start) * 1000
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode() message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Set TTL on stream to match task metadata TTL # Set TTL on stream to match task metadata TTL
await redis.expire(stream_key, config.stream_ttl) await redis.expire(stream_key, config.stream_ttl)
total_time = (time.perf_counter() - start_time) * 1000
# Only log timing for significant chunks or slow operations
if (
chunk_type
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
or total_time > 50
):
logger.info(
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"xadd_time_ms": xadd_time,
"message_id": message_id,
}
},
)
except Exception as e: except Exception as e:
elapsed = (time.perf_counter() - start_time) * 1000
logger.error( logger.error(
f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}", f"Failed to publish chunk for task {task_id}: {e}",
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
exc_info=True, exc_info=True,
) )
@@ -270,61 +200,24 @@ async def subscribe_to_task(
An asyncio Queue that will receive stream chunks, or None if task not found An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access or user doesn't have access
""" """
import time
start_time = time.perf_counter()
# Build log metadata
log_meta = {"component": "StreamRegistry", "task_id": task_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
)
redis_start = time.perf_counter()
redis = await get_redis_async() redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id) meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
hgetall_time = (time.perf_counter() - redis_start) * 1000
logger.info(
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
)
if not meta: if not meta:
elapsed = (time.perf_counter() - start_time) * 1000 logger.debug(f"Task {task_id} not found in Redis")
logger.info(
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"reason": "task_not_found",
}
},
)
return None return None
# Note: Redis client uses decode_responses=True, so keys are strings # Note: Redis client uses decode_responses=True, so keys are strings
task_status = meta.get("status", "") task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None task_user_id = meta.get("user_id", "") or None
log_meta["session_id"] = meta.get("session_id", "")
# Validate ownership - if task has an owner, requester must match # Validate ownership - if task has an owner, requester must match
if task_user_id: if task_user_id:
if user_id != task_user_id: if user_id != task_user_id:
logger.warning( logger.warning(
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}", f"User {user_id} denied access to task {task_id} "
extra={ f"owned by {task_user_id}"
"json_fields": {
**log_meta,
"task_owner": task_user_id,
"reason": "access_denied",
}
},
) )
return None return None
@@ -332,19 +225,7 @@ async def subscribe_to_task(
stream_key = _get_task_stream_key(task_id) stream_key = _get_task_stream_key(task_id)
# Step 1: Replay messages from Redis Stream # Step 1: Replay messages from Redis Stream
xread_start = time.perf_counter()
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000) messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
xread_time = (time.perf_counter() - xread_start) * 1000
logger.info(
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
extra={
"json_fields": {
**log_meta,
"duration_ms": xread_time,
"task_status": task_status,
}
},
)
replayed_count = 0 replayed_count = 0
replay_last_id = last_message_id replay_last_id = last_message_id
@@ -363,48 +244,19 @@ async def subscribe_to_task(
except Exception as e: except Exception as e:
logger.warning(f"Failed to replay message: {e}") logger.warning(f"Failed to replay message: {e}")
logger.info( logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
extra={
"json_fields": {
**log_meta,
"n_messages_replayed": replayed_count,
"replay_last_id": replay_last_id,
}
},
)
# Step 2: If task is still running, start stream listener for live updates # Step 2: If task is still running, start stream listener for live updates
if task_status == "running": if task_status == "running":
logger.info(
"[TIMING] Task still running, starting _stream_listener",
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
listener_task = asyncio.create_task( listener_task = asyncio.create_task(
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta) _stream_listener(task_id, subscriber_queue, replay_last_id)
) )
# Track listener task for cleanup on unsubscribe # Track listener task for cleanup on unsubscribe
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task) _listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
else: else:
# Task is completed/failed - add finish marker # Task is completed/failed - add finish marker
logger.info(
f"[TIMING] Task already {task_status}, adding StreamFinish",
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
await subscriber_queue.put(StreamFinish()) await subscriber_queue.put(StreamFinish())
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
f"n_messages_replayed={replayed_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"n_messages_replayed": replayed_count,
}
},
)
return subscriber_queue return subscriber_queue
@@ -412,7 +264,6 @@ async def _stream_listener(
task_id: str, task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse], subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str, last_replayed_id: str,
log_meta: dict | None = None,
) -> None: ) -> None:
"""Listen to Redis Stream for new messages using blocking XREAD. """Listen to Redis Stream for new messages using blocking XREAD.
@@ -423,27 +274,10 @@ async def _stream_listener(
task_id: Task ID to listen for task_id: Task ID to listen for
subscriber_queue: Queue to deliver messages to subscriber_queue: Queue to deliver messages to
last_replayed_id: Last message ID from replay (continue from here) last_replayed_id: Last message ID from replay (continue from here)
log_meta: Structured logging metadata
""" """
import time
start_time = time.perf_counter()
# Use provided log_meta or build minimal one
if log_meta is None:
log_meta = {"component": "StreamRegistry", "task_id": task_id}
logger.info(
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
)
queue_id = id(subscriber_queue) queue_id = id(subscriber_queue)
# Track the last successfully delivered message ID for recovery hints # Track the last successfully delivered message ID for recovery hints
last_delivered_id = last_replayed_id last_delivered_id = last_replayed_id
messages_delivered = 0
first_message_time = None
xread_count = 0
try: try:
redis = await get_redis_async() redis = await get_redis_async()
@@ -453,39 +287,9 @@ async def _stream_listener(
while True: while True:
# Block for up to 30 seconds waiting for new messages # Block for up to 30 seconds waiting for new messages
# This allows periodic checking if task is still running # This allows periodic checking if task is still running
xread_start = time.perf_counter()
xread_count += 1
messages = await redis.xread( messages = await redis.xread(
{stream_key: current_id}, block=30000, count=100 {stream_key: current_id}, block=30000, count=100
) )
xread_time = (time.perf_counter() - xread_start) * 1000
if messages:
msg_count = sum(len(msgs) for _, msgs in messages)
logger.info(
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"xread_count": xread_count,
"n_messages": msg_count,
"duration_ms": xread_time,
}
},
)
elif xread_time > 1000:
# Only log timeouts (30s blocking)
logger.info(
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"xread_count": xread_count,
"duration_ms": xread_time,
"reason": "timeout",
}
},
)
if not messages: if not messages:
# Timeout - check if task is still running # Timeout - check if task is still running
@@ -522,30 +326,10 @@ async def _stream_listener(
) )
# Update last delivered ID on successful delivery # Update last delivered ID on successful delivery
last_delivered_id = current_id last_delivered_id = current_id
messages_delivered += 1
if first_message_time is None:
first_message_time = time.perf_counter()
elapsed = (first_message_time - start_time) * 1000
logger.info(
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"chunk_type": type(chunk).__name__,
}
},
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning( logger.warning(
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s", f"Subscriber queue full for task {task_id}, "
extra={ f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
"json_fields": {
**log_meta,
"timeout_s": QUEUE_PUT_TIMEOUT,
"reason": "queue_full",
}
},
) )
# Send overflow error with recovery info # Send overflow error with recovery info
try: try:
@@ -567,44 +351,15 @@ async def _stream_listener(
# Stop listening on finish # Stop listening on finish
if isinstance(chunk, StreamFinish): if isinstance(chunk, StreamFinish):
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"messages_delivered": messages_delivered,
}
},
)
return return
except Exception as e: except Exception as e:
logger.warning( logger.warning(f"Error processing stream message: {e}")
f"Error processing stream message: {e}",
extra={"json_fields": {**log_meta, "error": str(e)}},
)
except asyncio.CancelledError: except asyncio.CancelledError:
elapsed = (time.perf_counter() - start_time) * 1000 logger.debug(f"Stream listener cancelled for task {task_id}")
logger.info(
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"messages_delivered": messages_delivered,
"reason": "cancelled",
}
},
)
raise # Re-raise to propagate cancellation raise # Re-raise to propagate cancellation
except Exception as e: except Exception as e:
elapsed = (time.perf_counter() - start_time) * 1000 logger.error(f"Stream listener error for task {task_id}: {e}")
logger.error(
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
)
# On error, send finish to unblock subscriber # On error, send finish to unblock subscriber
try: try:
await asyncio.wait_for( await asyncio.wait_for(
@@ -613,24 +368,10 @@ async def _stream_listener(
) )
except (asyncio.TimeoutError, asyncio.QueueFull): except (asyncio.TimeoutError, asyncio.QueueFull):
logger.warning( logger.warning(
"Could not deliver finish event after error", f"Could not deliver finish event for task {task_id} after error"
extra={"json_fields": log_meta},
) )
finally: finally:
# Clean up listener task mapping on exit # Clean up listener task mapping on exit
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
f"delivered={messages_delivered}, xread_count={xread_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"messages_delivered": messages_delivered,
"xread_count": xread_count,
}
},
)
_listener_tasks.pop(queue_id, None) _listener_tasks.pop(queue_id, None)
@@ -814,28 +555,6 @@ async def get_active_task_for_session(
if task_user_id and user_id != task_user_id: if task_user_id and user_id != task_user_id:
continue continue
# Auto-expire stale tasks that exceeded stream_timeout
created_at_str = meta.get("created_at", "")
if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str)
age_seconds = (
datetime.now(timezone.utc) - created_at
).total_seconds()
if age_seconds > config.stream_timeout:
logger.warning(
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
)
await mark_task_completed(task_id, "failed")
continue
except (ValueError, TypeError):
pass
logger.info(
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
)
# Get the last message ID from Redis Stream # Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id) stream_key = _get_task_stream_key(task_id)
last_id = "0-0" last_id = "0-0"
@@ -879,10 +598,8 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
ResponseType, ResponseType,
StreamError, StreamError,
StreamFinish, StreamFinish,
StreamFinishStep,
StreamHeartbeat, StreamHeartbeat,
StreamStart, StreamStart,
StreamStartStep,
StreamTextDelta, StreamTextDelta,
StreamTextEnd, StreamTextEnd,
StreamTextStart, StreamTextStart,
@@ -896,8 +613,6 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
type_to_class: dict[str, type[StreamBaseResponse]] = { type_to_class: dict[str, type[StreamBaseResponse]] = {
ResponseType.START.value: StreamStart, ResponseType.START.value: StreamStart,
ResponseType.FINISH.value: StreamFinish, ResponseType.FINISH.value: StreamFinish,
ResponseType.START_STEP.value: StreamStartStep,
ResponseType.FINISH_STEP.value: StreamFinishStep,
ResponseType.TEXT_START.value: StreamTextStart, ResponseType.TEXT_START.value: StreamTextStart,
ResponseType.TEXT_DELTA.value: StreamTextDelta, ResponseType.TEXT_DELTA.value: StreamTextDelta,
ResponseType.TEXT_END.value: StreamTextEnd, ResponseType.TEXT_END.value: StreamTextEnd,

View File

@@ -3,18 +3,15 @@ from typing import TYPE_CHECKING, Any
from openai.types.chat import ChatCompletionToolParam from openai.types.chat import ChatCompletionToolParam
from backend.copilot.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.copilot.tracking import track_tool_called from backend.api.features.chat.tracking import track_tool_called
from .add_understanding import AddUnderstandingTool from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool from .agent_output import AgentOutputTool
from .base import BaseTool from .base import BaseTool
from .bash_exec import BashExecTool
from .check_operation_status import CheckOperationStatusTool
from .create_agent import CreateAgentTool from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool from .edit_agent import EditAgentTool
from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool
from .find_agent import FindAgentTool from .find_agent import FindAgentTool
from .find_block import FindBlockTool from .find_block import FindBlockTool
from .find_library_agent import FindLibraryAgentTool from .find_library_agent import FindLibraryAgentTool
@@ -22,7 +19,6 @@ from .get_doc_page import GetDocPageTool
from .run_agent import RunAgentTool from .run_agent import RunAgentTool
from .run_block import RunBlockTool from .run_block import RunBlockTool
from .search_docs import SearchDocsTool from .search_docs import SearchDocsTool
from .web_fetch import WebFetchTool
from .workspace_files import ( from .workspace_files import (
DeleteWorkspaceFileTool, DeleteWorkspaceFileTool,
ListWorkspaceFilesTool, ListWorkspaceFilesTool,
@@ -31,7 +27,7 @@ from .workspace_files import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from backend.copilot.response_model import StreamToolOutputAvailable from backend.api.features.chat.response_model import StreamToolOutputAvailable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -47,17 +43,8 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"run_agent": RunAgentTool(), "run_agent": RunAgentTool(),
"run_block": RunBlockTool(), "run_block": RunBlockTool(),
"view_agent_output": AgentOutputTool(), "view_agent_output": AgentOutputTool(),
"check_operation_status": CheckOperationStatusTool(),
"search_docs": SearchDocsTool(), "search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(), "get_doc_page": GetDocPageTool(),
# Web fetch for safe URL retrieval
"web_fetch": WebFetchTool(),
# Sandboxed code execution (bubblewrap)
"bash_exec": BashExecTool(),
# Persistent workspace tools (cloud storage, survives across sessions)
# Feature request tools
"search_feature_requests": SearchFeatureRequestsTool(),
"create_feature_request": CreateFeatureRequestTool(),
# Workspace tools for CoPilot file operations # Workspace tools for CoPilot file operations
"list_workspace_files": ListWorkspaceFilesTool(), "list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(), "read_workspace_file": ReadWorkspaceFileTool(),

View File

@@ -6,11 +6,11 @@ import pytest
from prisma.types import ProfileCreateInput from prisma.types import ProfileCreateInput
from pydantic import SecretStr from pydantic import SecretStr
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db from backend.api.features.store import db as store_db
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.blocks.llm import AITextGeneratorBlock from backend.blocks.llm import AITextGeneratorBlock
from backend.copilot.model import ChatSession
from backend.data.db import prisma from backend.data.db import prisma
from backend.data.graph import Graph, Link, Node, create_graph from backend.data.graph import Graph, Link, Node, create_graph
from backend.data.model import APIKeyCredentials from backend.data.model import APIKeyCredentials

View File

@@ -3,9 +3,11 @@
import logging import logging
from typing import Any from typing import Any
from backend.copilot.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.data.db_accessors import understanding_db from backend.data.understanding import (
from backend.data.understanding import BusinessUnderstandingInput BusinessUnderstandingInput,
upsert_business_understanding,
)
from .base import BaseTool from .base import BaseTool
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
@@ -97,9 +99,7 @@ and automations for the user's specific needs."""
] ]
# Upsert with merge # Upsert with merge
understanding = await understanding_db().upsert_business_understanding( understanding = await upsert_business_understanding(user_id, input_data)
user_id, input_data
)
# Build current understanding summary (filter out empty values) # Build current understanding summary (filter out empty values)
current_understanding = { current_understanding = {

View File

@@ -5,8 +5,9 @@ import re
import uuid import uuid
from typing import Any, NotRequired, TypedDict from typing import Any, NotRequired, TypedDict
from backend.data.db_accessors import graph_db, library_db, store_db from backend.api.features.library import db as library_db
from backend.data.graph import Graph, Link, Node from backend.api.features.store import db as store_db
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
from backend.util.exceptions import DatabaseError, NotFoundError from backend.util.exceptions import DatabaseError, NotFoundError
from .service import ( from .service import (
@@ -144,9 +145,8 @@ async def get_library_agent_by_id(
Returns: Returns:
LibraryAgentSummary if found, None otherwise LibraryAgentSummary if found, None otherwise
""" """
db = library_db()
try: try:
agent = await db.get_library_agent_by_graph_id(user_id, agent_id) agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent: if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}") logger.debug(f"Found library agent by graph_id: {agent.name}")
return LibraryAgentSummary( return LibraryAgentSummary(
@@ -163,7 +163,7 @@ async def get_library_agent_by_id(
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}") logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
try: try:
agent = await db.get_library_agent(agent_id, user_id) agent = await library_db.get_library_agent(agent_id, user_id)
if agent: if agent:
logger.debug(f"Found library agent by library_id: {agent.name}") logger.debug(f"Found library agent by library_id: {agent.name}")
return LibraryAgentSummary( return LibraryAgentSummary(
@@ -215,7 +215,7 @@ async def get_library_agents_for_generation(
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
""" """
try: try:
response = await library_db().list_library_agents( response = await library_db.list_library_agents(
user_id=user_id, user_id=user_id,
search_term=search_query, search_term=search_query,
page=1, page=1,
@@ -272,7 +272,7 @@ async def search_marketplace_agents_for_generation(
List of LibraryAgentSummary with full input/output schemas List of LibraryAgentSummary with full input/output schemas
""" """
try: try:
response = await store_db().get_store_agents( response = await store_db.get_store_agents(
search_query=search_query, search_query=search_query,
page=1, page=1,
page_size=max_results, page_size=max_results,
@@ -286,7 +286,7 @@ async def search_marketplace_agents_for_generation(
return [] return []
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs] graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
graphs = await graph_db().get_store_listed_graphs(graph_ids) graphs = await get_store_listed_graphs(*graph_ids)
results: list[LibraryAgentSummary] = [] results: list[LibraryAgentSummary] = []
for agent in agents_with_graphs: for agent in agents_with_graphs:
@@ -673,10 +673,9 @@ async def save_agent_to_library(
Tuple of (created Graph, LibraryAgent) Tuple of (created Graph, LibraryAgent)
""" """
graph = json_to_graph(agent_json) graph = json_to_graph(agent_json)
db = library_db()
if is_update: if is_update:
return await db.update_graph_in_library(graph, user_id) return await library_db.update_graph_in_library(graph, user_id)
return await db.create_graph_in_library(graph, user_id) return await library_db.create_graph_in_library(graph, user_id)
def graph_to_json(graph: Graph) -> dict[str, Any]: def graph_to_json(graph: Graph) -> dict[str, Any]:
@@ -736,14 +735,12 @@ async def get_agent_as_json(
Returns: Returns:
Agent as JSON dict or None if not found Agent as JSON dict or None if not found
""" """
db = graph_db() graph = await get_graph(agent_id, version=None, user_id=user_id)
graph = await db.get_graph(agent_id, version=None, user_id=user_id)
if not graph and user_id: if not graph and user_id:
try: try:
library_agent = await library_db().get_library_agent(agent_id, user_id) library_agent = await library_db.get_library_agent(agent_id, user_id)
graph = await db.get_graph( graph = await get_graph(
library_agent.graph_id, version=None, user_id=user_id library_agent.graph_id, version=None, user_id=user_id
) )
except NotFoundError: except NotFoundError:

View File

@@ -12,19 +12,8 @@ import httpx
from backend.util.settings import Settings from backend.util.settings import Settings
from .dummy import (
customize_template_dummy,
decompose_goal_dummy,
generate_agent_dummy,
generate_agent_patch_dummy,
get_blocks_dummy,
health_check_dummy,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_dummy_mode_warned = False
def _create_error_response( def _create_error_response(
error_message: str, error_message: str,
@@ -101,26 +90,10 @@ def _get_settings() -> Settings:
return _settings return _settings
def _is_dummy_mode() -> bool:
"""Check if dummy mode is enabled for testing."""
global _dummy_mode_warned
settings = _get_settings()
is_dummy = bool(settings.config.agentgenerator_use_dummy)
if is_dummy and not _dummy_mode_warned:
logger.warning(
"Agent Generator running in DUMMY MODE - returning mock responses. "
"Do not use in production!"
)
_dummy_mode_warned = True
return is_dummy
def is_external_service_configured() -> bool: def is_external_service_configured() -> bool:
"""Check if external Agent Generator service is configured (or dummy mode).""" """Check if external Agent Generator service is configured."""
settings = _get_settings() settings = _get_settings()
return bool(settings.config.agentgenerator_host) or bool( return bool(settings.config.agentgenerator_host)
settings.config.agentgenerator_use_dummy
)
def _get_base_url() -> str: def _get_base_url() -> str:
@@ -164,9 +137,6 @@ async def decompose_goal_external(
- {"type": "error", "error": "...", "error_type": "..."} on error - {"type": "error", "error": "...", "error_type": "..."} on error
Or None on unexpected error Or None on unexpected error
""" """
if _is_dummy_mode():
return await decompose_goal_dummy(description, context, library_agents)
client = _get_client() client = _get_client()
if context: if context:
@@ -256,11 +226,6 @@ async def generate_agent_external(
Returns: Returns:
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
""" """
if _is_dummy_mode():
return await generate_agent_dummy(
instructions, library_agents, operation_id, task_id
)
client = _get_client() client = _get_client()
# Build request payload # Build request payload
@@ -332,11 +297,6 @@ async def generate_agent_patch_external(
Returns: Returns:
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
""" """
if _is_dummy_mode():
return await generate_agent_patch_dummy(
update_request, current_agent, library_agents, operation_id, task_id
)
client = _get_client() client = _get_client()
# Build request payload # Build request payload
@@ -423,11 +383,6 @@ async def customize_template_external(
Returns: Returns:
Customized agent JSON, clarifying questions dict, or error dict on error Customized agent JSON, clarifying questions dict, or error dict on error
""" """
if _is_dummy_mode():
return await customize_template_dummy(
template_agent, modification_request, context
)
client = _get_client() client = _get_client()
request = modification_request request = modification_request
@@ -490,9 +445,6 @@ async def get_blocks_external() -> list[dict[str, Any]] | None:
Returns: Returns:
List of block info dicts or None on error List of block info dicts or None on error
""" """
if _is_dummy_mode():
return await get_blocks_dummy()
client = _get_client() client = _get_client()
try: try:
@@ -526,9 +478,6 @@ async def health_check() -> bool:
if not is_external_service_configured(): if not is_external_service_configured():
return False return False
if _is_dummy_mode():
return await health_check_dummy()
client = _get_client() client = _get_client()
try: try:

View File

@@ -7,9 +7,10 @@ from typing import Any
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.library import db as library_db
from backend.api.features.library.model import LibraryAgent from backend.api.features.library.model import LibraryAgent
from backend.copilot.model import ChatSession from backend.data import execution as execution_db
from backend.data.db_accessors import execution_db, library_db
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
from .base import BaseTool from .base import BaseTool
@@ -164,12 +165,10 @@ class AgentOutputTool(BaseTool):
Resolve agent from provided identifiers. Resolve agent from provided identifiers.
Returns (library_agent, error_message). Returns (library_agent, error_message).
""" """
lib_db = library_db()
# Priority 1: Exact library agent ID # Priority 1: Exact library agent ID
if library_agent_id: if library_agent_id:
try: try:
agent = await lib_db.get_library_agent(library_agent_id, user_id) agent = await library_db.get_library_agent(library_agent_id, user_id)
return agent, None return agent, None
except Exception as e: except Exception as e:
logger.warning(f"Failed to get library agent by ID: {e}") logger.warning(f"Failed to get library agent by ID: {e}")
@@ -183,7 +182,7 @@ class AgentOutputTool(BaseTool):
return None, f"Agent '{store_slug}' not found in marketplace" return None, f"Agent '{store_slug}' not found in marketplace"
# Find in user's library by graph_id # Find in user's library by graph_id
agent = await lib_db.get_library_agent_by_graph_id(user_id, graph.id) agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
if not agent: if not agent:
return ( return (
None, None,
@@ -195,7 +194,7 @@ class AgentOutputTool(BaseTool):
# Priority 3: Fuzzy name search in library # Priority 3: Fuzzy name search in library
if agent_name: if agent_name:
try: try:
response = await lib_db.list_library_agents( response = await library_db.list_library_agents(
user_id=user_id, user_id=user_id,
search_term=agent_name, search_term=agent_name,
page_size=5, page_size=5,
@@ -229,11 +228,9 @@ class AgentOutputTool(BaseTool):
Fetch execution(s) based on filters. Fetch execution(s) based on filters.
Returns (single_execution, available_executions_meta, error_message). Returns (single_execution, available_executions_meta, error_message).
""" """
exec_db = execution_db()
# If specific execution_id provided, fetch it directly # If specific execution_id provided, fetch it directly
if execution_id: if execution_id:
execution = await exec_db.get_graph_execution( execution = await execution_db.get_graph_execution(
user_id=user_id, user_id=user_id,
execution_id=execution_id, execution_id=execution_id,
include_node_executions=False, include_node_executions=False,
@@ -243,7 +240,7 @@ class AgentOutputTool(BaseTool):
return execution, [], None return execution, [], None
# Get completed executions with time filters # Get completed executions with time filters
executions = await exec_db.get_graph_executions( executions = await execution_db.get_graph_executions(
graph_id=graph_id, graph_id=graph_id,
user_id=user_id, user_id=user_id,
statuses=[ExecutionStatus.COMPLETED], statuses=[ExecutionStatus.COMPLETED],
@@ -257,7 +254,7 @@ class AgentOutputTool(BaseTool):
# If only one execution, fetch full details # If only one execution, fetch full details
if len(executions) == 1: if len(executions) == 1:
full_execution = await exec_db.get_graph_execution( full_execution = await execution_db.get_graph_execution(
user_id=user_id, user_id=user_id,
execution_id=executions[0].id, execution_id=executions[0].id,
include_node_executions=False, include_node_executions=False,
@@ -265,7 +262,7 @@ class AgentOutputTool(BaseTool):
return full_execution, [], None return full_execution, [], None
# Multiple executions - return latest with full details, plus list of available # Multiple executions - return latest with full details, plus list of available
full_execution = await exec_db.get_graph_execution( full_execution = await execution_db.get_graph_execution(
user_id=user_id, user_id=user_id,
execution_id=executions[0].id, execution_id=executions[0].id,
include_node_executions=False, include_node_executions=False,
@@ -383,7 +380,7 @@ class AgentOutputTool(BaseTool):
and not input_data.store_slug and not input_data.store_slug
): ):
# Fetch execution directly to get graph_id # Fetch execution directly to get graph_id
execution = await execution_db().get_graph_execution( execution = await execution_db.get_graph_execution(
user_id=user_id, user_id=user_id,
execution_id=input_data.execution_id, execution_id=input_data.execution_id,
include_node_executions=False, include_node_executions=False,
@@ -395,7 +392,7 @@ class AgentOutputTool(BaseTool):
) )
# Find library agent by graph_id # Find library agent by graph_id
agent = await library_db().get_library_agent_by_graph_id( agent = await library_db.get_library_agent_by_graph_id(
user_id, execution.graph_id user_id, execution.graph_id
) )
if not agent: if not agent:

View File

@@ -4,7 +4,8 @@ import logging
import re import re
from typing import Literal from typing import Literal
from backend.data.db_accessors import library_db, store_db from backend.api.features.library import db as library_db
from backend.api.features.store import db as store_db
from backend.util.exceptions import DatabaseError, NotFoundError from backend.util.exceptions import DatabaseError, NotFoundError
from .models import ( from .models import (
@@ -44,10 +45,8 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
Returns: Returns:
AgentInfo if found, None otherwise AgentInfo if found, None otherwise
""" """
lib_db = library_db()
try: try:
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id) agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent: if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}") logger.debug(f"Found library agent by graph_id: {agent.name}")
return AgentInfo( return AgentInfo(
@@ -72,7 +71,7 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
) )
try: try:
agent = await lib_db.get_library_agent(agent_id, user_id) agent = await library_db.get_library_agent(agent_id, user_id)
if agent: if agent:
logger.debug(f"Found library agent by library_id: {agent.name}") logger.debug(f"Found library agent by library_id: {agent.name}")
return AgentInfo( return AgentInfo(
@@ -134,7 +133,7 @@ async def search_agents(
try: try:
if source == "marketplace": if source == "marketplace":
logger.info(f"Searching marketplace for: {query}") logger.info(f"Searching marketplace for: {query}")
results = await store_db().get_store_agents(search_query=query, page_size=5) results = await store_db.get_store_agents(search_query=query, page_size=5)
for agent in results.agents: for agent in results.agents:
agents.append( agents.append(
AgentInfo( AgentInfo(
@@ -160,7 +159,7 @@ async def search_agents(
if not agents: if not agents:
logger.info(f"Searching user library for: {query}") logger.info(f"Searching user library for: {query}")
results = await library_db().list_library_agents( results = await library_db.list_library_agents(
user_id=user_id, # type: ignore[arg-type] user_id=user_id, # type: ignore[arg-type]
search_term=query, search_term=query,
page_size=10, page_size=10,

View File

@@ -5,8 +5,8 @@ from typing import Any
from openai.types.chat import ChatCompletionToolParam from openai.types.chat import ChatCompletionToolParam
from backend.copilot.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.copilot.response_model import StreamToolOutputAvailable from backend.api.features.chat.response_model import StreamToolOutputAvailable
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase

View File

@@ -3,7 +3,7 @@
import logging import logging
from typing import Any from typing import Any
from backend.copilot.model import ChatSession from backend.api.features.chat.model import ChatSession
from .agent_generator import ( from .agent_generator import (
AgentGeneratorNotConfiguredError, AgentGeneratorNotConfiguredError,

View File

@@ -3,9 +3,9 @@
import logging import logging
from typing import Any from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.api.features.store.exceptions import AgentNotFoundError from backend.api.features.store.exceptions import AgentNotFoundError
from backend.copilot.model import ChatSession
from backend.data.db_accessors import store_db as get_store_db
from .agent_generator import ( from .agent_generator import (
AgentGeneratorNotConfiguredError, AgentGeneratorNotConfiguredError,
@@ -137,8 +137,6 @@ class CustomizeAgentTool(BaseTool):
creator_username, agent_slug = parts creator_username, agent_slug = parts
store_db = get_store_db()
# Fetch the marketplace agent details # Fetch the marketplace agent details
try: try:
agent_details = await store_db.get_store_agent_details( agent_details = await store_db.get_store_agent_details(

View File

@@ -3,7 +3,7 @@
import logging import logging
from typing import Any from typing import Any
from backend.copilot.model import ChatSession from backend.api.features.chat.model import ChatSession
from .agent_generator import ( from .agent_generator import (
AgentGeneratorNotConfiguredError, AgentGeneratorNotConfiguredError,

View File

@@ -2,7 +2,7 @@
from typing import Any from typing import Any
from backend.copilot.model import ChatSession from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents from .agent_search import search_agents
from .base import BaseTool from .base import BaseTool

View File

@@ -3,43 +3,20 @@ from typing import Any
from prisma.enums import ContentType from prisma.enums import ContentType
from backend.blocks import get_block from backend.api.features.chat.model import ChatSession
from backend.blocks._base import BlockType from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
from backend.copilot.model import ChatSession from backend.api.features.chat.tools.models import (
from backend.data.db_accessors import search
from .base import BaseTool, ToolResponseBase
from .models import (
BlockInfoSummary, BlockInfoSummary,
BlockInputFieldInfo,
BlockListResponse, BlockListResponse,
ErrorResponse, ErrorResponse,
NoResultsResponse, NoResultsResponse,
) )
from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.data.block import get_block
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_TARGET_RESULTS = 10
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
_OVERFETCH_PAGE_SIZE = 40
# Block types that only work within graphs and cannot run standalone in CoPilot.
COPILOT_EXCLUDED_BLOCK_TYPES = {
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
BlockType.NOTE, # Visual annotation only - no runtime behavior
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
}
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
COPILOT_EXCLUDED_BLOCK_IDS = {
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
"3b191d9f-356f-482d-8238-ba04b6d18381",
}
class FindBlockTool(BaseTool): class FindBlockTool(BaseTool):
"""Tool for searching available blocks.""" """Tool for searching available blocks."""
@@ -55,8 +32,7 @@ class FindBlockTool(BaseTool):
"Blocks are reusable components that perform specific tasks like " "Blocks are reusable components that perform specific tasks like "
"sending emails, making API calls, processing text, etc. " "sending emails, making API calls, processing text, etc. "
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. " "IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
"The response includes each block's id, name, and description. " "The response includes each block's id, required_inputs, and input_schema."
"Call run_block with the block's id **with no inputs** to see detailed inputs/outputs and execute it."
) )
@property @property
@@ -108,11 +84,11 @@ class FindBlockTool(BaseTool):
try: try:
# Search for blocks using hybrid search # Search for blocks using hybrid search
results, total = await search().unified_hybrid_search( results, total = await unified_hybrid_search(
query=query, query=query,
content_types=[ContentType.BLOCK], content_types=[ContentType.BLOCK],
page=1, page=1,
page_size=_OVERFETCH_PAGE_SIZE, page_size=10,
) )
if not results: if not results:
@@ -125,44 +101,67 @@ class FindBlockTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
# Enrich results with block information # Enrich results with full block information
blocks: list[BlockInfoSummary] = [] blocks: list[BlockInfoSummary] = []
for result in results: for result in results:
block_id = result["content_id"] block_id = result["content_id"]
block = get_block(block_id) block = get_block(block_id)
# Skip disabled blocks # Skip disabled blocks
if not block or block.disabled: if block and not block.disabled:
continue # Get input/output schemas
input_schema = {}
output_schema = {}
try:
input_schema = block.input_schema.jsonschema()
except Exception:
pass
try:
output_schema = block.output_schema.jsonschema()
except Exception:
pass
# Skip blocks excluded from CoPilot (graph-only blocks) # Get categories from block instance
if ( categories = []
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES if hasattr(block, "categories") and block.categories:
or block.id in COPILOT_EXCLUDED_BLOCK_IDS categories = [cat.value for cat in block.categories]
):
continue
blocks.append( # Extract required inputs for easier use
BlockInfoSummary( required_inputs: list[BlockInputFieldInfo] = []
id=block_id, if input_schema:
name=block.name, properties = input_schema.get("properties", {})
description=block.description or "", required_fields = set(input_schema.get("required", []))
categories=[c.value for c in block.categories], # Get credential field names to exclude from required inputs
credentials_fields = set(
block.input_schema.get_credentials_fields().keys()
)
for field_name, field_schema in properties.items():
# Skip credential fields - they're handled separately
if field_name in credentials_fields:
continue
required_inputs.append(
BlockInputFieldInfo(
name=field_name,
type=field_schema.get("type", "string"),
description=field_schema.get("description", ""),
required=field_name in required_fields,
default=field_schema.get("default"),
)
)
blocks.append(
BlockInfoSummary(
id=block_id,
name=block.name,
description=block.description or "",
categories=categories,
input_schema=input_schema,
output_schema=output_schema,
required_inputs=required_inputs,
)
) )
)
if len(blocks) >= _TARGET_RESULTS:
break
if blocks and len(blocks) < _TARGET_RESULTS:
logger.debug(
"find_block returned %d/%d results for query '%s' "
"(filtered %d excluded/disabled blocks)",
len(blocks),
_TARGET_RESULTS,
query,
len(results) - len(blocks),
)
if not blocks: if not blocks:
return NoResultsResponse( return NoResultsResponse(
@@ -176,7 +175,8 @@ class FindBlockTool(BaseTool):
return BlockListResponse( return BlockListResponse(
message=( message=(
f"Found {len(blocks)} block(s) matching '{query}'. " f"Found {len(blocks)} block(s) matching '{query}'. "
"To see a block's inputs/outputs and execute it, use run_block with the block's 'id' - providing no inputs." "To execute a block, use run_block with the block's 'id' field "
"and provide 'input_data' matching the block's input_schema."
), ),
blocks=blocks, blocks=blocks,
count=len(blocks), count=len(blocks),

View File

@@ -2,7 +2,7 @@
from typing import Any from typing import Any
from backend.copilot.model import ChatSession from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents from .agent_search import search_agents
from .base import BaseTool from .base import BaseTool

View File

@@ -4,10 +4,13 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from backend.copilot.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from .base import BaseTool from backend.api.features.chat.tools.models import (
from .models import DocPageResponse, ErrorResponse, ToolResponseBase DocPageResponse,
ErrorResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -25,7 +25,6 @@ class ResponseType(str, Enum):
AGENT_SAVED = "agent_saved" AGENT_SAVED = "agent_saved"
CLARIFICATION_NEEDED = "clarification_needed" CLARIFICATION_NEEDED = "clarification_needed"
BLOCK_LIST = "block_list" BLOCK_LIST = "block_list"
BLOCK_DETAILS = "block_details"
BLOCK_OUTPUT = "block_output" BLOCK_OUTPUT = "block_output"
DOC_SEARCH_RESULTS = "doc_search_results" DOC_SEARCH_RESULTS = "doc_search_results"
DOC_PAGE = "doc_page" DOC_PAGE = "doc_page"
@@ -41,15 +40,6 @@ class ResponseType(str, Enum):
OPERATION_IN_PROGRESS = "operation_in_progress" OPERATION_IN_PROGRESS = "operation_in_progress"
# Input validation # Input validation
INPUT_VALIDATION_ERROR = "input_validation_error" INPUT_VALIDATION_ERROR = "input_validation_error"
# Web fetch
WEB_FETCH = "web_fetch"
# Code execution
BASH_EXEC = "bash_exec"
# Operation status check
OPERATION_STATUS = "operation_status"
# Feature request types
FEATURE_REQUEST_SEARCH = "feature_request_search"
FEATURE_REQUEST_CREATED = "feature_request_created"
# Base response model # Base response model
@@ -345,17 +335,11 @@ class BlockInfoSummary(BaseModel):
name: str name: str
description: str description: str
categories: list[str] categories: list[str]
input_schema: dict[str, Any] = Field( input_schema: dict[str, Any]
default_factory=dict, output_schema: dict[str, Any]
description="Full JSON schema for block inputs",
)
output_schema: dict[str, Any] = Field(
default_factory=dict,
description="Full JSON schema for block outputs",
)
required_inputs: list[BlockInputFieldInfo] = Field( required_inputs: list[BlockInputFieldInfo] = Field(
default_factory=list, default_factory=list,
description="List of input fields for this block", description="List of required input fields for this block",
) )
@@ -368,29 +352,10 @@ class BlockListResponse(ToolResponseBase):
query: str query: str
usage_hint: str = Field( usage_hint: str = Field(
default="To execute a block, call run_block with block_id set to the block's " default="To execute a block, call run_block with block_id set to the block's "
"'id' field and input_data containing the fields listed in required_inputs." "'id' field and input_data containing the required fields from input_schema."
) )
class BlockDetails(BaseModel):
"""Detailed block information."""
id: str
name: str
description: str
inputs: dict[str, Any] = {}
outputs: dict[str, Any] = {}
credentials: list[CredentialsMetaInput] = []
class BlockDetailsResponse(ToolResponseBase):
"""Response for block details (first run_block attempt)."""
type: ResponseType = ResponseType.BLOCK_DETAILS
block: BlockDetails
user_authenticated: bool = False
class BlockOutputResponse(ToolResponseBase): class BlockOutputResponse(ToolResponseBase):
"""Response for run_block tool.""" """Response for run_block tool."""
@@ -456,55 +421,3 @@ class AsyncProcessingResponse(ToolResponseBase):
status: str = "accepted" # Must be "accepted" for detection status: str = "accepted" # Must be "accepted" for detection
operation_id: str | None = None operation_id: str | None = None
task_id: str | None = None task_id: str | None = None
class WebFetchResponse(ToolResponseBase):
"""Response for web_fetch tool."""
type: ResponseType = ResponseType.WEB_FETCH
url: str
status_code: int
content_type: str
content: str
truncated: bool = False
class BashExecResponse(ToolResponseBase):
"""Response for bash_exec tool."""
type: ResponseType = ResponseType.BASH_EXEC
stdout: str
stderr: str
exit_code: int
timed_out: bool = False
# Feature request models
class FeatureRequestInfo(BaseModel):
"""Information about a feature request issue."""
id: str
identifier: str
title: str
description: str | None = None
class FeatureRequestSearchResponse(ToolResponseBase):
"""Response for search_feature_requests tool."""
type: ResponseType = ResponseType.FEATURE_REQUEST_SEARCH
results: list[FeatureRequestInfo]
count: int
query: str
class FeatureRequestCreatedResponse(ToolResponseBase):
"""Response for create_feature_request tool."""
type: ResponseType = ResponseType.FEATURE_REQUEST_CREATED
issue_id: str
issue_identifier: str
issue_title: str
issue_url: str
is_new_issue: bool # False if added to existing
customer_name: str

View File

@@ -5,12 +5,16 @@ from typing import Any
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from backend.copilot.config import ChatConfig from backend.api.features.chat.config import ChatConfig
from backend.copilot.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled from backend.api.features.chat.tracking import (
from backend.data.db_accessors import graph_db, library_db, user_db track_agent_run_success,
track_agent_scheduled,
)
from backend.api.features.library import db as library_db
from backend.data.graph import GraphModel from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput from backend.data.model import CredentialsMetaInput
from backend.data.user import get_user_by_id
from backend.executor import utils as execution_utils from backend.executor import utils as execution_utils
from backend.util.clients import get_scheduler_client from backend.util.clients import get_scheduler_client
from backend.util.exceptions import DatabaseError, NotFoundError from backend.util.exceptions import DatabaseError, NotFoundError
@@ -20,7 +24,6 @@ from backend.util.timezone_utils import (
) )
from .base import BaseTool from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import ( from .models import (
AgentDetails, AgentDetails,
AgentDetailsResponse, AgentDetailsResponse,
@@ -196,7 +199,7 @@ class RunAgentTool(BaseTool):
# Priority: library_agent_id if provided # Priority: library_agent_id if provided
if has_library_id: if has_library_id:
library_agent = await library_db().get_library_agent( library_agent = await library_db.get_library_agent(
params.library_agent_id, user_id params.library_agent_id, user_id
) )
if not library_agent: if not library_agent:
@@ -205,7 +208,9 @@ class RunAgentTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
# Get the graph from the library agent # Get the graph from the library agent
graph = await graph_db().get_graph( from backend.data.graph import get_graph
graph = await get_graph(
library_agent.graph_id, library_agent.graph_id,
library_agent.graph_version, library_agent.graph_version,
user_id=user_id, user_id=user_id,
@@ -256,7 +261,7 @@ class RunAgentTool(BaseTool):
), ),
requirements={ requirements={
"credentials": requirements_creds_list, "credentials": requirements_creds_list,
"inputs": get_inputs_from_schema(graph.input_schema), "inputs": self._get_inputs_list(graph.input_schema),
"execution_modes": self._get_execution_modes(graph), "execution_modes": self._get_execution_modes(graph),
}, },
), ),
@@ -364,6 +369,22 @@ class RunAgentTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract inputs list from schema."""
inputs_list = []
if isinstance(input_schema, dict) and "properties" in input_schema:
for field_name, field_schema in input_schema["properties"].items():
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in input_schema.get("required", []),
}
)
return inputs_list
def _get_execution_modes(self, graph: GraphModel) -> list[str]: def _get_execution_modes(self, graph: GraphModel) -> list[str]:
"""Get available execution modes for the graph.""" """Get available execution modes for the graph."""
trigger_info = graph.trigger_setup_info trigger_info = graph.trigger_setup_info
@@ -377,7 +398,7 @@ class RunAgentTool(BaseTool):
suffix: str, suffix: str,
) -> str: ) -> str:
"""Build a message describing available inputs for an agent.""" """Build a message describing available inputs for an agent."""
inputs_list = get_inputs_from_schema(graph.input_schema) inputs_list = self._get_inputs_list(graph.input_schema)
required_names = [i["name"] for i in inputs_list if i["required"]] required_names = [i["name"] for i in inputs_list if i["required"]]
optional_names = [i["name"] for i in inputs_list if not i["required"]] optional_names = [i["name"] for i in inputs_list if not i["required"]]
@@ -516,7 +537,7 @@ class RunAgentTool(BaseTool):
library_agent = await get_or_create_library_agent(graph, user_id) library_agent = await get_or_create_library_agent(graph, user_id)
# Get user timezone # Get user timezone
user = await user_db().get_user_by_id(user_id) user = await get_user_by_id(user_id)
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone) user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
# Create schedule # Create schedule

View File

@@ -7,33 +7,24 @@ from typing import Any
from pydantic_core import PydanticUndefined from pydantic_core import PydanticUndefined
from backend.blocks import get_block from backend.api.features.chat.model import ChatSession
from backend.blocks._base import AnyBlockSchema from backend.data.block import get_block
from backend.copilot.model import ChatSession
from backend.data.db_accessors import workspace_db
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput from backend.data.model import CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError from backend.util.exceptions import BlockError
from .base import BaseTool from .base import BaseTool
from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES
from .helpers import get_inputs_from_schema
from .models import ( from .models import (
BlockDetails,
BlockDetailsResponse,
BlockOutputResponse, BlockOutputResponse,
ErrorResponse, ErrorResponse,
InputValidationErrorResponse,
SetupInfo, SetupInfo,
SetupRequirementsResponse, SetupRequirementsResponse,
ToolResponseBase, ToolResponseBase,
UserReadiness, UserReadiness,
) )
from .utils import ( from .utils import build_missing_credentials_from_field_info
build_missing_credentials_from_field_info,
match_credentials_to_requirements,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -51,8 +42,8 @@ class RunBlockTool(BaseTool):
"Execute a specific block with the provided input data. " "Execute a specific block with the provided input data. "
"IMPORTANT: You MUST call find_block first to get the block's 'id' - " "IMPORTANT: You MUST call find_block first to get the block's 'id' - "
"do NOT guess or make up block IDs. " "do NOT guess or make up block IDs. "
"On first attempt (without input_data), returns detailed schema showing " "Use the 'id' from find_block results and provide input_data "
"required inputs and outputs. Then call again with proper input_data to execute." "matching the block's required_inputs."
) )
@property @property
@@ -67,19 +58,11 @@ class RunBlockTool(BaseTool):
"NEVER guess this - always get it from find_block first." "NEVER guess this - always get it from find_block first."
), ),
}, },
"block_name": {
"type": "string",
"description": (
"The block's human-readable name from find_block results. "
"Used for display purposes in the UI."
),
},
"input_data": { "input_data": {
"type": "object", "type": "object",
"description": ( "description": (
"Input values for the block. " "Input values for the block. Use the 'required_inputs' field "
"First call with empty {} to see the block's schema, " "from find_block to see what fields are needed."
"then call again with proper values to execute."
), ),
}, },
}, },
@@ -90,6 +73,91 @@ class RunBlockTool(BaseTool):
def requires_auth(self) -> bool: def requires_auth(self) -> bool:
return True return True
async def _check_block_credentials(
self,
user_id: str,
block: Any,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Args:
user_id: User ID
block: Block to check credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple[matched_credentials, missing_credentials]
"""
matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
input_data = input_data or {}
# Get credential field info from block's input schema
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return matched_credentials, missing_credentials
# Get user's available credentials
creds_manager = IntegrationCredentialsManager()
available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
# Get discriminator from input, falling back to schema default
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in effective_field_info.provider
and cred.type in effective_field_info.supported_types
),
None,
)
if matching_cred:
matched_credentials[field_name] = CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
else:
# Create a placeholder for the missing credential
provider = next(iter(effective_field_info.provider), "unknown")
cred_type = next(iter(effective_field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched_credentials, missing_credentials
async def _execute( async def _execute(
self, self,
user_id: str | None, user_id: str | None,
@@ -144,54 +212,13 @@ class RunBlockTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
# Check if block is excluded from CoPilot (graph-only blocks)
if (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
return ErrorResponse(
message=(
f"Block '{block.name}' cannot be run directly in CoPilot. "
"This block is designed for use within graphs only."
),
session_id=session_id,
)
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}") logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
creds_manager = IntegrationCredentialsManager() creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = ( matched_credentials, missing_credentials = await self._check_block_credentials(
await self._resolve_block_credentials(user_id, block, input_data) user_id, block, input_data
) )
# Get block schemas for details/validation
try:
input_schema: dict[str, Any] = block.input_schema.jsonschema()
except Exception as e:
logger.warning(
"Failed to generate input schema for block %s: %s",
block_id,
e,
)
return ErrorResponse(
message=f"Block '{block.name}' has an invalid input schema",
error=str(e),
session_id=session_id,
)
try:
output_schema: dict[str, Any] = block.output_schema.jsonschema()
except Exception as e:
logger.warning(
"Failed to generate output schema for block %s: %s",
block_id,
e,
)
return ErrorResponse(
message=f"Block '{block.name}' has an invalid output schema",
error=str(e),
session_id=session_id,
)
if missing_credentials: if missing_credentials:
# Return setup requirements response with missing credentials # Return setup requirements response with missing credentials
credentials_fields_info = block.input_schema.get_credentials_fields_info() credentials_fields_info = block.input_schema.get_credentials_fields_info()
@@ -224,56 +251,9 @@ class RunBlockTool(BaseTool):
graph_version=None, graph_version=None,
) )
# Check if this is a first attempt (required inputs missing)
# Return block details so user can see what inputs are needed
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
required_keys = set(input_schema.get("required", []))
required_non_credential_keys = required_keys - credentials_fields
provided_input_keys = set(input_data.keys()) - credentials_fields
# Check for unknown input fields
valid_fields = (
set(input_schema.get("properties", {}).keys()) - credentials_fields
)
unrecognized_fields = provided_input_keys - valid_fields
if unrecognized_fields:
return InputValidationErrorResponse(
message=(
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
f"Block was not executed. Please use the correct field names from the schema."
),
session_id=session_id,
unrecognized_fields=sorted(unrecognized_fields),
inputs=input_schema,
)
# Show details when not all required non-credential inputs are provided
if not (required_non_credential_keys <= provided_input_keys):
# Get credentials info for the response
credentials_meta = []
for field_name, cred_meta in matched_credentials.items():
credentials_meta.append(cred_meta)
return BlockDetailsResponse(
message=(
f"Block '{block.name}' details. "
"Provide input_data matching the inputs schema to execute the block."
),
session_id=session_id,
block=BlockDetails(
id=block_id,
name=block.name,
description=block.description or "",
inputs=input_schema,
outputs=output_schema,
credentials=credentials_meta,
),
user_authenticated=True,
)
try: try:
# Get or create user's workspace for CoPilot file operations # Get or create user's workspace for CoPilot file operations
workspace = await workspace_db().get_or_create_workspace(user_id) workspace = await get_or_create_workspace(user_id)
# Generate synthetic IDs for CoPilot context # Generate synthetic IDs for CoPilot context
# Each chat session is treated as its own agent with one continuous run # Each chat session is treated as its own agent with one continuous run
@@ -365,75 +345,29 @@ class RunBlockTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
async def _resolve_block_credentials( def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
self,
user_id: str,
block: AnyBlockSchema,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Resolve credentials for a block by matching user's available credentials.
Args:
user_id: User ID
block: Block to resolve credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple of (matched_credentials, missing_credentials) - matched credentials
are used for block execution, missing ones indicate setup requirements.
"""
input_data = input_data or {}
requirements = self._resolve_discriminated_credentials(block, input_data)
if not requirements:
return {}, []
return await match_credentials_to_requirements(user_id, requirements)
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
"""Extract non-credential inputs from block schema.""" """Extract non-credential inputs from block schema."""
inputs_list = []
schema = block.input_schema.jsonschema() schema = block.input_schema.jsonschema()
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
# Get credential field names to exclude
credentials_fields = set(block.input_schema.get_credentials_fields().keys()) credentials_fields = set(block.input_schema.get_credentials_fields().keys())
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
def _resolve_discriminated_credentials( for field_name, field_schema in properties.items():
self, # Skip credential fields
block: AnyBlockSchema, if field_name in credentials_fields:
input_data: dict[str, Any], continue
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed."""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return {}
resolved: dict[str, CredentialsFieldInfo] = {} inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in required_fields,
}
)
for field_name, field_info in credentials_fields_info.items(): return inputs_list
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
# For host-scoped credentials, add the discriminator value
# (e.g., URL) so _credential_is_for_host can match it
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
resolved[field_name] = effective_field_info
return resolved

View File

@@ -5,17 +5,16 @@ from typing import Any
from prisma.enums import ContentType from prisma.enums import ContentType
from backend.copilot.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.data.db_accessors import search from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
from .base import BaseTool
from .models import (
DocSearchResult, DocSearchResult,
DocSearchResultsResponse, DocSearchResultsResponse,
ErrorResponse, ErrorResponse,
NoResultsResponse, NoResultsResponse,
ToolResponseBase, ToolResponseBase,
) )
from backend.api.features.store.hybrid_search import unified_hybrid_search
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -118,7 +117,7 @@ class SearchDocsTool(BaseTool):
try: try:
# Search using hybrid search for DOCUMENTATION content type only # Search using hybrid search for DOCUMENTATION content type only
results, total = await search().unified_hybrid_search( results, total = await unified_hybrid_search(
query=query, query=query,
content_types=[ContentType.DOCUMENTATION], content_types=[ContentType.DOCUMENTATION],
page=1, page=1,

View File

@@ -3,18 +3,17 @@
import logging import logging
from typing import Any from typing import Any
from backend.api.features.library import db as library_db
from backend.api.features.library import model as library_model from backend.api.features.library import model as library_model
from backend.data.db_accessors import library_db, store_db from backend.api.features.store import db as store_db
from backend.data.graph import GraphModel from backend.data.graph import GraphModel
from backend.data.model import ( from backend.data.model import (
Credentials,
CredentialsFieldInfo, CredentialsFieldInfo,
CredentialsMetaInput, CredentialsMetaInput,
HostScopedCredentials, HostScopedCredentials,
OAuth2Credentials, OAuth2Credentials,
) )
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.exceptions import NotFoundError from backend.util.exceptions import NotFoundError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -38,14 +37,13 @@ async def fetch_graph_from_store_slug(
Raises: Raises:
DatabaseError: If there's a database error during lookup. DatabaseError: If there's a database error during lookup.
""" """
sdb = store_db()
try: try:
store_agent = await sdb.get_store_agent_details(username, agent_name) store_agent = await store_db.get_store_agent_details(username, agent_name)
except NotFoundError: except NotFoundError:
return None, None return None, None
# Get the graph from store listing version # Get the graph from store listing version
graph = await sdb.get_available_graph( graph = await store_db.get_available_graph(
store_agent.store_listing_version_id, hide_nodes=False store_agent.store_listing_version_id, hide_nodes=False
) )
return graph, store_agent return graph, store_agent
@@ -119,7 +117,7 @@ def build_missing_credentials_from_graph(
preserving all supported credential types for each field. preserving all supported credential types for each field.
""" """
matched_keys = set(matched_credentials.keys()) if matched_credentials else set() matched_keys = set(matched_credentials.keys()) if matched_credentials else set()
aggregated_fields = graph.aggregate_credentials_inputs() aggregated_fields = graph.regular_credentials_inputs
return { return {
field_key: _serialize_missing_credential(field_key, field_info) field_key: _serialize_missing_credential(field_key, field_info)
@@ -210,13 +208,13 @@ async def get_or_create_library_agent(
Returns: Returns:
LibraryAgent instance LibraryAgent instance
""" """
existing = await library_db().get_library_agent_by_graph_id( existing = await library_db.get_library_agent_by_graph_id(
graph_id=graph.id, user_id=user_id graph_id=graph.id, user_id=user_id
) )
if existing: if existing:
return existing return existing
library_agents = await library_db().create_library_agent( library_agents = await library_db.create_library_agent(
graph=graph, graph=graph,
user_id=user_id, user_id=user_id,
create_library_agents_for_sub_graphs=False, create_library_agents_for_sub_graphs=False,
@@ -225,99 +223,6 @@ async def get_or_create_library_agent(
return library_agents[0] return library_agents[0]
async def match_credentials_to_requirements(
user_id: str,
requirements: dict[str, CredentialsFieldInfo],
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Match user's credentials against a dictionary of credential requirements.
This is the core matching logic shared by both graph and block credential matching.
"""
matched: dict[str, CredentialsMetaInput] = {}
missing: list[CredentialsMetaInput] = []
if not requirements:
return matched, missing
available_creds = await get_user_credentials(user_id)
for field_name, field_info in requirements.items():
matching_cred = find_matching_credential(available_creds, field_info)
if matching_cred:
try:
matched[field_name] = create_credential_meta_from_match(matching_cred)
except Exception as e:
logger.error(
f"Failed to create CredentialsMetaInput for field '{field_name}': "
f"provider={matching_cred.provider}, type={matching_cred.type}, "
f"credential_id={matching_cred.id}",
exc_info=True,
)
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=f"{field_name} (validation failed: {e})",
)
)
else:
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched, missing
async def get_user_credentials(user_id: str) -> list[Credentials]:
"""Get all available credentials for a user."""
creds_manager = IntegrationCredentialsManager()
return await creds_manager.store.get_all_creds(user_id)
def find_matching_credential(
available_creds: list[Credentials],
field_info: CredentialsFieldInfo,
) -> Credentials | None:
"""Find a credential that matches the required provider, type, scopes, and host."""
for cred in available_creds:
if cred.provider not in field_info.provider:
continue
if cred.type not in field_info.supported_types:
continue
if cred.type == "oauth2" and not _credential_has_required_scopes(
cred, field_info
):
continue
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
continue
return cred
return None
def create_credential_meta_from_match(
matching_cred: Credentials,
) -> CredentialsMetaInput:
"""Create a CredentialsMetaInput from a matched credential."""
return CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
async def match_user_credentials_to_graph( async def match_user_credentials_to_graph(
user_id: str, user_id: str,
graph: GraphModel, graph: GraphModel,
@@ -339,7 +244,7 @@ async def match_user_credentials_to_graph(
missing_creds: list[str] = [] missing_creds: list[str] = []
# Get aggregated credentials requirements from the graph # Get aggregated credentials requirements from the graph
aggregated_creds = graph.aggregate_credentials_inputs() aggregated_creds = graph.regular_credentials_inputs
logger.debug( logger.debug(
f"Matching credentials for graph {graph.id}: {len(aggregated_creds)} required" f"Matching credentials for graph {graph.id}: {len(aggregated_creds)} required"
) )
@@ -360,7 +265,7 @@ async def match_user_credentials_to_graph(
_, _,
_, _,
) in aggregated_creds.items(): ) in aggregated_creds.items():
# Find first matching credential by provider, type, scopes, and host/URL # Find first matching credential by provider, type, and scopes
matching_cred = next( matching_cred = next(
( (
cred cred
@@ -375,10 +280,6 @@ async def match_user_credentials_to_graph(
cred.type != "host_scoped" cred.type != "host_scoped"
or _credential_is_for_host(cred, credential_requirements) or _credential_is_for_host(cred, credential_requirements)
) )
and (
cred.provider != ProviderName.MCP
or _credential_is_for_mcp_server(cred, credential_requirements)
)
), ),
None, None,
) )
@@ -430,6 +331,8 @@ def _credential_has_required_scopes(
# If no scopes are required, any credential matches # If no scopes are required, any credential matches
if not requirements.required_scopes: if not requirements.required_scopes:
return True return True
# Check that credential scopes are a superset of required scopes
return set(credential.scopes).issuperset(requirements.required_scopes) return set(credential.scopes).issuperset(requirements.required_scopes)
@@ -449,22 +352,6 @@ def _credential_is_for_host(
return credential.matches_url(list(requirements.discriminator_values)[0]) return credential.matches_url(list(requirements.discriminator_values)[0])
def _credential_is_for_mcp_server(
credential: Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if an MCP OAuth credential matches the required server URL."""
if not requirements.discriminator_values:
return True
server_url = (
credential.metadata.get("mcp_server_url")
if isinstance(credential, OAuth2Credentials)
else None
)
return server_url in requirements.discriminator_values if server_url else False
async def check_user_has_required_credentials( async def check_user_has_required_credentials(
user_id: str, user_id: str,
required_credentials: list[CredentialsMetaInput], required_credentials: list[CredentialsMetaInput],

View File

@@ -0,0 +1,78 @@
"""Tests for chat tools utility functions."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.data.model import CredentialsFieldInfo
def _make_regular_field() -> CredentialsFieldInfo:
return CredentialsFieldInfo.model_validate(
{
"credentials_provider": ["github"],
"credentials_types": ["api_key"],
"is_auto_credential": False,
},
by_alias=True,
)
def test_build_missing_credentials_excludes_auto_creds():
"""
build_missing_credentials_from_graph() should use regular_credentials_inputs
and thus exclude auto_credentials from the "missing" set.
"""
from backend.api.features.chat.tools.utils import (
build_missing_credentials_from_graph,
)
regular_field = _make_regular_field()
mock_graph = MagicMock()
# regular_credentials_inputs should only return the non-auto field
mock_graph.regular_credentials_inputs = {
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
}
result = build_missing_credentials_from_graph(mock_graph, matched_credentials=None)
# Should include the regular credential
assert "github_api_key" in result
# Should NOT include the auto_credential (not in regular_credentials_inputs)
assert "google_oauth2" not in result
@pytest.mark.asyncio
async def test_match_user_credentials_excludes_auto_creds():
"""
match_user_credentials_to_graph() should use regular_credentials_inputs
and thus exclude auto_credentials from matching.
"""
from backend.api.features.chat.tools.utils import match_user_credentials_to_graph
regular_field = _make_regular_field()
mock_graph = MagicMock()
mock_graph.id = "test-graph"
# regular_credentials_inputs returns only non-auto fields
mock_graph.regular_credentials_inputs = {
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
}
# Mock the credentials manager to return no credentials
with patch(
"backend.api.features.chat.tools.utils.IntegrationCredentialsManager"
) as MockCredsMgr:
mock_store = AsyncMock()
mock_store.get_all_creds.return_value = []
MockCredsMgr.return_value.store = mock_store
matched, missing = await match_user_credentials_to_graph(
user_id="test-user", graph=mock_graph
)
# No credentials available, so github should be missing
assert len(matched) == 0
assert len(missing) == 1
assert "github_api_key" in missing[0]

View File

@@ -6,8 +6,8 @@ from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
from backend.copilot.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.data.db_accessors import workspace_db from backend.data.workspace import get_or_create_workspace
from backend.util.settings import Config from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager from backend.util.workspace import WorkspaceManager
@@ -88,9 +88,7 @@ class ListWorkspaceFilesTool(BaseTool):
@property @property
def description(self) -> str: def description(self) -> str:
return ( return (
"List files in the user's persistent workspace (cloud storage). " "List files in the user's workspace. "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Read/Glob tools instead. "
"Returns file names, paths, sizes, and metadata. " "Returns file names, paths, sizes, and metadata. "
"Optionally filter by path prefix." "Optionally filter by path prefix."
) )
@@ -148,7 +146,7 @@ class ListWorkspaceFilesTool(BaseTool):
include_all_sessions: bool = kwargs.get("include_all_sessions", False) include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try: try:
workspace = await workspace_db().get_or_create_workspace(user_id) workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access # Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id) manager = WorkspaceManager(user_id, workspace.id, session_id)
@@ -167,8 +165,8 @@ class ListWorkspaceFilesTool(BaseTool):
file_id=f.id, file_id=f.id,
name=f.name, name=f.name,
path=f.path, path=f.path,
mime_type=f.mime_type, mime_type=f.mimeType,
size_bytes=f.size_bytes, size_bytes=f.sizeBytes,
) )
for f in files for f in files
] ]
@@ -206,9 +204,7 @@ class ReadWorkspaceFileTool(BaseTool):
@property @property
def description(self) -> str: def description(self) -> str:
return ( return (
"Read a file from the user's persistent workspace (cloud storage). " "Read a file from the user's workspace. "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Read tool instead. "
"Specify either file_id or path to identify the file. " "Specify either file_id or path to identify the file. "
"For small text files, returns content directly. " "For small text files, returns content directly. "
"For large or binary files, returns metadata and a download URL. " "For large or binary files, returns metadata and a download URL. "
@@ -284,7 +280,7 @@ class ReadWorkspaceFileTool(BaseTool):
) )
try: try:
workspace = await workspace_db().get_or_create_workspace(user_id) workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access # Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id) manager = WorkspaceManager(user_id, workspace.id, session_id)
@@ -309,8 +305,8 @@ class ReadWorkspaceFileTool(BaseTool):
target_file_id = file_info.id target_file_id = file_info.id
# Decide whether to return inline content or metadata+URL # Decide whether to return inline content or metadata+URL
is_small_file = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
is_text_file = self._is_text_mime_type(file_info.mime_type) is_text_file = self._is_text_mime_type(file_info.mimeType)
# Return inline content for small text files (unless force_download_url) # Return inline content for small text files (unless force_download_url)
if is_small_file and is_text_file and not force_download_url: if is_small_file and is_text_file and not force_download_url:
@@ -321,7 +317,7 @@ class ReadWorkspaceFileTool(BaseTool):
file_id=file_info.id, file_id=file_info.id,
name=file_info.name, name=file_info.name,
path=file_info.path, path=file_info.path,
mime_type=file_info.mime_type, mime_type=file_info.mimeType,
content_base64=content_b64, content_base64=content_b64,
message=f"Successfully read file: {file_info.name}", message=f"Successfully read file: {file_info.name}",
session_id=session_id, session_id=session_id,
@@ -350,11 +346,11 @@ class ReadWorkspaceFileTool(BaseTool):
file_id=file_info.id, file_id=file_info.id,
name=file_info.name, name=file_info.name,
path=file_info.path, path=file_info.path,
mime_type=file_info.mime_type, mime_type=file_info.mimeType,
size_bytes=file_info.size_bytes, size_bytes=file_info.sizeBytes,
download_url=download_url, download_url=download_url,
preview=preview, preview=preview,
message=f"File: {file_info.name} ({file_info.size_bytes} bytes). Use download_url to retrieve content.", message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
session_id=session_id, session_id=session_id,
) )
@@ -382,9 +378,7 @@ class WriteWorkspaceFileTool(BaseTool):
@property @property
def description(self) -> str: def description(self) -> str:
return ( return (
"Write or create a file in the user's persistent workspace (cloud storage). " "Write or create a file in the user's workspace. "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Write tool instead. "
"Provide the content as a base64-encoded string. " "Provide the content as a base64-encoded string. "
f"Maximum file size is {Config().max_file_size_mb}MB. " f"Maximum file size is {Config().max_file_size_mb}MB. "
"Files are saved to the current session's folder by default. " "Files are saved to the current session's folder by default. "
@@ -484,7 +478,7 @@ class WriteWorkspaceFileTool(BaseTool):
# Virus scan # Virus scan
await scan_content_safe(content, filename=filename) await scan_content_safe(content, filename=filename)
workspace = await workspace_db().get_or_create_workspace(user_id) workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access # Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id) manager = WorkspaceManager(user_id, workspace.id, session_id)
@@ -500,7 +494,7 @@ class WriteWorkspaceFileTool(BaseTool):
file_id=file_record.id, file_id=file_record.id,
name=file_record.name, name=file_record.name,
path=file_record.path, path=file_record.path,
size_bytes=file_record.size_bytes, size_bytes=file_record.sizeBytes,
message=f"Successfully wrote file: {file_record.name}", message=f"Successfully wrote file: {file_record.name}",
session_id=session_id, session_id=session_id,
) )
@@ -529,7 +523,7 @@ class DeleteWorkspaceFileTool(BaseTool):
@property @property
def description(self) -> str: def description(self) -> str:
return ( return (
"Delete a file from the user's persistent workspace (cloud storage). " "Delete a file from the user's workspace. "
"Specify either file_id or path to identify the file. " "Specify either file_id or path to identify the file. "
"Paths are scoped to the current session by default. " "Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access." "Use /sessions/<session_id>/... for cross-session access."
@@ -583,7 +577,7 @@ class DeleteWorkspaceFileTool(BaseTool):
) )
try: try:
workspace = await workspace_db().get_or_create_workspace(user_id) workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access # Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id) manager = WorkspaceManager(user_id, workspace.id, session_id)

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated, Any, List, Literal from typing import TYPE_CHECKING, Annotated, List, Literal
from autogpt_libs.auth import get_user_id from autogpt_libs.auth import get_user_id
from fastapi import ( from fastapi import (
@@ -14,7 +14,7 @@ from fastapi import (
Security, Security,
status, status,
) )
from pydantic import BaseModel, Field, SecretStr, model_validator from pydantic import BaseModel, Field, SecretStr
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
from backend.api.features.library.db import set_preset_webhook, update_preset from backend.api.features.library.db import set_preset_webhook, update_preset
@@ -39,11 +39,7 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.credentials_store import provider_matches from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager from backend.integrations.webhooks import get_webhook_manager
@@ -106,37 +102,9 @@ class CredentialsMetaResponse(BaseModel):
scopes: list[str] | None scopes: list[str] | None
username: str | None username: str | None
host: str | None = Field( host: str | None = Field(
default=None, default=None, description="Host pattern for host-scoped credentials"
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
) )
@model_validator(mode="before")
@classmethod
def _normalize_provider(cls, data: Any) -> Any:
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug."""
if isinstance(data, dict):
prov = data.get("provider", "")
if isinstance(prov, str) and prov.startswith("ProviderName."):
member = prov.removeprefix("ProviderName.")
try:
data = {**data, "provider": ProviderName[member].value}
except KeyError:
pass
return data
@staticmethod
def get_host(cred: Credentials) -> str | None:
"""Extract host from credential: HostScoped host or MCP server URL."""
if isinstance(cred, HostScopedCredentials):
return cred.host
if isinstance(cred, OAuth2Credentials) and cred.provider in (
ProviderName.MCP,
ProviderName.MCP.value,
"ProviderName.MCP",
):
return (cred.metadata or {}).get("mcp_server_url")
return None
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens") @router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
async def callback( async def callback(
@@ -211,7 +179,9 @@ async def callback(
title=credentials.title, title=credentials.title,
scopes=credentials.scopes, scopes=credentials.scopes,
username=credentials.username, username=credentials.username,
host=(CredentialsMetaResponse.get_host(credentials)), host=(
credentials.host if isinstance(credentials, HostScopedCredentials) else None
),
) )
@@ -229,7 +199,7 @@ async def list_credentials(
title=cred.title, title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None, scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None, username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred), host=cred.host if isinstance(cred, HostScopedCredentials) else None,
) )
for cred in credentials for cred in credentials
] ]
@@ -252,7 +222,7 @@ async def list_credentials_by_provider(
title=cred.title, title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None, scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None, username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred), host=cred.host if isinstance(cred, HostScopedCredentials) else None,
) )
for cred in credentials for cred in credentials
] ]
@@ -352,11 +322,7 @@ async def delete_credentials(
tokens_revoked = None tokens_revoked = None
if isinstance(creds, OAuth2Credentials): if isinstance(creds, OAuth2Credentials):
if provider_matches(provider.value, ProviderName.MCP.value): handler = _get_provider_oauth_handler(request, provider)
# MCP uses dynamic per-server OAuth — create handler from metadata
handler = create_mcp_oauth_handler(creds)
else:
handler = _get_provider_oauth_handler(request, provider)
tokens_revoked = await handler.revoke_tokens(creds) tokens_revoked = await handler.revoke_tokens(creds)
return CredentialsDeletionResponse(revoked=tokens_revoked) return CredentialsDeletionResponse(revoked=tokens_revoked)

View File

@@ -12,11 +12,12 @@ import backend.api.features.store.image_gen as store_image_gen
import backend.api.features.store.media as store_media import backend.api.features.store.media as store_media
import backend.data.graph as graph_db import backend.data.graph as graph_db
import backend.data.integrations as integrations_db import backend.data.integrations as integrations_db
from backend.data.block import BlockInput
from backend.data.db import transaction from backend.data.db import transaction
from backend.data.execution import get_graph_execution from backend.data.execution import get_graph_execution
from backend.data.graph import GraphSettings from backend.data.graph import GraphSettings
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
from backend.data.model import CredentialsMetaInput, GraphInput from backend.data.model import CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import ( from backend.integrations.webhooks.graph_lifecycle_hooks import (
on_graph_activate, on_graph_activate,
@@ -1102,7 +1103,7 @@ async def create_preset_from_graph_execution(
raise NotFoundError( raise NotFoundError(
f"Graph #{graph_execution.graph_id} not found or accessible" f"Graph #{graph_execution.graph_id} not found or accessible"
) )
elif len(graph.aggregate_credentials_inputs()) > 0: elif len(graph.regular_credentials_inputs) > 0:
raise ValueError( raise ValueError(
f"Graph execution #{graph_exec_id} can't be turned into a preset " f"Graph execution #{graph_exec_id} can't be turned into a preset "
"because it was run before this feature existed " "because it was run before this feature existed "
@@ -1129,7 +1130,7 @@ async def create_preset_from_graph_execution(
async def update_preset( async def update_preset(
user_id: str, user_id: str,
preset_id: str, preset_id: str,
inputs: Optional[GraphInput] = None, inputs: Optional[BlockInput] = None,
credentials: Optional[dict[str, CredentialsMetaInput]] = None, credentials: Optional[dict[str, CredentialsMetaInput]] = None,
name: Optional[str] = None, name: Optional[str] = None,
description: Optional[str] = None, description: Optional[str] = None,

View File

@@ -6,12 +6,9 @@ import prisma.enums
import prisma.models import prisma.models
import pydantic import pydantic
from backend.data.block import BlockInput
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
from backend.data.model import ( from backend.data.model import CredentialsMetaInput, is_credentials_field_name
CredentialsMetaInput,
GraphInput,
is_credentials_field_name,
)
from backend.util.json import loads as json_loads from backend.util.json import loads as json_loads
from backend.util.models import Pagination from backend.util.models import Pagination
@@ -326,7 +323,7 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
graph_id: str graph_id: str
graph_version: int graph_version: int
inputs: GraphInput inputs: BlockInput
credentials: dict[str, CredentialsMetaInput] credentials: dict[str, CredentialsMetaInput]
name: str name: str
@@ -355,7 +352,7 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
Request model used when updating a preset for a library agent. Request model used when updating a preset for a library agent.
""" """
inputs: Optional[GraphInput] = None inputs: Optional[BlockInput] = None
credentials: Optional[dict[str, CredentialsMetaInput]] = None credentials: Optional[dict[str, CredentialsMetaInput]] = None
name: Optional[str] = None name: Optional[str] = None
@@ -398,7 +395,7 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
"Webhook must be included in AgentPreset query when webhookId is set" "Webhook must be included in AgentPreset query when webhookId is set"
) )
input_data: GraphInput = {} input_data: BlockInput = {}
input_credentials: dict[str, CredentialsMetaInput] = {} input_credentials: dict[str, CredentialsMetaInput] = {}
for preset_input in preset.InputPresets: for preset_input in preset.InputPresets:

View File

@@ -1,404 +0,0 @@
"""
MCP (Model Context Protocol) API routes.
Provides endpoints for MCP tool discovery and OAuth authentication so the
frontend can list available tools on an MCP server before placing a block.
"""
import logging
from typing import Annotated, Any
from urllib.parse import urlparse
import fastapi
from autogpt_libs.auth import get_user_id
from fastapi import Security
from pydantic import BaseModel, Field
from backend.api.features.integrations.router import CredentialsMetaResponse
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
router = fastapi.APIRouter(tags=["mcp"])
creds_manager = IntegrationCredentialsManager()
# ====================== Tool Discovery ====================== #
class DiscoverToolsRequest(BaseModel):
"""Request to discover tools on an MCP server."""
server_url: str = Field(description="URL of the MCP server")
auth_token: str | None = Field(
default=None,
description="Optional Bearer token for authenticated MCP servers",
)
class MCPToolResponse(BaseModel):
"""A single MCP tool returned by discovery."""
name: str
description: str
input_schema: dict[str, Any]
class DiscoverToolsResponse(BaseModel):
"""Response containing the list of tools available on an MCP server."""
tools: list[MCPToolResponse]
server_name: str | None = None
protocol_version: str | None = None
@router.post(
"/discover-tools",
summary="Discover available tools on an MCP server",
response_model=DiscoverToolsResponse,
)
async def discover_tools(
request: DiscoverToolsRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> DiscoverToolsResponse:
"""
Connect to an MCP server and return its available tools.
If the user has a stored MCP credential for this server URL, it will be
used automatically — no need to pass an explicit auth token.
"""
auth_token = request.auth_token
# Auto-use stored MCP credential when no explicit token is provided.
if not auth_token:
mcp_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
# Find the freshest credential for this server URL
best_cred: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
):
if best_cred is None or (
(cred.access_token_expires_at or 0)
> (best_cred.access_token_expires_at or 0)
):
best_cred = cred
if best_cred:
# Refresh the token if expired before using it
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
logger.info(
f"Using MCP credential {best_cred.id} for {request.server_url}, "
f"expires_at={best_cred.access_token_expires_at}"
)
auth_token = best_cred.access_token.get_secret_value()
client = MCPClient(request.server_url, auth_token=auth_token)
try:
init_result = await client.initialize()
tools = await client.list_tools()
except HTTPClientError as e:
if e.status_code in (401, 403):
raise fastapi.HTTPException(
status_code=401,
detail="This MCP server requires authentication. "
"Please provide a valid auth token.",
)
raise fastapi.HTTPException(status_code=502, detail=str(e))
except MCPClientError as e:
raise fastapi.HTTPException(status_code=502, detail=str(e))
except Exception as e:
raise fastapi.HTTPException(
status_code=502,
detail=f"Failed to connect to MCP server: {e}",
)
return DiscoverToolsResponse(
tools=[
MCPToolResponse(
name=t.name,
description=t.description,
input_schema=t.input_schema,
)
for t in tools
],
server_name=(
init_result.get("serverInfo", {}).get("name")
or urlparse(request.server_url).hostname
or "MCP"
),
protocol_version=init_result.get("protocolVersion"),
)
# ======================== OAuth Flow ======================== #
class MCPOAuthLoginRequest(BaseModel):
"""Request to start an OAuth flow for an MCP server."""
server_url: str = Field(description="URL of the MCP server that requires OAuth")
class MCPOAuthLoginResponse(BaseModel):
"""Response with the OAuth login URL for the user to authenticate."""
login_url: str
state_token: str
@router.post(
"/oauth/login",
summary="Initiate OAuth login for an MCP server",
)
async def mcp_oauth_login(
request: MCPOAuthLoginRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> MCPOAuthLoginResponse:
"""
Discover OAuth metadata from the MCP server and return a login URL.
1. Discovers the protected-resource metadata (RFC 9728)
2. Fetches the authorization server metadata (RFC 8414)
3. Performs Dynamic Client Registration (RFC 7591) if available
4. Returns the authorization URL for the frontend to open in a popup
"""
client = MCPClient(request.server_url)
# Step 1: Discover protected-resource metadata (RFC 9728)
protected_resource = await client.discover_auth()
metadata: dict[str, Any] | None = None
if protected_resource and protected_resource.get("authorization_servers"):
auth_server_url = protected_resource["authorization_servers"][0]
resource_url = protected_resource.get("resource", request.server_url)
# Step 2a: Discover auth-server metadata (RFC 8414)
metadata = await client.discover_auth_server_metadata(auth_server_url)
else:
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
# and serve OAuth metadata directly without protected-resource metadata.
# Don't assume a resource_url — omitting it lets the auth server choose
# the correct audience for the token (RFC 8707 resource is optional).
resource_url = None
metadata = await client.discover_auth_server_metadata(request.server_url)
if (
not metadata
or "authorization_endpoint" not in metadata
or "token_endpoint" not in metadata
):
raise fastapi.HTTPException(
status_code=400,
detail="This MCP server does not advertise OAuth support. "
"You may need to provide an auth token manually.",
)
authorize_url = metadata["authorization_endpoint"]
token_url = metadata["token_endpoint"]
registration_endpoint = metadata.get("registration_endpoint")
revoke_url = metadata.get("revocation_endpoint")
# Step 3: Dynamic Client Registration (RFC 7591) if available
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
raise fastapi.HTTPException(
status_code=500,
detail="Frontend base URL is not configured.",
)
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
client_id = ""
client_secret = ""
if registration_endpoint:
reg_result = await _register_mcp_client(
registration_endpoint, redirect_uri, request.server_url
)
if reg_result:
client_id = reg_result.get("client_id", "")
client_secret = reg_result.get("client_secret", "")
if not client_id:
client_id = "autogpt-platform"
# Step 4: Store state token with OAuth metadata for the callback
scopes = (protected_resource or {}).get("scopes_supported") or metadata.get(
"scopes_supported", []
)
state_token, code_challenge = await creds_manager.store.store_state_token(
user_id,
ProviderName.MCP.value,
scopes,
state_metadata={
"authorize_url": authorize_url,
"token_url": token_url,
"revoke_url": revoke_url,
"resource_url": resource_url,
"server_url": request.server_url,
"client_id": client_id,
"client_secret": client_secret,
},
)
# Step 5: Build and return the login URL
handler = MCPOAuthHandler(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
authorize_url=authorize_url,
token_url=token_url,
resource_url=resource_url,
)
login_url = handler.get_login_url(
scopes, state_token, code_challenge=code_challenge
)
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
class MCPOAuthCallbackRequest(BaseModel):
"""Request to exchange an OAuth code for tokens."""
code: str = Field(description="Authorization code from OAuth callback")
state_token: str = Field(description="State token for CSRF verification")
class MCPOAuthCallbackResponse(BaseModel):
"""Response after successfully storing OAuth credentials."""
credential_id: str
@router.post(
"/oauth/callback",
summary="Exchange OAuth code for MCP tokens",
)
async def mcp_oauth_callback(
request: MCPOAuthCallbackRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> CredentialsMetaResponse:
"""
Exchange the authorization code for tokens and store the credential.
The frontend calls this after receiving the OAuth code from the popup.
On success, subsequent ``/discover-tools`` calls for the same server URL
will automatically use the stored credential.
"""
valid_state = await creds_manager.store.verify_state_token(
user_id, request.state_token, ProviderName.MCP.value
)
if not valid_state:
raise fastapi.HTTPException(
status_code=400,
detail="Invalid or expired state token.",
)
meta = valid_state.state_metadata
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
raise fastapi.HTTPException(
status_code=500,
detail="Frontend base URL is not configured.",
)
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
handler = MCPOAuthHandler(
client_id=meta["client_id"],
client_secret=meta.get("client_secret", ""),
redirect_uri=redirect_uri,
authorize_url=meta["authorize_url"],
token_url=meta["token_url"],
revoke_url=meta.get("revoke_url"),
resource_url=meta.get("resource_url"),
)
try:
credentials = await handler.exchange_code_for_tokens(
request.code, valid_state.scopes, valid_state.code_verifier
)
except Exception as e:
raise fastapi.HTTPException(
status_code=400,
detail=f"OAuth token exchange failed: {e}",
)
# Enrich credential metadata for future lookup and token refresh
if credentials.metadata is None:
credentials.metadata = {}
credentials.metadata["mcp_server_url"] = meta["server_url"]
credentials.metadata["mcp_client_id"] = meta["client_id"]
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
credentials.metadata["mcp_token_url"] = meta["token_url"]
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
credentials.title = f"MCP: {hostname}"
# Remove old MCP credentials for the same server to prevent stale token buildup.
try:
old_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
for old in old_creds:
if (
isinstance(old, OAuth2Credentials)
and (old.metadata or {}).get("mcp_server_url") == meta["server_url"]
):
await creds_manager.store.delete_creds_by_id(user_id, old.id)
logger.info(
f"Removed old MCP credential {old.id} for {meta['server_url']}"
)
except Exception:
logger.debug("Could not clean up old MCP credentials", exc_info=True)
await creds_manager.create(user_id, credentials)
return CredentialsMetaResponse(
id=credentials.id,
provider=credentials.provider,
type=credentials.type,
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=credentials.metadata.get("mcp_server_url"),
)
# ======================== Helpers ======================== #
async def _register_mcp_client(
registration_endpoint: str,
redirect_uri: str,
server_url: str,
) -> dict[str, Any] | None:
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
try:
response = await Requests(raise_for_status=True).post(
registration_endpoint,
json={
"client_name": "AutoGPT Platform",
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_post",
},
)
data = response.json()
if isinstance(data, dict) and "client_id" in data:
return data
return None
except Exception as e:
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
return None

View File

@@ -1,436 +0,0 @@
"""Tests for MCP API routes.
Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient
to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop.
"""
from unittest.mock import AsyncMock, patch
import fastapi
import httpx
import pytest
import pytest_asyncio
from autogpt_libs.auth import get_user_id
from backend.api.features.mcp.routes import router
from backend.blocks.mcp.client import MCPClientError, MCPTool
from backend.util.request import HTTPClientError
app = fastapi.FastAPI()
app.include_router(router)
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
@pytest_asyncio.fixture(scope="module")
async def client():
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
yield c
class TestDiscoverTools:
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_success(self, client):
mock_tools = [
MCPTool(
name="get_weather",
description="Get weather for a city",
input_schema={
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
),
MCPTool(
name="add_numbers",
description="Add two numbers",
input_schema={
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"},
},
},
),
]
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={
"protocolVersion": "2025-03-26",
"serverInfo": {"name": "test-server"},
}
)
instance.list_tools = AsyncMock(return_value=mock_tools)
response = await client.post(
"/discover-tools",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
data = response.json()
assert len(data["tools"]) == 2
assert data["tools"][0]["name"] == "get_weather"
assert data["tools"][1]["name"] == "add_numbers"
assert data["server_name"] == "test-server"
assert data["protocol_version"] == "2025-03-26"
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_with_auth_token(self, client):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
)
instance.list_tools = AsyncMock(return_value=[])
response = await client.post(
"/discover-tools",
json={
"server_url": "https://mcp.example.com/mcp",
"auth_token": "my-secret-token",
},
)
assert response.status_code == 200
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="my-secret-token",
)
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auto_uses_stored_credential(self, client):
"""When no explicit token is given, stored MCP credentials are used."""
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
stored_cred = OAuth2Credentials(
provider="mcp",
title="MCP: example.com",
access_token=SecretStr("stored-token-123"),
refresh_token=None,
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
)
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
)
instance.list_tools = AsyncMock(return_value=[])
response = await client.post(
"/discover-tools",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="stored-token-123",
)
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_mcp_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=MCPClientError("Connection refused")
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://bad-server.example.com/mcp"},
)
assert response.status_code == 502
assert "Connection refused" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_generic_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
response = await client.post(
"/discover-tools",
json={"server_url": "https://timeout.example.com/mcp"},
)
assert response.status_code == 502
assert "Failed to connect" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auth_required(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://auth-server.example.com/mcp"},
)
assert response.status_code == 401
assert "requires authentication" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_forbidden(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://auth-server.example.com/mcp"},
)
assert response.status_code == 401
assert "requires authentication" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_missing_url(self, client):
response = await client.post("/discover-tools", json={})
assert response.status_code == 422
class TestOAuthLogin:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_success(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch(
"backend.api.features.mcp.routes._register_mcp_client"
) as mock_register,
):
instance = MockClient.return_value
instance.discover_auth = AsyncMock(
return_value={
"authorization_servers": ["https://auth.sentry.io"],
"resource": "https://mcp.sentry.dev/mcp",
"scopes_supported": ["openid"],
}
)
instance.discover_auth_server_metadata = AsyncMock(
return_value={
"authorization_endpoint": "https://auth.sentry.io/authorize",
"token_endpoint": "https://auth.sentry.io/token",
"registration_endpoint": "https://auth.sentry.io/register",
}
)
mock_register.return_value = {
"client_id": "registered-client-id",
"client_secret": "registered-secret",
}
mock_cm.store.store_state_token = AsyncMock(
return_value=("state-token-123", "code-challenge-abc")
)
mock_settings.config.frontend_base_url = "http://localhost:3000"
response = await client.post(
"/oauth/login",
json={"server_url": "https://mcp.sentry.dev/mcp"},
)
assert response.status_code == 200
data = response.json()
assert "login_url" in data
assert data["state_token"] == "state-token-123"
assert "auth.sentry.io/authorize" in data["login_url"]
assert "registered-client-id" in data["login_url"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_no_oauth_support(self, client):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.discover_auth = AsyncMock(return_value=None)
instance.discover_auth_server_metadata = AsyncMock(return_value=None)
response = await client.post(
"/oauth/login",
json={"server_url": "https://simple-server.example.com/mcp"},
)
assert response.status_code == 400
assert "does not advertise OAuth" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_fallback_to_public_client(self, client):
"""When DCR is unavailable, falls back to default public client ID."""
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
):
instance = MockClient.return_value
instance.discover_auth = AsyncMock(
return_value={
"authorization_servers": ["https://auth.example.com"],
"resource": "https://mcp.example.com/mcp",
}
)
instance.discover_auth_server_metadata = AsyncMock(
return_value={
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
# No registration_endpoint
}
)
mock_cm.store.store_state_token = AsyncMock(
return_value=("state-abc", "challenge-xyz")
)
mock_settings.config.frontend_base_url = "http://localhost:3000"
response = await client.post(
"/oauth/login",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
data = response.json()
assert "autogpt-platform" in data["login_url"]
class TestOAuthCallback:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_success(self, client):
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
mock_creds = OAuth2Credentials(
provider="mcp",
title=None,
access_token=SecretStr("access-token-xyz"),
refresh_token=None,
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
metadata={
"mcp_token_url": "https://auth.sentry.io/token",
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
},
)
with (
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
):
mock_settings.config.frontend_base_url = "http://localhost:3000"
# Mock state verification
mock_state = AsyncMock()
mock_state.state_metadata = {
"authorize_url": "https://auth.sentry.io/authorize",
"token_url": "https://auth.sentry.io/token",
"client_id": "test-client-id",
"client_secret": "test-secret",
"server_url": "https://mcp.sentry.dev/mcp",
}
mock_state.scopes = ["openid"]
mock_state.code_verifier = "verifier-123"
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
mock_cm.create = AsyncMock()
handler_instance = MockHandler.return_value
handler_instance.exchange_code_for_tokens = AsyncMock(
return_value=mock_creds
)
# Mock old credential cleanup
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
response = await client.post(
"/oauth/callback",
json={"code": "auth-code-abc", "state_token": "state-token-123"},
)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["provider"] == "mcp"
assert data["type"] == "oauth2"
mock_cm.create.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_invalid_state(self, client):
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
response = await client.post(
"/oauth/callback",
json={"code": "auth-code", "state_token": "bad-state"},
)
assert response.status_code == 400
assert "Invalid or expired" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_token_exchange_fails(self, client):
with (
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
):
mock_settings.config.frontend_base_url = "http://localhost:3000"
mock_state = AsyncMock()
mock_state.state_metadata = {
"authorize_url": "https://auth.example.com/authorize",
"token_url": "https://auth.example.com/token",
"client_id": "cid",
"server_url": "https://mcp.example.com/mcp",
}
mock_state.scopes = []
mock_state.code_verifier = "v"
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
handler_instance = MockHandler.return_value
handler_instance.exchange_code_for_tokens = AsyncMock(
side_effect=RuntimeError("Token exchange failed")
)
response = await client.post(
"/oauth/callback",
json={"code": "bad-code", "state_token": "state"},
)
assert response.status_code == 400
assert "token exchange failed" in response.json()["detail"].lower()

View File

@@ -5,8 +5,8 @@ from typing import Optional
import aiohttp import aiohttp
from fastapi import HTTPException from fastapi import HTTPException
from backend.blocks import get_block
from backend.data import graph as graph_db from backend.data import graph as graph_db
from backend.data.block import get_block
from backend.util.settings import Settings from backend.util.settings import Settings
from .models import ApiResponse, ChatRequest, GraphData from .models import ApiResponse, ChatRequest, GraphData

View File

@@ -152,7 +152,7 @@ class BlockHandler(ContentHandler):
async def get_missing_items(self, batch_size: int) -> list[ContentItem]: async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
"""Fetch blocks without embeddings.""" """Fetch blocks without embeddings."""
from backend.blocks import get_blocks from backend.data.block import get_blocks
# Get all available blocks # Get all available blocks
all_blocks = get_blocks() all_blocks = get_blocks()
@@ -249,7 +249,7 @@ class BlockHandler(ContentHandler):
async def get_stats(self) -> dict[str, int]: async def get_stats(self) -> dict[str, int]:
"""Get statistics about block embedding coverage.""" """Get statistics about block embedding coverage."""
from backend.blocks import get_blocks from backend.data.block import get_blocks
all_blocks = get_blocks() all_blocks = get_blocks()

View File

@@ -93,7 +93,7 @@ async def test_block_handler_get_missing_items(mocker):
mock_existing = [] mock_existing = []
with patch( with patch(
"backend.blocks.get_blocks", "backend.data.block.get_blocks",
return_value=mock_blocks, return_value=mock_blocks,
): ):
with patch( with patch(
@@ -135,7 +135,7 @@ async def test_block_handler_get_stats(mocker):
mock_embedded = [{"count": 2}] mock_embedded = [{"count": 2}]
with patch( with patch(
"backend.blocks.get_blocks", "backend.data.block.get_blocks",
return_value=mock_blocks, return_value=mock_blocks,
): ):
with patch( with patch(
@@ -327,7 +327,7 @@ async def test_block_handler_handles_missing_attributes():
mock_blocks = {"block-minimal": mock_block_class} mock_blocks = {"block-minimal": mock_block_class}
with patch( with patch(
"backend.blocks.get_blocks", "backend.data.block.get_blocks",
return_value=mock_blocks, return_value=mock_blocks,
): ):
with patch( with patch(
@@ -360,7 +360,7 @@ async def test_block_handler_skips_failed_blocks():
mock_blocks = {"good-block": good_block, "bad-block": bad_block} mock_blocks = {"good-block": good_block, "bad-block": bad_block}
with patch( with patch(
"backend.blocks.get_blocks", "backend.data.block.get_blocks",
return_value=mock_blocks, return_value=mock_blocks,
): ):
with patch( with patch(

View File

@@ -662,7 +662,7 @@ async def cleanup_orphaned_embeddings() -> dict[str, Any]:
) )
current_ids = {row["id"] for row in valid_agents} current_ids = {row["id"] for row in valid_agents}
elif content_type == ContentType.BLOCK: elif content_type == ContentType.BLOCK:
from backend.blocks import get_blocks from backend.data.block import get_blocks
current_ids = set(get_blocks().keys()) current_ids = set(get_blocks().keys())
elif content_type == ContentType.DOCUMENTATION: elif content_type == ContentType.DOCUMENTATION:

View File

@@ -8,7 +8,6 @@ Includes BM25 reranking for improved lexical relevance.
import logging import logging
import re import re
import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Literal from typing import Any, Literal
@@ -363,11 +362,7 @@ async def unified_hybrid_search(
LIMIT {limit_param} OFFSET {offset_param} LIMIT {limit_param} OFFSET {offset_param}
""" """
try: results = await query_raw_with_schema(sql_query, *params)
results = await query_raw_with_schema(sql_query, *params)
except Exception as e:
await _log_vector_error_diagnostics(e)
raise
total = results[0]["total_count"] if results else 0 total = results[0]["total_count"] if results else 0
# Apply BM25 reranking # Apply BM25 reranking
@@ -691,11 +686,7 @@ async def hybrid_search(
LIMIT {limit_param} OFFSET {offset_param} LIMIT {limit_param} OFFSET {offset_param}
""" """
try: results = await query_raw_with_schema(sql_query, *params)
results = await query_raw_with_schema(sql_query, *params)
except Exception as e:
await _log_vector_error_diagnostics(e)
raise
total = results[0]["total_count"] if results else 0 total = results[0]["total_count"] if results else 0
@@ -727,87 +718,6 @@ async def hybrid_search_simple(
return await hybrid_search(query=query, page=page, page_size=page_size) return await hybrid_search(query=query, page=page, page_size=page_size)
# ============================================================================
# Diagnostics
# ============================================================================
# Rate limit: only log vector error diagnostics once per this interval
_VECTOR_DIAG_INTERVAL_SECONDS = 60
_last_vector_diag_time: float = 0
async def _log_vector_error_diagnostics(error: Exception) -> None:
"""Log diagnostic info when 'type vector does not exist' error occurs.
Note: Diagnostic queries use query_raw_with_schema which may run on a different
pooled connection than the one that failed. Session-level search_path can differ,
so these diagnostics show cluster-wide state, not necessarily the failed session.
Includes rate limiting to avoid log spam - only logs once per minute.
Caller should re-raise the error after calling this function.
"""
global _last_vector_diag_time
# Check if this is the vector type error
error_str = str(error).lower()
if not (
"type" in error_str and "vector" in error_str and "does not exist" in error_str
):
return
# Rate limit: only log once per interval
now = time.time()
if now - _last_vector_diag_time < _VECTOR_DIAG_INTERVAL_SECONDS:
return
_last_vector_diag_time = now
try:
diagnostics: dict[str, object] = {}
try:
search_path_result = await query_raw_with_schema("SHOW search_path")
diagnostics["search_path"] = search_path_result
except Exception as e:
diagnostics["search_path"] = f"Error: {e}"
try:
schema_result = await query_raw_with_schema("SELECT current_schema()")
diagnostics["current_schema"] = schema_result
except Exception as e:
diagnostics["current_schema"] = f"Error: {e}"
try:
user_result = await query_raw_with_schema(
"SELECT current_user, session_user, current_database()"
)
diagnostics["user_info"] = user_result
except Exception as e:
diagnostics["user_info"] = f"Error: {e}"
try:
# Check pgvector extension installation (cluster-wide, stable info)
ext_result = await query_raw_with_schema(
"SELECT extname, extversion, nspname as schema "
"FROM pg_extension e "
"JOIN pg_namespace n ON e.extnamespace = n.oid "
"WHERE extname = 'vector'"
)
diagnostics["pgvector_extension"] = ext_result
except Exception as e:
diagnostics["pgvector_extension"] = f"Error: {e}"
logger.error(
f"Vector type error diagnostics:\n"
f" Error: {error}\n"
f" search_path: {diagnostics.get('search_path')}\n"
f" current_schema: {diagnostics.get('current_schema')}\n"
f" user_info: {diagnostics.get('user_info')}\n"
f" pgvector_extension: {diagnostics.get('pgvector_extension')}"
)
except Exception as diag_error:
logger.error(f"Failed to collect vector error diagnostics: {diag_error}")
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights # Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
# for existing code that expects the popularity parameter # for existing code that expects the popularity parameter
HybridSearchWeights = StoreAgentSearchWeights HybridSearchWeights = StoreAgentSearchWeights

View File

@@ -7,6 +7,15 @@ from replicate.client import Client as ReplicateClient
from replicate.exceptions import ReplicateError from replicate.exceptions import ReplicateError
from replicate.helpers import FileOutput from replicate.helpers import FileOutput
from backend.blocks.ideogram import (
AspectRatio,
ColorPalettePreset,
IdeogramModelBlock,
IdeogramModelName,
MagicPromptOption,
StyleType,
UpscaleOption,
)
from backend.data.graph import GraphBaseMeta from backend.data.graph import GraphBaseMeta
from backend.data.model import CredentialsMetaInput, ProviderName from backend.data.model import CredentialsMetaInput, ProviderName
from backend.integrations.credentials_store import ideogram_credentials from backend.integrations.credentials_store import ideogram_credentials
@@ -41,16 +50,6 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
if not ideogram_credentials.api_key: if not ideogram_credentials.api_key:
raise ValueError("Missing Ideogram API key") raise ValueError("Missing Ideogram API key")
from backend.blocks.ideogram import (
AspectRatio,
ColorPalettePreset,
IdeogramModelBlock,
IdeogramModelName,
MagicPromptOption,
StyleType,
UpscaleOption,
)
name = graph.name name = graph.name
description = f"{name} ({graph.description})" if graph.description else name description = f"{name} ({graph.description})" if graph.description else name

View File

@@ -40,11 +40,10 @@ from backend.api.model import (
UpdateTimezoneRequest, UpdateTimezoneRequest,
UploadFileResponse, UploadFileResponse,
) )
from backend.blocks import get_block, get_blocks
from backend.data import execution as execution_db from backend.data import execution as execution_db
from backend.data import graph as graph_db from backend.data import graph as graph_db
from backend.data.auth import api_key as api_key_db from backend.data.auth import api_key as api_key_db
from backend.data.block import BlockInput, CompletedBlockOutput from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
from backend.data.credit import ( from backend.data.credit import (
AutoTopUpConfig, AutoTopUpConfig,
RefundRequest, RefundRequest,

View File

@@ -11,7 +11,7 @@ import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi.responses import Response from fastapi.responses import Response
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file from backend.data.workspace import get_workspace, get_workspace_file
from backend.util.workspace_storage import get_workspace_storage from backend.util.workspace_storage import get_workspace_storage
@@ -44,11 +44,11 @@ router = fastapi.APIRouter(
) )
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response: def _create_streaming_response(content: bytes, file) -> Response:
"""Create a streaming response for file content.""" """Create a streaming response for file content."""
return Response( return Response(
content=content, content=content,
media_type=file.mime_type, media_type=file.mimeType,
headers={ headers={
"Content-Disposition": _sanitize_filename_for_header(file.name), "Content-Disposition": _sanitize_filename_for_header(file.name),
"Content-Length": str(len(content)), "Content-Length": str(len(content)),
@@ -56,7 +56,7 @@ def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
) )
async def _create_file_download_response(file: WorkspaceFile) -> Response: async def _create_file_download_response(file) -> Response:
""" """
Create a download response for a workspace file. Create a download response for a workspace file.
@@ -66,33 +66,33 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
storage = await get_workspace_storage() storage = await get_workspace_storage()
# For local storage, stream the file directly # For local storage, stream the file directly
if file.storage_path.startswith("local://"): if file.storagePath.startswith("local://"):
content = await storage.retrieve(file.storage_path) content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file) return _create_streaming_response(content, file)
# For GCS, try to redirect to signed URL, fall back to streaming # For GCS, try to redirect to signed URL, fall back to streaming
try: try:
url = await storage.get_download_url(file.storage_path, expires_in=300) url = await storage.get_download_url(file.storagePath, expires_in=300)
# If we got back an API path (fallback), stream directly instead # If we got back an API path (fallback), stream directly instead
if url.startswith("/api/"): if url.startswith("/api/"):
content = await storage.retrieve(file.storage_path) content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file) return _create_streaming_response(content, file)
return fastapi.responses.RedirectResponse(url=url, status_code=302) return fastapi.responses.RedirectResponse(url=url, status_code=302)
except Exception as e: except Exception as e:
# Log the signed URL failure with context # Log the signed URL failure with context
logger.error( logger.error(
f"Failed to get signed URL for file {file.id} " f"Failed to get signed URL for file {file.id} "
f"(storagePath={file.storage_path}): {e}", f"(storagePath={file.storagePath}): {e}",
exc_info=True, exc_info=True,
) )
# Fall back to streaming directly from GCS # Fall back to streaming directly from GCS
try: try:
content = await storage.retrieve(file.storage_path) content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file) return _create_streaming_response(content, file)
except Exception as fallback_error: except Exception as fallback_error:
logger.error( logger.error(
f"Fallback streaming also failed for file {file.id} " f"Fallback streaming also failed for file {file.id} "
f"(storagePath={file.storage_path}): {fallback_error}", f"(storagePath={file.storagePath}): {fallback_error}",
exc_info=True, exc_info=True,
) )
raise raise

View File

@@ -26,7 +26,6 @@ import backend.api.features.executions.review.routes
import backend.api.features.library.db import backend.api.features.library.db
import backend.api.features.library.model import backend.api.features.library.model
import backend.api.features.library.routes import backend.api.features.library.routes
import backend.api.features.mcp.routes as mcp_routes
import backend.api.features.oauth import backend.api.features.oauth
import backend.api.features.otto.routes import backend.api.features.otto.routes
import backend.api.features.postmark.postmark import backend.api.features.postmark.postmark
@@ -41,11 +40,11 @@ import backend.data.user
import backend.integrations.webhooks.utils import backend.integrations.webhooks.utils
import backend.util.service import backend.util.service
import backend.util.settings import backend.util.settings
from backend.blocks.llm import DEFAULT_LLM_MODEL from backend.api.features.chat.completion_consumer import (
from backend.copilot.completion_consumer import (
start_completion_consumer, start_completion_consumer,
stop_completion_consumer, stop_completion_consumer,
) )
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data.model import Credentials from backend.data.model import Credentials
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi from backend.monitoring.instrumentation import instrument_fastapi
@@ -344,11 +343,6 @@ app.include_router(
tags=["workspace"], tags=["workspace"],
prefix="/api/workspace", prefix="/api/workspace",
) )
app.include_router(
mcp_routes.router,
tags=["v2", "mcp"],
prefix="/api/mcp",
)
app.include_router( app.include_router(
backend.api.features.oauth.router, backend.api.features.oauth.router,
tags=["oauth"], tags=["oauth"],

View File

@@ -38,9 +38,7 @@ def main(**kwargs):
from backend.api.rest_api import AgentServer from backend.api.rest_api import AgentServer
from backend.api.ws_api import WebsocketServer from backend.api.ws_api import WebsocketServer
from backend.copilot.executor.manager import CoPilotExecutor from backend.executor import DatabaseManager, ExecutionManager, Scheduler
from backend.data.db_manager import DatabaseManager
from backend.executor import ExecutionManager, Scheduler
from backend.notifications import NotificationManager from backend.notifications import NotificationManager
run_processes( run_processes(
@@ -50,7 +48,6 @@ def main(**kwargs):
WebsocketServer(), WebsocketServer(),
AgentServer(), AgentServer(),
ExecutionManager(), ExecutionManager(),
CoPilotExecutor(),
**kwargs, **kwargs,
) )

View File

@@ -3,19 +3,22 @@ import logging
import os import os
import re import re
from pathlib import Path from pathlib import Path
from typing import Sequence, Type, TypeVar from typing import TYPE_CHECKING, TypeVar
from backend.blocks._base import AnyBlockSchema, BlockType
from backend.util.cache import cached from backend.util.cache import cached
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.block import Block
T = TypeVar("T") T = TypeVar("T")
@cached(ttl_seconds=3600) @cached(ttl_seconds=3600)
def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]: def load_all_blocks() -> dict[str, type["Block"]]:
from backend.blocks._base import Block from backend.data.block import Block
from backend.util.settings import Config from backend.util.settings import Config
# Check if example blocks should be loaded from settings # Check if example blocks should be loaded from settings
@@ -47,8 +50,8 @@ def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
importlib.import_module(f".{module}", package=__name__) importlib.import_module(f".{module}", package=__name__)
# Load all Block instances from the available modules # Load all Block instances from the available modules
available_blocks: dict[str, type["AnyBlockSchema"]] = {} available_blocks: dict[str, type["Block"]] = {}
for block_cls in _all_subclasses(Block): for block_cls in all_subclasses(Block):
class_name = block_cls.__name__ class_name = block_cls.__name__
if class_name.endswith("Base"): if class_name.endswith("Base"):
@@ -61,7 +64,7 @@ def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
"please name the class with 'Base' at the end" "please name the class with 'Base' at the end"
) )
block = block_cls() # pyright: ignore[reportAbstractUsage] block = block_cls.create()
if not isinstance(block.id, str) or len(block.id) != 36: if not isinstance(block.id, str) or len(block.id) != 36:
raise ValueError( raise ValueError(
@@ -102,7 +105,7 @@ def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
available_blocks[block.id] = block_cls available_blocks[block.id] = block_cls
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets # Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
from ._utils import is_block_auth_configured from backend.data.block import is_block_auth_configured
filtered_blocks = {} filtered_blocks = {}
for block_id, block_cls in available_blocks.items(): for block_id, block_cls in available_blocks.items():
@@ -112,48 +115,11 @@ def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
return filtered_blocks return filtered_blocks
def _all_subclasses(cls: type[T]) -> list[type[T]]: __all__ = ["load_all_blocks"]
def all_subclasses(cls: type[T]) -> list[type[T]]:
subclasses = cls.__subclasses__() subclasses = cls.__subclasses__()
for subclass in subclasses: for subclass in subclasses:
subclasses += _all_subclasses(subclass) subclasses += all_subclasses(subclass)
return subclasses return subclasses
# ============== Block access helper functions ============== #
def get_blocks() -> dict[str, Type["AnyBlockSchema"]]:
return load_all_blocks()
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
def get_block(block_id: str) -> "AnyBlockSchema | None":
cls = get_blocks().get(block_id)
return cls() if cls else None
@cached(ttl_seconds=3600)
def get_webhook_block_ids() -> Sequence[str]:
return [
id
for id, B in get_blocks().items()
if B().block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
]
@cached(ttl_seconds=3600)
def get_io_block_ids() -> Sequence[str]:
return [
id
for id, B in get_blocks().items()
if B().block_type in (BlockType.INPUT, BlockType.OUTPUT)
]
@cached(ttl_seconds=3600)
def get_human_in_the_loop_block_ids() -> Sequence[str]:
return [
id
for id, B in get_blocks().items()
if B().block_type == BlockType.HUMAN_IN_THE_LOOP
]

View File

@@ -1,740 +0,0 @@
import inspect
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Generic,
Optional,
Type,
TypeAlias,
TypeVar,
cast,
get_origin,
)
import jsonref
import jsonschema
from pydantic import BaseModel
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
from backend.data.model import (
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
SchemaField,
is_credentials_field_name,
)
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.exceptions import (
BlockError,
BlockExecutionError,
BlockInputError,
BlockOutputError,
BlockUnknownError,
)
from backend.util.settings import Config
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
from backend.data.model import ContributorDetails, NodeExecutionStats
from ..data.graph import Link
app_config = Config()
BlockTestOutput = BlockOutputEntry | tuple[str, Callable[[Any], bool]]
class BlockType(Enum):
STANDARD = "Standard"
INPUT = "Input"
OUTPUT = "Output"
NOTE = "Note"
WEBHOOK = "Webhook"
WEBHOOK_MANUAL = "Webhook (manual)"
AGENT = "Agent"
AI = "AI"
AYRSHARE = "Ayrshare"
HUMAN_IN_THE_LOOP = "Human In The Loop"
MCP_TOOL = "MCP Tool"
class BlockCategory(Enum):
AI = "Block that leverages AI to perform a task."
SOCIAL = "Block that interacts with social media platforms."
TEXT = "Block that processes text data."
SEARCH = "Block that searches or extracts information from the internet."
BASIC = "Block that performs basic operations."
INPUT = "Block that interacts with input of the graph."
OUTPUT = "Block that interacts with output of the graph."
LOGIC = "Programming logic to control the flow of your agent"
COMMUNICATION = "Block that interacts with communication platforms."
DEVELOPER_TOOLS = "Developer tools such as GitHub blocks."
DATA = "Block that interacts with structured data."
HARDWARE = "Block that interacts with hardware."
AGENT = "Block that interacts with other agents."
CRM = "Block that interacts with CRM services."
SAFETY = (
"Block that provides AI safety mechanisms such as detecting harmful content"
)
PRODUCTIVITY = "Block that helps with productivity"
ISSUE_TRACKING = "Block that helps with issue tracking"
MULTIMEDIA = "Block that interacts with multimedia content"
MARKETING = "Block that helps with marketing"
def dict(self) -> dict[str, str]:
return {"category": self.name, "description": self.value}
class BlockCostType(str, Enum):
RUN = "run" # cost X credits per run
BYTE = "byte" # cost X credits per byte
SECOND = "second" # cost X credits per second
class BlockCost(BaseModel):
cost_amount: int
cost_filter: BlockInput
cost_type: BlockCostType
def __init__(
self,
cost_amount: int,
cost_type: BlockCostType = BlockCostType.RUN,
cost_filter: Optional[BlockInput] = None,
**data: Any,
) -> None:
super().__init__(
cost_amount=cost_amount,
cost_filter=cost_filter or {},
cost_type=cost_type,
**data,
)
class BlockInfo(BaseModel):
id: str
name: str
inputSchema: dict[str, Any]
outputSchema: dict[str, Any]
costs: list[BlockCost]
description: str
categories: list[dict[str, str]]
contributors: list[dict[str, Any]]
staticOutput: bool
uiType: str
class BlockSchema(BaseModel):
cached_jsonschema: ClassVar[dict[str, Any]]
@classmethod
def jsonschema(cls) -> dict[str, Any]:
if cls.cached_jsonschema:
return cls.cached_jsonschema
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
def ref_to_dict(obj):
if isinstance(obj, dict):
# OpenAPI <3.1 does not support sibling fields that has a $ref key
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
keys = {"allOf", "anyOf", "oneOf"}
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
if one_key:
obj.update(obj[one_key][0])
return {
key: ref_to_dict(value)
for key, value in obj.items()
if not key.startswith("$") and key != one_key
}
elif isinstance(obj, list):
return [ref_to_dict(item) for item in obj]
return obj
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
return cls.cached_jsonschema
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:
return json.validate_with_jsonschema(
schema=cls.jsonschema(),
data={k: v for k, v in data.items() if v is not None},
)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return cls.validate_data(data)
@classmethod
def get_field_schema(cls, field_name: str) -> dict[str, Any]:
model_schema = cls.jsonschema().get("properties", {})
if not model_schema:
raise ValueError(f"Invalid model schema {cls}")
property_schema = model_schema.get(field_name)
if not property_schema:
raise ValueError(f"Invalid property name {field_name}")
return property_schema
@classmethod
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:
"""
Validate the data against a specific property (one of the input/output name).
Returns the validation error message if the data does not match the schema.
"""
try:
property_schema = cls.get_field_schema(field_name)
jsonschema.validate(json.to_dict(data), property_schema)
return None
except jsonschema.ValidationError as e:
return str(e)
@classmethod
def get_fields(cls) -> set[str]:
return set(cls.model_fields.keys())
@classmethod
def get_required_fields(cls) -> set[str]:
return {
field
for field, field_info in cls.model_fields.items()
if field_info.is_required()
}
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""Validates the schema definition. Rules:
- Fields with annotation `CredentialsMetaInput` MUST be
named `credentials` or `*_credentials`
- Fields named `credentials` or `*_credentials` MUST be
of type `CredentialsMetaInput`
"""
super().__pydantic_init_subclass__(**kwargs)
# Reset cached JSON schema to prevent inheriting it from parent class
cls.cached_jsonschema = {}
credentials_fields = cls.get_credentials_fields()
for field_name in cls.get_fields():
if is_credentials_field_name(field_name):
if field_name not in credentials_fields:
raise TypeError(
f"Credentials field '{field_name}' on {cls.__qualname__} "
f"is not of type {CredentialsMetaInput.__name__}"
)
CredentialsMetaInput.validate_credentials_field_schema(
cls.get_field_schema(field_name), field_name
)
elif field_name in credentials_fields:
raise KeyError(
f"Credentials field '{field_name}' on {cls.__qualname__} "
"has invalid name: must be 'credentials' or *_credentials"
)
@classmethod
def get_credentials_fields(cls) -> dict[str, type[CredentialsMetaInput]]:
return {
field_name: info.annotation
for field_name, info in cls.model_fields.items()
if (
inspect.isclass(info.annotation)
and issubclass(
get_origin(info.annotation) or info.annotation,
CredentialsMetaInput,
)
)
}
@classmethod
def get_auto_credentials_fields(cls) -> dict[str, dict[str, Any]]:
"""
Get fields that have auto_credentials metadata (e.g., GoogleDriveFileInput).
Returns a dict mapping kwarg_name -> {field_name, auto_credentials_config}
Raises:
ValueError: If multiple fields have the same kwarg_name, as this would
cause silent overwriting and only the last field would be processed.
"""
result: dict[str, dict[str, Any]] = {}
schema = cls.jsonschema()
properties = schema.get("properties", {})
for field_name, field_schema in properties.items():
auto_creds = field_schema.get("auto_credentials")
if auto_creds:
kwarg_name = auto_creds.get("kwarg_name", "credentials")
if kwarg_name in result:
raise ValueError(
f"Duplicate auto_credentials kwarg_name '{kwarg_name}' "
f"in fields '{result[kwarg_name]['field_name']}' and "
f"'{field_name}' on {cls.__qualname__}"
)
result[kwarg_name] = {
"field_name": field_name,
"config": auto_creds,
}
return result
@classmethod
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
result = {}
# Regular credentials fields
for field_name in cls.get_credentials_fields().keys():
result[field_name] = CredentialsFieldInfo.model_validate(
cls.get_field_schema(field_name), by_alias=True
)
# Auto-generated credentials fields (from GoogleDriveFileInput etc.)
for kwarg_name, info in cls.get_auto_credentials_fields().items():
config = info["config"]
# Build a schema-like dict that CredentialsFieldInfo can parse
auto_schema = {
"credentials_provider": [config.get("provider", "google")],
"credentials_types": [config.get("type", "oauth2")],
"credentials_scopes": config.get("scopes"),
}
result[kwarg_name] = CredentialsFieldInfo.model_validate(
auto_schema, by_alias=True
)
return result
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data # Return as is, by default.
@classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
input_fields_from_nodes = {link.sink_name for link in links}
return input_fields_from_nodes - set(data)
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
return cls.get_required_fields() - set(data)
class BlockSchemaInput(BlockSchema):
"""
Base schema class for block inputs.
All block input schemas should extend this class for consistency.
"""
pass
class BlockSchemaOutput(BlockSchema):
"""
Base schema class for block outputs that includes a standard error field.
All block output schemas should extend this class to ensure consistent error handling.
"""
error: str = SchemaField(
description="Error message if the operation failed", default=""
)
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchemaInput)
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchemaOutput)
class EmptyInputSchema(BlockSchemaInput):
pass
class EmptyOutputSchema(BlockSchemaOutput):
pass
# For backward compatibility - will be deprecated
EmptySchema = EmptyOutputSchema
# --8<-- [start:BlockWebhookConfig]
class BlockManualWebhookConfig(BaseModel):
"""
Configuration model for webhook-triggered blocks on which
the user has to manually set up the webhook at the provider.
"""
provider: ProviderName
"""The service provider that the webhook connects to"""
webhook_type: str
"""
Identifier for the webhook type. E.g. GitHub has repo and organization level hooks.
Only for use in the corresponding `WebhooksManager`.
"""
event_filter_input: str = ""
"""
Name of the block's event filter input.
Leave empty if the corresponding webhook doesn't have distinct event/payload types.
"""
event_format: str = "{event}"
"""
Template string for the event(s) that a block instance subscribes to.
Applied individually to each event selected in the event filter input.
Example: `"pull_request.{event}"` -> `"pull_request.opened"`
"""
class BlockWebhookConfig(BlockManualWebhookConfig):
"""
Configuration model for webhook-triggered blocks for which
the webhook can be automatically set up through the provider's API.
"""
resource_format: str
"""
Template string for the resource that a block instance subscribes to.
Fields will be filled from the block's inputs (except `payload`).
Example: `f"{repo}/pull_requests"` (note: not how it's actually implemented)
Only for use in the corresponding `WebhooksManager`.
"""
# --8<-- [end:BlockWebhookConfig]
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
def __init__(
self,
id: str = "",
description: str = "",
contributors: list["ContributorDetails"] = [],
categories: set[BlockCategory] | None = None,
input_schema: Type[BlockSchemaInputType] = EmptyInputSchema,
output_schema: Type[BlockSchemaOutputType] = EmptyOutputSchema,
test_input: BlockInput | list[BlockInput] | None = None,
test_output: BlockTestOutput | list[BlockTestOutput] | None = None,
test_mock: dict[str, Any] | None = None,
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
disabled: bool = False,
static_output: bool = False,
block_type: BlockType = BlockType.STANDARD,
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
is_sensitive_action: bool = False,
):
"""
Initialize the block with the given schema.
Args:
id: The unique identifier for the block, this value will be persisted in the
DB. So it should be a unique and constant across the application run.
Use the UUID format for the ID.
description: The description of the block, explaining what the block does.
contributors: The list of contributors who contributed to the block.
input_schema: The schema, defined as a Pydantic model, for the input data.
output_schema: The schema, defined as a Pydantic model, for the output data.
test_input: The list or single sample input data for the block, for testing.
test_output: The list or single expected output if the test_input is run.
test_mock: function names on the block implementation to mock on test run.
disabled: If the block is disabled, it will not be available for execution.
static_output: Whether the output links of the block are static by default.
"""
from backend.data.model import NodeExecutionStats
self.id = id
self.input_schema = input_schema
self.output_schema = output_schema
self.test_input = test_input
self.test_output = test_output
self.test_mock = test_mock
self.test_credentials = test_credentials
self.description = description
self.categories = categories or set()
self.contributors = contributors or set()
self.disabled = disabled
self.static_output = static_output
self.block_type = block_type
self.webhook_config = webhook_config
self.is_sensitive_action = is_sensitive_action
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
# Enforce presence of credentials field on auto-setup webhook blocks
if not (cred_fields := self.input_schema.get_credentials_fields()):
raise TypeError(
"credentials field is required on auto-setup webhook blocks"
)
# Disallow multiple credentials inputs on webhook blocks
elif len(cred_fields) > 1:
raise ValueError(
"Multiple credentials inputs not supported on webhook blocks"
)
self.block_type = BlockType.WEBHOOK
else:
self.block_type = BlockType.WEBHOOK_MANUAL
# Enforce shape of webhook event filter, if present
if self.webhook_config.event_filter_input:
event_filter_field = self.input_schema.model_fields[
self.webhook_config.event_filter_input
]
if not (
isinstance(event_filter_field.annotation, type)
and issubclass(event_filter_field.annotation, BaseModel)
and all(
field.annotation is bool
for field in event_filter_field.annotation.model_fields.values()
)
):
raise NotImplementedError(
f"{self.name} has an invalid webhook event selector: "
"field must be a BaseModel and all its fields must be boolean"
)
# Enforce presence of 'payload' input
if "payload" not in self.input_schema.model_fields:
raise TypeError(
f"{self.name} is webhook-triggered but has no 'payload' input"
)
# Disable webhook-triggered block if webhook functionality not available
if not app_config.platform_base_url:
self.disabled = True
@abstractmethod
async def run(self, input_data: BlockSchemaInputType, **kwargs) -> BlockOutput:
"""
Run the block with the given input data.
Args:
input_data: The input data with the structure of input_schema.
Kwargs: Currently 14/02/2025 these include
graph_id: The ID of the graph.
node_id: The ID of the node.
graph_exec_id: The ID of the graph execution.
node_exec_id: The ID of the node execution.
user_id: The ID of the user.
Returns:
A Generator that yields (output_name, output_data).
output_name: One of the output name defined in Block's output_schema.
output_data: The data for the output_name, matching the defined schema.
"""
# --- satisfy the type checker, never executed -------------
if False: # noqa: SIM115
yield "name", "value" # pyright: ignore[reportMissingYield]
raise NotImplementedError(f"{self.name} does not implement the run method.")
async def run_once(
self, input_data: BlockSchemaInputType, output: str, **kwargs
) -> Any:
async for item in self.run(input_data, **kwargs):
name, data = item
if name == output:
return data
raise ValueError(f"{self.name} did not produce any output for {output}")
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
self.execution_stats += stats
return self.execution_stats
@property
def name(self):
return self.__class__.__name__
def to_dict(self):
return {
"id": self.id,
"name": self.name,
"inputSchema": self.input_schema.jsonschema(),
"outputSchema": self.output_schema.jsonschema(),
"description": self.description,
"categories": [category.dict() for category in self.categories],
"contributors": [
contributor.model_dump() for contributor in self.contributors
],
"staticOutput": self.static_output,
"uiType": self.block_type.value,
}
def get_info(self) -> BlockInfo:
from backend.data.credit import get_block_cost
return BlockInfo(
id=self.id,
name=self.name,
inputSchema=self.input_schema.jsonschema(),
outputSchema=self.output_schema.jsonschema(),
costs=get_block_cost(self),
description=self.description,
categories=[category.dict() for category in self.categories],
contributors=[
contributor.model_dump() for contributor in self.contributors
],
staticOutput=self.static_output,
uiType=self.block_type.value,
)
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
try:
async for output_name, output_data in self._execute(input_data, **kwargs):
yield output_name, output_data
except Exception as ex:
if isinstance(ex, BlockError):
raise ex
else:
raise (
BlockExecutionError
if isinstance(ex, ValueError)
else BlockUnknownError
)(
message=str(ex),
block_name=self.name,
block_id=self.id,
) from ex
async def is_block_exec_need_review(
self,
input_data: BlockInput,
*,
user_id: str,
node_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: "ExecutionContext",
**kwargs,
) -> tuple[bool, BlockInput]:
"""
Check if this block execution needs human review and handle the review process.
Returns:
Tuple of (should_pause, input_data_to_use)
- should_pause: True if execution should be paused for review
- input_data_to_use: The input data to use (may be modified by reviewer)
"""
if not (
self.is_sensitive_action and execution_context.sensitive_action_safe_mode
):
return False, input_data
from backend.blocks.helpers.review import HITLReviewHelper
# Handle the review request and get decision
decision = await HITLReviewHelper.handle_review_decision(
input_data=input_data,
user_id=user_id,
node_id=node_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
block_name=self.name,
editable=True,
)
if decision is None:
# We're awaiting review - pause execution
return True, input_data
if not decision.should_proceed:
# Review was rejected, raise an error to stop execution
raise BlockExecutionError(
message=f"Block execution rejected by reviewer: {decision.message}",
block_name=self.name,
block_id=self.id,
)
# Review was approved - use the potentially modified data
# ReviewResult.data must be a dict for block inputs
reviewed_data = decision.review_result.data
if not isinstance(reviewed_data, dict):
raise BlockExecutionError(
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
block_name=self.name,
block_id=self.id,
)
return False, reviewed_data
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
# Check for review requirement only if running within a graph execution context
# Direct block execution (e.g., from chat) skips the review process
has_graph_context = all(
key in kwargs
for key in (
"node_exec_id",
"graph_exec_id",
"graph_id",
"execution_context",
)
)
if has_graph_context:
should_pause, input_data = await self.is_block_exec_need_review(
input_data, **kwargs
)
if should_pause:
return
# Validate the input data (original or reviewer-modified) once
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
# Use the validated input data
async for output_name, output_data in self.run(
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
**kwargs,
):
if output_name == "error":
raise BlockExecutionError(
message=output_data, block_name=self.name, block_id=self.id
)
if self.block_type == BlockType.STANDARD and (
error := self.output_schema.validate_field(output_name, output_data)
):
raise BlockOutputError(
message=f"Block produced an invalid output data: {error}",
block_name=self.name,
block_id=self.id,
)
yield output_name, output_data
def is_triggered_by_event_type(
self, trigger_config: dict[str, Any], event_type: str
) -> bool:
if not self.webhook_config:
raise TypeError("This method can't be used on non-trigger blocks")
if not self.webhook_config.event_filter_input:
return True
event_filter = trigger_config.get(self.webhook_config.event_filter_input)
if not event_filter:
raise ValueError("Event filter is not configured on trigger")
return event_type in [
self.webhook_config.event_format.format(event=k)
for k in event_filter
if event_filter[k] is True
]
# Type alias for any block with standard input/output schemas
AnyBlockSchema: TypeAlias = Block[BlockSchemaInput, BlockSchemaOutput]

View File

@@ -1,122 +0,0 @@
import logging
import os
from backend.integrations.providers import ProviderName
from ._base import AnyBlockSchema
logger = logging.getLogger(__name__)
def is_block_auth_configured(
block_cls: type[AnyBlockSchema],
) -> bool:
"""
Check if a block has a valid authentication method configured at runtime.
For example if a block is an OAuth-only block and there env vars are not set,
do not show it in the UI.
"""
from backend.sdk.registry import AutoRegistry
# Create an instance to access input_schema
try:
block = block_cls()
except Exception as e:
# If we can't create a block instance, assume it's not OAuth-only
logger.error(f"Error creating block instance for {block_cls.__name__}: {e}")
return True
logger.debug(
f"Checking if block {block_cls.__name__} has a valid provider configured"
)
# Get all credential inputs from input schema
credential_inputs = block.input_schema.get_credentials_fields_info()
required_inputs = block.input_schema.get_required_fields()
if not credential_inputs:
logger.debug(
f"Block {block_cls.__name__} has no credential inputs - Treating as valid"
)
return True
# Check credential inputs
if len(required_inputs.intersection(credential_inputs.keys())) == 0:
logger.debug(
f"Block {block_cls.__name__} has only optional credential inputs"
" - will work without credentials configured"
)
# Check if the credential inputs for this block are correctly configured
for field_name, field_info in credential_inputs.items():
provider_names = field_info.provider
if not provider_names:
logger.warning(
f"Block {block_cls.__name__} "
f"has credential input '{field_name}' with no provider options"
" - Disabling"
)
return False
# If a field has multiple possible providers, each one needs to be usable to
# prevent breaking the UX
for _provider_name in provider_names:
provider_name = _provider_name.value
if provider_name in ProviderName.__members__.values():
logger.debug(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' is part of the legacy provider system"
" - Treating as valid"
)
break
provider = AutoRegistry.get_provider(provider_name)
if not provider:
logger.warning(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"refers to unknown provider '{provider_name}' - Disabling"
)
return False
# Check the provider's supported auth types
if field_info.supported_types != provider.supported_auth_types:
logger.warning(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"has mismatched supported auth types (field <> Provider): "
f"{field_info.supported_types} != {provider.supported_auth_types}"
)
if not (supported_auth_types := provider.supported_auth_types):
# No auth methods are been configured for this provider
logger.warning(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' "
"has no authentication methods configured - Disabling"
)
return False
# Check if provider supports OAuth
if "oauth2" in supported_auth_types:
# Check if OAuth environment variables are set
if (oauth_config := provider.oauth_config) and bool(
os.getenv(oauth_config.client_id_env_var)
and os.getenv(oauth_config.client_secret_env_var)
):
logger.debug(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' is configured for OAuth"
)
else:
logger.error(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' "
"is missing OAuth client ID or secret - Disabling"
)
return False
logger.debug(
f"Block {block_cls.__name__} credential input '{field_name}' is valid; "
f"supported credential types: {', '.join(field_info.supported_types)}"
)
return True

View File

@@ -1,7 +1,7 @@
import logging import logging
from typing import TYPE_CHECKING, Any, Optional from typing import Any, Optional
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockInput, BlockInput,
@@ -9,15 +9,13 @@ from backend.blocks._base import (
BlockSchema, BlockSchema,
BlockSchemaInput, BlockSchemaInput,
BlockType, BlockType,
get_block,
) )
from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks
from backend.data.model import NodeExecutionStats, SchemaField from backend.data.model import NodeExecutionStats, SchemaField
from backend.util.json import validate_with_jsonschema from backend.util.json import validate_with_jsonschema
from backend.util.retry import func_retry from backend.util.retry import func_retry
if TYPE_CHECKING:
from backend.executor.utils import LogMetadata
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@@ -126,10 +124,9 @@ class AgentExecutorBlock(Block):
graph_version: int, graph_version: int,
graph_exec_id: str, graph_exec_id: str,
user_id: str, user_id: str,
logger: "LogMetadata", logger,
) -> BlockOutput: ) -> BlockOutput:
from backend.blocks import get_block
from backend.data.execution import ExecutionEventType from backend.data.execution import ExecutionEventType
from backend.executor import utils as execution_utils from backend.executor import utils as execution_utils
@@ -201,7 +198,7 @@ class AgentExecutorBlock(Block):
self, self,
graph_exec_id: str, graph_exec_id: str,
user_id: str, user_id: str,
logger: "LogMetadata", logger,
) -> None: ) -> None:
from backend.executor import utils as execution_utils from backend.executor import utils as execution_utils

View File

@@ -1,11 +1,5 @@
from typing import Any from typing import Any
from backend.blocks._base import (
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.llm import ( from backend.blocks.llm import (
DEFAULT_LLM_MODEL, DEFAULT_LLM_MODEL,
TEST_CREDENTIALS, TEST_CREDENTIALS,
@@ -17,6 +11,12 @@ from backend.blocks.llm import (
LLMResponse, LLMResponse,
llm_call, llm_call,
) )
from backend.data.block import (
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField

View File

@@ -6,7 +6,7 @@ from pydantic import SecretStr
from replicate.client import Client as ReplicateClient from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput from replicate.helpers import FileOutput
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -5,12 +5,7 @@ from pydantic import SecretStr
from replicate.client import Client as ReplicateClient from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput from replicate.helpers import FileOutput
from backend.blocks._base import ( from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
Block,
BlockCategory,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
from backend.data.model import ( from backend.data.model import (
APIKeyCredentials, APIKeyCredentials,

View File

@@ -6,7 +6,7 @@ from typing import Literal
from pydantic import SecretStr from pydantic import SecretStr
from replicate.client import Client as ReplicateClient from replicate.client import Client as ReplicateClient
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -6,7 +6,7 @@ from typing import Literal
from pydantic import SecretStr from pydantic import SecretStr
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,10 +1,3 @@
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.apollo._api import ApolloClient from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import ( from backend.blocks.apollo._auth import (
TEST_CREDENTIALS, TEST_CREDENTIALS,
@@ -17,6 +10,13 @@ from backend.blocks.apollo.models import (
PrimaryPhone, PrimaryPhone,
SearchOrganizationsRequest, SearchOrganizationsRequest,
) )
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import CredentialsField, SchemaField from backend.data.model import CredentialsField, SchemaField

View File

@@ -1,12 +1,5 @@
import asyncio import asyncio
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.apollo._api import ApolloClient from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import ( from backend.blocks.apollo._auth import (
TEST_CREDENTIALS, TEST_CREDENTIALS,
@@ -21,6 +14,13 @@ from backend.blocks.apollo.models import (
SearchPeopleRequest, SearchPeopleRequest,
SenorityLevels, SenorityLevels,
) )
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import CredentialsField, SchemaField from backend.data.model import CredentialsField, SchemaField

View File

@@ -1,10 +1,3 @@
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.apollo._api import ApolloClient from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import ( from backend.blocks.apollo._auth import (
TEST_CREDENTIALS, TEST_CREDENTIALS,
@@ -13,6 +6,13 @@ from backend.blocks.apollo._auth import (
ApolloCredentialsInput, ApolloCredentialsInput,
) )
from backend.blocks.apollo.models import Contact, EnrichPersonRequest from backend.blocks.apollo.models import Contact, EnrichPersonRequest
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import CredentialsField, SchemaField from backend.data.model import CredentialsField, SchemaField

View File

@@ -3,7 +3,7 @@ from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from backend.blocks._base import BlockSchemaInput from backend.data.block import BlockSchemaInput
from backend.data.model import SchemaField, UserIntegrations from backend.data.model import SchemaField, UserIntegrations
from backend.integrations.ayrshare import AyrshareClient from backend.integrations.ayrshare import AyrshareClient
from backend.util.clients import get_database_manager_async_client from backend.util.clients import get_database_manager_async_client

View File

@@ -1,7 +1,7 @@
import enum import enum
from typing import Any from typing import Any
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
@@ -126,7 +126,6 @@ class PrintToConsoleBlock(Block):
output_schema=PrintToConsoleBlock.Output, output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"}, test_input={"text": "Hello, World!"},
is_sensitive_action=True, is_sensitive_action=True,
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
test_output=[ test_output=[
("output", "Hello, World!"), ("output", "Hello, World!"),
("status", "printed"), ("status", "printed"),

View File

@@ -2,7 +2,7 @@ import os
import re import re
from typing import Type from typing import Type
from backend.blocks._base import ( from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

Some files were not shown because too many files have changed in this diff Show More