Compare commits

..

1 Commits

Author SHA1 Message Date
Bentlybro
6b1f0df58c fix(backend): Clean up orphaned schedules without schedule_id
Old scheduled jobs created before schedule_id was added to
GraphExecutionJobArgs have schedule_id=None. When these fail
validation, _handle_graph_validation_error could not unschedule
them, causing them to fire repeatedly and generate ~60K+ Sentry
errors (AUTOGPT-SERVER-6W2 and AUTOGPT-SERVER-6W3).

Fix: Add _cleanup_old_schedules_without_id() which finds schedules
for the graph but only removes those with schedule_id=None (legacy
jobs). This preserves any valid newer schedules the user may have
created, unlike the broader _cleanup_orphaned_schedules_for_graph()
which removes all schedules for a graph.
2026-02-02 14:15:25 +00:00
634 changed files with 18915 additions and 49955 deletions

View File

@@ -5,13 +5,42 @@
!docs/
# Platform - Libs
!autogpt_platform/autogpt_libs/
!autogpt_platform/autogpt_libs/autogpt_libs/
!autogpt_platform/autogpt_libs/pyproject.toml
!autogpt_platform/autogpt_libs/poetry.lock
!autogpt_platform/autogpt_libs/README.md
# Platform - Backend
!autogpt_platform/backend/
!autogpt_platform/backend/backend/
!autogpt_platform/backend/test/e2e_test_data.py
!autogpt_platform/backend/migrations/
!autogpt_platform/backend/schema.prisma
!autogpt_platform/backend/pyproject.toml
!autogpt_platform/backend/poetry.lock
!autogpt_platform/backend/README.md
!autogpt_platform/backend/.env
!autogpt_platform/backend/gen_prisma_types_stub.py
# Platform - Market
!autogpt_platform/market/market/
!autogpt_platform/market/scripts.py
!autogpt_platform/market/schema.prisma
!autogpt_platform/market/pyproject.toml
!autogpt_platform/market/poetry.lock
!autogpt_platform/market/README.md
# Platform - Frontend
!autogpt_platform/frontend/
!autogpt_platform/frontend/src/
!autogpt_platform/frontend/public/
!autogpt_platform/frontend/scripts/
!autogpt_platform/frontend/package.json
!autogpt_platform/frontend/pnpm-lock.yaml
!autogpt_platform/frontend/tsconfig.json
!autogpt_platform/frontend/README.md
## config
!autogpt_platform/frontend/*.config.*
!autogpt_platform/frontend/.env.*
!autogpt_platform/frontend/.env
# Classic - AutoGPT
!classic/original_autogpt/autogpt/
@@ -35,38 +64,6 @@
# Classic - Frontend
!classic/frontend/build/web/
# Explicitly re-ignore unwanted files from whitelisted directories
# Note: These patterns MUST come after the whitelist rules to take effect
# Hidden files and directories (but keep frontend .env files needed for build)
**/.*
!autogpt_platform/frontend/.env
!autogpt_platform/frontend/.env.default
!autogpt_platform/frontend/.env.production
# Python artifacts
**/__pycache__/
**/*.pyc
**/*.pyo
**/.venv/
**/.ruff_cache/
**/.pytest_cache/
**/.coverage
**/htmlcov/
# Node artifacts
**/node_modules/
**/.next/
**/storybook-static/
**/playwright-report/
**/test-results/
# Build artifacts
**/dist/
**/build/
!autogpt_platform/frontend/src/**/build/
**/target/
# Logs and temp files
**/*.log
**/*.tmp
# Explicitly re-ignore some folders
.*
**/__pycache__

File diff suppressed because it is too large Load Diff

View File

@@ -49,7 +49,7 @@ jobs:
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
if: github.event_name == 'push'
uses: peter-evans/create-pull-request@v8
uses: peter-evans/create-pull-request@v7
with:
add-paths: classic/frontend/build/web
base: ${{ github.ref_name }}

View File

@@ -22,7 +22,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
ref: ${{ github.event.workflow_run.head_branch }}
fetch-depth: 0
@@ -40,51 +40,9 @@ jobs:
git checkout -b "$BRANCH_NAME"
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
# Backend Python/Poetry setup (so Claude can run linting/tests)
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v5
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
- name: Install Poetry
run: |
cd autogpt_platform/backend
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
echo "$HOME/.local/bin" >> $GITHUB_PATH
- name: Install Python dependencies
working-directory: autogpt_platform/backend
run: poetry install
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
# Frontend Node.js/pnpm setup (so Claude can run linting/tests)
- name: Enable corepack
run: corepack enable
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: "22"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Install JavaScript dependencies
working-directory: autogpt_platform/frontend
run: pnpm install --frozen-lockfile
- name: Get CI failure details
id: failure_details
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
const run = await github.rest.actions.getWorkflowRun({

View File

@@ -30,7 +30,7 @@ jobs:
actions: read # Required for CI access
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 1
@@ -41,7 +41,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -77,15 +77,27 @@ jobs:
run: poetry run prisma generate && poetry run gen-prisma-stub
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22"
- name: Enable corepack
run: corepack enable
- name: Set up Node.js
uses: actions/setup-node@v6
- name: Set pnpm store directory
run: |
pnpm config set store-dir ~/.pnpm-store
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v4
with:
node-version: "22"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install JavaScript dependencies
working-directory: autogpt_platform/frontend
@@ -112,7 +124,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache
id: docker-cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes
@@ -297,7 +309,6 @@ jobs:
uses: anthropics/claude-code-action@v1
with:
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
allowed_bots: "dependabot[bot]"
claude_args: |
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
prompt: |

View File

@@ -40,7 +40,7 @@ jobs:
actions: read # Required for CI access
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 1
@@ -57,7 +57,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -93,15 +93,27 @@ jobs:
run: poetry run prisma generate && poetry run gen-prisma-stub
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22"
- name: Enable corepack
run: corepack enable
- name: Set up Node.js
uses: actions/setup-node@v6
- name: Set pnpm store directory
run: |
pnpm config set store-dir ~/.pnpm-store
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v4
with:
node-version: "22"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install JavaScript dependencies
working-directory: autogpt_platform/frontend
@@ -128,7 +140,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache
id: docker-cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes

View File

@@ -58,11 +58,11 @@ jobs:
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v4
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
@@ -93,6 +93,6 @@ jobs:
exit 1
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v4
uses: github/codeql-action/analyze@v3
with:
category: "/language:${{matrix.language}}"

View File

@@ -27,7 +27,7 @@ jobs:
# If you do not check out your code, Copilot will do this for you.
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
@@ -39,7 +39,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -76,7 +76,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22"
@@ -89,7 +89,7 @@ jobs:
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
@@ -132,7 +132,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache
id: docker-cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes

View File

@@ -23,7 +23,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 1
@@ -33,7 +33,7 @@ jobs:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -7,10 +7,6 @@ on:
- "docs/integrations/**"
- "autogpt_platform/backend/backend/blocks/**"
concurrency:
group: claude-docs-review-${{ github.event.pull_request.number }}
cancel-in-progress: true
jobs:
claude-review:
# Only run for PRs from members/collaborators
@@ -27,7 +23,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
@@ -37,7 +33,7 @@ jobs:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -95,35 +91,5 @@ jobs:
3. Read corresponding documentation files to verify accuracy
4. Provide your feedback as a PR comment
## IMPORTANT: Comment Marker
Start your PR comment with exactly this HTML comment marker on its own line:
<!-- CLAUDE_DOCS_REVIEW -->
This marker is used to identify and replace your comment on subsequent runs.
Be constructive and specific. If everything looks good, say so!
If there are issues, explain what's wrong and suggest how to fix it.
- name: Delete old Claude review comments
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Get all comment IDs with our marker, sorted by creation date (oldest first)
COMMENT_IDS=$(gh api \
repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments \
--jq '[.[] | select(.body | contains("<!-- CLAUDE_DOCS_REVIEW -->"))] | sort_by(.created_at) | .[].id')
# Count comments
COMMENT_COUNT=$(echo "$COMMENT_IDS" | grep -c . || true)
if [ "$COMMENT_COUNT" -gt 1 ]; then
# Delete all but the last (newest) comment
echo "$COMMENT_IDS" | head -n -1 | while read -r COMMENT_ID; do
if [ -n "$COMMENT_ID" ]; then
echo "Deleting old review comment: $COMMENT_ID"
gh api -X DELETE repos/${{ github.repository }}/issues/comments/$COMMENT_ID
fi
done
else
echo "No old review comments to clean up"
fi

View File

@@ -28,7 +28,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 1
@@ -38,7 +38,7 @@ jobs:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -25,7 +25,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
@@ -52,7 +52,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Trigger deploy workflow
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DEPLOY_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure

View File

@@ -17,7 +17,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name || 'master' }}
@@ -45,7 +45,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Trigger deploy workflow
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DEPLOY_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure

View File

@@ -41,18 +41,13 @@ jobs:
ports:
- 6379:6379
rabbitmq:
image: rabbitmq:4.1.4
image: rabbitmq:3.12-management
ports:
- 5672:5672
- 15672:15672
env:
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
options: >-
--health-cmd "rabbitmq-diagnostics -q ping"
--health-interval 30s
--health-timeout 10s
--health-retries 5
--health-start-period 10s
clamav:
image: clamav/clamav-debian:latest
ports:
@@ -73,7 +68,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
@@ -93,7 +88,7 @@ jobs:
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -17,7 +17,7 @@ jobs:
- name: Check comment permissions and deployment status
id: check_status
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
const commentBody = context.payload.comment.body.trim();
@@ -55,7 +55,7 @@ jobs:
- name: Post permission denied comment
if: steps.check_status.outputs.permission_denied == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.createComment({
@@ -68,7 +68,7 @@ jobs:
- name: Get PR details for deployment
id: pr_details
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
const pr = await github.rest.pulls.get({
@@ -82,7 +82,7 @@ jobs:
- name: Dispatch Deploy Event
if: steps.check_status.outputs.should_deploy == 'true'
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
@@ -98,7 +98,7 @@ jobs:
- name: Post deploy success comment
if: steps.check_status.outputs.should_deploy == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.createComment({
@@ -110,7 +110,7 @@ jobs:
- name: Dispatch Undeploy Event (from comment)
if: steps.check_status.outputs.should_undeploy == 'true'
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
@@ -126,7 +126,7 @@ jobs:
- name: Post undeploy success comment
if: steps.check_status.outputs.should_undeploy == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.createComment({
@@ -139,7 +139,7 @@ jobs:
- name: Check deployment status on PR close
id: check_pr_close
if: github.event_name == 'pull_request' && github.event.action == 'closed'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
const comments = await github.rest.issues.listComments({
@@ -168,7 +168,7 @@ jobs:
github.event_name == 'pull_request' &&
github.event.action == 'closed' &&
steps.check_pr_close.outputs.should_undeploy == 'true'
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
@@ -187,7 +187,7 @@ jobs:
github.event_name == 'pull_request' &&
github.event.action == 'closed' &&
steps.check_pr_close.outputs.should_undeploy == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.createComment({

View File

@@ -6,16 +6,10 @@ on:
paths:
- ".github/workflows/platform-frontend-ci.yml"
- "autogpt_platform/frontend/**"
- "autogpt_platform/backend/Dockerfile"
- "autogpt_platform/docker-compose.yml"
- "autogpt_platform/docker-compose.platform.yml"
pull_request:
paths:
- ".github/workflows/platform-frontend-ci.yml"
- "autogpt_platform/frontend/**"
- "autogpt_platform/backend/Dockerfile"
- "autogpt_platform/docker-compose.yml"
- "autogpt_platform/docker-compose.platform.yml"
merge_group:
workflow_dispatch:
@@ -32,31 +26,34 @@ jobs:
setup:
runs-on: ubuntu-latest
outputs:
components-changed: ${{ steps.filter.outputs.components }}
cache-key: ${{ steps.cache-key.outputs.key }}
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Check for component changes
uses: dorny/paths-filter@v3
id: filter
- name: Set up Node.js
uses: actions/setup-node@v4
with:
filters: |
components:
- 'autogpt_platform/frontend/src/components/**'
node-version: "22.18.0"
- name: Enable corepack
run: corepack enable
- name: Set up Node
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Generate cache key
id: cache-key
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Install dependencies to populate cache
- name: Cache dependencies
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
lint:
@@ -65,17 +62,24 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
- name: Enable corepack
run: corepack enable
- name: Set up Node
uses: actions/setup-node@v6
- name: Restore dependencies cache
uses: actions/cache@v4
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
@@ -86,27 +90,31 @@ jobs:
chromatic:
runs-on: ubuntu-latest
needs: setup
# Disabled: to re-enable, remove 'false &&' from the condition below
if: >-
false
&& (github.ref == 'refs/heads/dev' || github.base_ref == 'dev')
&& needs.setup.outputs.components-changed == 'true'
# Only run on dev branch pushes or PRs targeting dev
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
- name: Enable corepack
run: corepack enable
- name: Set up Node
uses: actions/setup-node@v6
- name: Restore dependencies cache
uses: actions/cache@v4
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
@@ -121,20 +129,30 @@ jobs:
exitOnceUploaded: true
e2e_test:
name: end-to-end tests
runs-on: big-boi
needs: setup
strategy:
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Platform - Copy default supabase .env
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
- name: Enable corepack
run: corepack enable
- name: Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Set up Platform - Copy backend .env and set OpenAI API key
- name: Copy backend .env and set OpenAI API key
run: |
cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
@@ -142,125 +160,77 @@ jobs:
# Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Set up Platform - Set up Docker Buildx
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Cache Docker layers
uses: actions/cache@v4
with:
driver: docker-container
driver-opts: network=host
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
restore-keys: |
${{ runner.os }}-buildx-frontend-test-
- name: Set up Platform - Expose GHA cache to docker buildx CLI
uses: crazy-max/ghaction-github-runtime@v3
- name: Set up Platform - Build Docker images (with cache)
working-directory: autogpt_platform
- name: Run docker compose
run: |
pip install pyyaml
# Resolve extends and generate a flat compose file that bake can understand
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
# Add cache configuration to the resolved compose file
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
--source docker-compose.resolved.yml \
--cache-from "type=gha" \
--cache-to "type=gha,mode=max" \
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
--git-ref "${{ github.ref }}"
# Build with bake using the resolved compose file (now includes cache config)
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
env:
NEXT_PUBLIC_PW_TEST: true
DOCKER_BUILDKIT: 1
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max
- name: Set up tests - Cache E2E test data
id: e2e-data-cache
uses: actions/cache@v5
with:
path: /tmp/e2e_test_data.sql
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
- name: Set up Platform - Start Supabase DB + Auth
- name: Move cache
run: |
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
echo "Waiting for auth service to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
rm -rf /tmp/.buildx-cache
if [ -d "/tmp/.buildx-cache-new" ]; then
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
fi
- name: Set up Platform - Run migrations
- name: Wait for services to be ready
run: |
echo "Running migrations..."
docker compose -f ../docker-compose.resolved.yml run --rm migrate
echo "✅ Migrations completed"
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Load cached E2E test data
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
run: |
echo "✅ Found cached E2E test data, restoring..."
{
echo "SET session_replication_role = 'replica';"
cat /tmp/e2e_test_data.sql
echo "SET session_replication_role = 'origin';"
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
# Refresh materialized views after restore
docker compose -f ../docker-compose.resolved.yml exec -T db \
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
echo "✅ E2E test data restored from cache"
- name: Set up Platform - Start (all other services)
run: |
docker compose -f ../docker-compose.resolved.yml up -d --no-build
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
env:
NEXT_PUBLIC_PW_TEST: true
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
- name: Set up tests - Create E2E test data
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
- name: Create E2E test data
run: |
echo "Creating E2E test data..."
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
exit 1
}
# First try to run the script from inside the container
if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then
echo "✅ Found e2e_test_data.py in container, running it..."
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
exit 1
}
else
echo "⚠️ e2e_test_data.py not found in container, copying and running..."
# Copy the script into the container and run it
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || {
echo "❌ Failed to copy script to container"
exit 1
}
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
exit 1
}
fi
# Dump auth.users + platform schema for cache (two separate dumps)
echo "Dumping database for cache..."
{
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--table='auth.users' postgres
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--schema=platform \
--exclude-table='platform._prisma_migrations' \
--exclude-table='platform.apscheduler_jobs' \
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
postgres
} > /tmp/e2e_test_data.sql
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
- name: Set up tests - Enable corepack
run: corepack enable
- name: Set up tests - Set up Node
uses: actions/setup-node@v6
- name: Restore dependencies cache
uses: actions/cache@v4
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Set up tests - Install dependencies
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Set up tests - Install browser 'chromium'
- name: Install Browser 'chromium'
run: pnpm playwright install --with-deps chromium
- name: Run Playwright tests
@@ -287,7 +257,7 @@ jobs:
- name: Print Final Docker Compose logs
if: always()
run: docker compose -f ../docker-compose.resolved.yml logs
run: docker compose -f ../docker-compose.yml logs
integration_test:
runs-on: ubuntu-latest
@@ -295,19 +265,26 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
- name: Enable corepack
run: corepack enable
- name: Set up Node
uses: actions/setup-node@v6
- name: Restore dependencies cache
uses: actions/cache@v4
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile

View File

@@ -29,10 +29,10 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -44,7 +44,7 @@ jobs:
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
@@ -56,19 +56,19 @@ jobs:
run: pnpm install --frozen-lockfile
types:
runs-on: big-boi
runs-on: ubuntu-latest
needs: setup
strategy:
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -85,10 +85,10 @@ jobs:
- name: Run docker compose
run: |
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
- name: Restore dependencies cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}

View File

@@ -1,39 +0,0 @@
name: PR Overlap Detection
on:
pull_request:
types: [opened, synchronize, reopened]
branches:
- dev
- master
permissions:
contents: read
pull-requests: write
jobs:
check-overlaps:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0 # Need full history for merge testing
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Configure git
run: |
git config user.email "github-actions[bot]@users.noreply.github.com"
git config user.name "github-actions[bot]"
- name: Run overlap detection
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# Always succeed - this check informs contributors, it shouldn't block merging
continue-on-error: true
run: |
python .github/scripts/detect_overlaps.py ${{ github.event.pull_request.number }}

View File

@@ -11,7 +11,7 @@ jobs:
steps:
# - name: Wait some time for all actions to start
# run: sleep 30
- uses: actions/checkout@v6
- uses: actions/checkout@v4
# with:
# fetch-depth: 0
- name: Set up Python

View File

@@ -1,195 +0,0 @@
#!/usr/bin/env python3
"""
Add cache configuration to a resolved docker-compose file for all services
that have a build key, and ensure image names match what docker compose expects.
"""
import argparse
import yaml
DEFAULT_BRANCH = "dev"
CACHE_BUILDS_FOR_COMPONENTS = ["backend", "frontend"]
def main():
parser = argparse.ArgumentParser(
description="Add cache config to a resolved compose file"
)
parser.add_argument(
"--source",
required=True,
help="Source compose file to read (should be output of `docker compose config`)",
)
parser.add_argument(
"--cache-from",
default="type=gha",
help="Cache source configuration",
)
parser.add_argument(
"--cache-to",
default="type=gha,mode=max",
help="Cache destination configuration",
)
for component in CACHE_BUILDS_FOR_COMPONENTS:
parser.add_argument(
f"--{component}-hash",
default="",
help=f"Hash for {component} cache scope (e.g., from hashFiles())",
)
parser.add_argument(
"--git-ref",
default="",
help="Git ref for branch-based cache scope (e.g., refs/heads/master)",
)
args = parser.parse_args()
# Normalize git ref to a safe scope name (e.g., refs/heads/master -> master)
git_ref_scope = ""
if args.git_ref:
git_ref_scope = args.git_ref.replace("refs/heads/", "").replace("/", "-")
with open(args.source, "r") as f:
compose = yaml.safe_load(f)
# Get project name from compose file or default
project_name = compose.get("name", "autogpt_platform")
def get_image_name(dockerfile: str, target: str) -> str:
"""Generate image name based on Dockerfile folder and build target."""
dockerfile_parts = dockerfile.replace("\\", "/").split("/")
if len(dockerfile_parts) >= 2:
folder_name = dockerfile_parts[-2] # e.g., "backend" or "frontend"
else:
folder_name = "app"
return f"{project_name}-{folder_name}:{target}"
def get_build_key(dockerfile: str, target: str) -> str:
"""Generate a unique key for a Dockerfile+target combination."""
return f"{dockerfile}:{target}"
def get_component(dockerfile: str) -> str | None:
"""Get component name (frontend/backend) from dockerfile path."""
for component in CACHE_BUILDS_FOR_COMPONENTS:
if component in dockerfile:
return component
return None
# First pass: collect all services with build configs and identify duplicates
# Track which (dockerfile, target) combinations we've seen
build_key_to_first_service: dict[str, str] = {}
services_to_build: list[str] = []
services_to_dedupe: list[str] = []
for service_name, service_config in compose.get("services", {}).items():
if "build" not in service_config:
continue
build_config = service_config["build"]
dockerfile = build_config.get("dockerfile", "Dockerfile")
target = build_config.get("target", "default")
build_key = get_build_key(dockerfile, target)
if build_key not in build_key_to_first_service:
# First service with this build config - it will do the actual build
build_key_to_first_service[build_key] = service_name
services_to_build.append(service_name)
else:
# Duplicate - will just use the image from the first service
services_to_dedupe.append(service_name)
# Second pass: configure builds and deduplicate
modified_services = []
for service_name, service_config in compose.get("services", {}).items():
if "build" not in service_config:
continue
build_config = service_config["build"]
dockerfile = build_config.get("dockerfile", "Dockerfile")
target = build_config.get("target", "latest")
image_name = get_image_name(dockerfile, target)
# Set image name for all services (needed for both builders and deduped)
service_config["image"] = image_name
if service_name in services_to_dedupe:
# Remove build config - this service will use the pre-built image
del service_config["build"]
continue
# This service will do the actual build - add cache config
cache_from_list = []
cache_to_list = []
component = get_component(dockerfile)
if not component:
# Skip services that don't clearly match frontend/backend
continue
# Get the hash for this component
component_hash = getattr(args, f"{component}_hash")
# Scope format: platform-{component}-{target}-{hash|ref}
# Example: platform-backend-server-abc123
if "type=gha" in args.cache_from:
# 1. Primary: exact hash match (most specific)
if component_hash:
hash_scope = f"platform-{component}-{target}-{component_hash}"
cache_from_list.append(f"{args.cache_from},scope={hash_scope}")
# 2. Fallback: branch-based cache
if git_ref_scope:
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
cache_from_list.append(f"{args.cache_from},scope={ref_scope}")
# 3. Fallback: dev branch cache (for PRs/feature branches)
if git_ref_scope and git_ref_scope != DEFAULT_BRANCH:
master_scope = f"platform-{component}-{target}-{DEFAULT_BRANCH}"
cache_from_list.append(f"{args.cache_from},scope={master_scope}")
if "type=gha" in args.cache_to:
# Write to both hash-based and branch-based scopes
if component_hash:
hash_scope = f"platform-{component}-{target}-{component_hash}"
cache_to_list.append(f"{args.cache_to},scope={hash_scope}")
if git_ref_scope:
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
cache_to_list.append(f"{args.cache_to},scope={ref_scope}")
# Ensure we have at least one cache source/target
if not cache_from_list:
cache_from_list.append(args.cache_from)
if not cache_to_list:
cache_to_list.append(args.cache_to)
build_config["cache_from"] = cache_from_list
build_config["cache_to"] = cache_to_list
modified_services.append(service_name)
# Write back to the same file
with open(args.source, "w") as f:
yaml.dump(compose, f, default_flow_style=False, sort_keys=False)
print(f"Added cache config to {len(modified_services)} services in {args.source}:")
for svc in modified_services:
svc_config = compose["services"][svc]
build_cfg = svc_config.get("build", {})
cache_from_list = build_cfg.get("cache_from", ["none"])
cache_to_list = build_cfg.get("cache_to", ["none"])
print(f" - {svc}")
print(f" image: {svc_config.get('image', 'N/A')}")
print(f" cache_from: {cache_from_list}")
print(f" cache_to: {cache_to_list}")
if services_to_dedupe:
print(
f"Deduplicated {len(services_to_dedupe)} services (will use pre-built images):"
)
for svc in services_to_dedupe:
print(f" - {svc} -> {compose['services'][svc].get('image', 'N/A')}")
if __name__ == "__main__":
main()

1
.gitignore vendored
View File

@@ -180,4 +180,3 @@ autogpt_platform/backend/settings.py
.claude/settings.local.json
CLAUDE.local.md
/autogpt_platform/backend/logs
.next

View File

@@ -45,11 +45,6 @@ AutoGPT Platform is a monorepo containing:
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
### Branching Strategy
- **`dev`** is the main development branch. All PRs should target `dev`.
- **`master`** is the production branch. Only used for production releases.
### Creating Pull Requests
- Create the PR against the `dev` branch of the repository.

File diff suppressed because it is too large Load Diff

View File

@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
[tool.poetry.dependencies]
python = ">=3.10,<4.0"
colorama = "^0.4.6"
cryptography = "^46.0"
cryptography = "^45.0"
expiringdict = "^1.2.2"
fastapi = "^0.128.7"
google-cloud-logging = "^3.13.0"
launchdarkly-server-sdk = "^9.15.0"
pydantic = "^2.12.5"
pydantic-settings = "^2.12.0"
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
fastapi = "^0.116.1"
google-cloud-logging = "^3.12.1"
launchdarkly-server-sdk = "^9.12.0"
pydantic = "^2.11.7"
pydantic-settings = "^2.10.1"
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
redis = "^6.2.0"
supabase = "^2.28.0"
uvicorn = "^0.40.0"
supabase = "^2.16.0"
uvicorn = "^0.35.0"
[tool.poetry.group.dev.dependencies]
pyright = "^1.1.408"
pyright = "^1.1.404"
pytest = "^8.4.1"
pytest-asyncio = "^1.3.0"
pytest-mock = "^3.15.1"
pytest-cov = "^7.0.0"
ruff = "^0.15.0"
pytest-asyncio = "^1.1.0"
pytest-mock = "^3.14.1"
pytest-cov = "^6.2.1"
ruff = "^0.12.11"
[build-system]
requires = ["poetry-core"]

View File

@@ -104,12 +104,6 @@ TWITTER_CLIENT_SECRET=
# Make a new workspace for your OAuth APP -- trust me
# https://linear.app/settings/api/applications/new
# Callback URL: http://localhost:3000/auth/integrations/oauth_callback
LINEAR_API_KEY=
# Linear project and team IDs for the feature request tracker.
# Find these in your Linear workspace URL: linear.app/<workspace>/project/<project-id>
# and in team settings. Used by the chat copilot to file and search feature requests.
LINEAR_FEATURE_REQUEST_PROJECT_ID=
LINEAR_FEATURE_REQUEST_TEAM_ID=
LINEAR_CLIENT_ID=
LINEAR_CLIENT_SECRET=
@@ -158,7 +152,6 @@ REPLICATE_API_KEY=
REVID_API_KEY=
SCREENSHOTONE_API_KEY=
UNREAL_SPEECH_API_KEY=
ELEVENLABS_API_KEY=
# Data & Search Services
E2B_API_KEY=

View File

@@ -19,6 +19,3 @@ load-tests/*.json
load-tests/*.log
load-tests/node_modules/*
migrations/*/rollback*.sql
# Workspace files
workspaces/

View File

@@ -1,5 +1,3 @@
# ============================ DEPENDENCY BUILDER ============================ #
FROM debian:13-slim AS builder
# Set environment variables
@@ -53,62 +51,25 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
RUN poetry run prisma generate && poetry run gen-prisma-stub
# =============================== DB MIGRATOR =============================== #
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
FROM debian:13-slim AS migrate
WORKDIR /app/autogpt_platform/backend
ENV DEBIAN_FRONTEND=noninteractive
# Install only what's needed for prisma migrate: Node.js and minimal Python for prisma-python
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.13 \
python3-pip \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# Copy Node.js from builder (needed for Prisma CLI)
COPY --from=builder /usr/bin/node /usr/bin/node
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
COPY --from=builder /usr/bin/npm /usr/bin/npm
# Copy Prisma binaries
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
# Install prisma-client-py directly (much smaller than copying full venv)
RUN pip3 install prisma>=0.15.0 --break-system-packages
COPY autogpt_platform/backend/schema.prisma ./
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
COPY autogpt_platform/backend/migrations ./migrations
# ============================== BACKEND SERVER ============================== #
FROM debian:13-slim AS server
FROM debian:13-slim AS server_dependencies
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
ENV POETRY_HOME=/opt/poetry \
POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_CREATE=true \
POETRY_VIRTUALENVS_IN_PROJECT=true \
DEBIAN_FRONTEND=noninteractive
ENV PATH=/opt/poetry/bin:$PATH
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
# for the bash_exec MCP tool.
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
RUN apt-get update && apt-get install -y --no-install-recommends \
# Install Python without upgrading system-managed packages
RUN apt-get update && apt-get install -y \
python3.13 \
python3-pip \
ffmpeg \
imagemagick \
jq \
ripgrep \
tree \
bubblewrap \
&& rm -rf /var/lib/apt/lists/*
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
# Copy only necessary files from builder
COPY --from=builder /app /app
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
# Copy Node.js installation for Prisma
@@ -118,25 +79,30 @@ COPY --from=builder /usr/bin/npm /usr/bin/npm
COPY --from=builder /usr/bin/npx /usr/bin/npx
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
WORKDIR /app/autogpt_platform/backend
# Copy only the .venv from builder (not the entire /app directory)
# The .venv includes the generated Prisma client
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
# Copy dependency files + autogpt_libs (path dependency)
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
RUN mkdir -p /app/autogpt_platform/autogpt_libs
RUN mkdir -p /app/autogpt_platform/backend
# Copy backend code + docs (for Copilot docs search)
COPY autogpt_platform/backend ./
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
WORKDIR /app/autogpt_platform/backend
FROM server_dependencies AS migrate
# Migration stage only needs schema and migrations - much lighter than full backend
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
FROM server_dependencies AS server
COPY autogpt_platform/backend /app/autogpt_platform/backend
COPY docs /app/docs
# Install the project package to create entry point scripts in .venv/bin/
# (e.g., rest, executor, ws, db, scheduler, notification - see [tool.poetry.scripts])
RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
poetry install --no-ansi --only-root
RUN poetry install --no-ansi --only-root
ENV PORT=8000
CMD ["rest"]
CMD ["poetry", "run", "rest"]

View File

@@ -1,9 +1,4 @@
"""Common test fixtures for server tests.
Note: Common fixtures like test_user_id, admin_user_id, target_user_id,
setup_test_user, and setup_admin_user are defined in the parent conftest.py
(backend/conftest.py) and are available here automatically.
"""
"""Common test fixtures for server tests."""
import pytest
from pytest_snapshot.plugin import Snapshot
@@ -16,6 +11,54 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
return snapshot
@pytest.fixture
def test_user_id() -> str:
"""Test user ID fixture."""
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture
def admin_user_id() -> str:
"""Admin user ID fixture."""
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
@pytest.fixture
def target_user_id() -> str:
"""Target user ID fixture."""
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
@pytest.fixture
async def setup_test_user(test_user_id):
"""Create test user in database before tests."""
from backend.data.user import get_or_create_user
# Create the test user in the database using JWT token format
user_data = {
"sub": test_user_id,
"email": "test@example.com",
"user_metadata": {"name": "Test User"},
}
await get_or_create_user(user_data)
return test_user_id
@pytest.fixture
async def setup_admin_user(admin_user_id):
"""Create admin user in database before tests."""
from backend.data.user import get_or_create_user
# Create the admin user in the database using JWT token format
user_data = {
"sub": admin_user_id,
"email": "test-admin@example.com",
"user_metadata": {"name": "Test Admin"},
}
await get_or_create_user(user_data)
return admin_user_id
@pytest.fixture
def mock_jwt_user(test_user_id):
"""Provide mock JWT payload for regular user testing."""

View File

@@ -10,7 +10,7 @@ from typing_extensions import TypedDict
import backend.api.features.store.cache as store_cache
import backend.api.features.store.model as store_model
import backend.blocks
import backend.data.block
from backend.api.external.middleware import require_permission
from backend.data import execution as execution_db
from backend.data import graph as graph_db
@@ -67,7 +67,7 @@ async def get_user_info(
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
)
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.blocks.get_blocks().values()]
blocks = [block() for block in backend.data.block.get_blocks().values()]
return [b.to_dict() for b in blocks if not b.disabled]
@@ -83,7 +83,7 @@ async def execute_graph_block(
require_permission(APIKeyPermission.EXECUTE_BLOCK)
),
) -> CompletedBlockOutput:
obj = backend.blocks.get_block(block_id)
obj = backend.data.block.get_block(block_id)
if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
if obj.disabled:

View File

@@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.api.external.middleware import require_permission
from backend.copilot.model import ChatSession
from backend.copilot.tools import find_agent_tool, run_agent_tool
from backend.copilot.tools.models import ToolResponseBase
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
from backend.api.features.chat.tools.models import ToolResponseBase
from backend.data.auth.base import APIAuthorizationInfo
logger = logging.getLogger(__name__)

View File

@@ -10,15 +10,10 @@ import backend.api.features.library.db as library_db
import backend.api.features.library.model as library_model
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
import backend.data.block
from backend.blocks import load_all_blocks
from backend.blocks._base import (
AnyBlockSchema,
BlockCategory,
BlockInfo,
BlockSchema,
BlockType,
)
from backend.blocks.llm import LlmModel
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
from backend.data.db import query_raw_with_schema
from backend.integrations.providers import ProviderName
from backend.util.cache import cached
@@ -27,7 +22,7 @@ from backend.util.models import Pagination
from .model import (
BlockCategoryResponse,
BlockResponse,
BlockTypeFilter,
BlockType,
CountResponse,
FilterType,
Provider,
@@ -93,7 +88,7 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
def get_blocks(
*,
category: str | None = None,
type: BlockTypeFilter | None = None,
type: BlockType | None = None,
provider: ProviderName | None = None,
page: int = 1,
page_size: int = 50,
@@ -674,9 +669,9 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
if block.disabled or block.block_type in (
BlockType.INPUT,
BlockType.OUTPUT,
BlockType.AGENT,
backend.data.block.BlockType.INPUT,
backend.data.block.BlockType.OUTPUT,
backend.data.block.BlockType.AGENT,
):
continue
# Find the execution count for this block

View File

@@ -4,7 +4,7 @@ from pydantic import BaseModel
import backend.api.features.library.model as library_model
import backend.api.features.store.model as store_model
from backend.blocks._base import BlockInfo
from backend.data.block import BlockInfo
from backend.integrations.providers import ProviderName
from backend.util.models import Pagination
@@ -15,7 +15,7 @@ FilterType = Literal[
"my_agents",
]
BlockTypeFilter = Literal["all", "input", "action", "output"]
BlockType = Literal["all", "input", "action", "output"]
class SearchEntry(BaseModel):

View File

@@ -88,7 +88,7 @@ async def get_block_categories(
)
async def get_blocks(
category: Annotated[str | None, fastapi.Query()] = None,
type: Annotated[builder_model.BlockTypeFilter | None, fastapi.Query()] = None,
type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None,
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
page: Annotated[int, fastapi.Query()] = 1,
page_size: Annotated[int, fastapi.Query()] = 50,

View File

@@ -0,0 +1,96 @@
"""Configuration management for chat system."""
import os
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
# OpenAI API Configuration
model: str = Field(
default="anthropic/claude-opus-4.5", description="Default model to use"
)
title_model: str = Field(
default="openai/gpt-4o-mini",
description="Model to use for generating session titles (should be fast/cheap)",
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
default="https://openrouter.ai/api/v1",
description="Base URL for API (e.g., for OpenRouter)",
)
# Session TTL Configuration - 12 hours
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration
max_context_messages: int = Field(
default=50, ge=1, le=200, description="Maximum context messages"
)
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(default=3, description="Maximum number of retries")
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=30, description="Maximum number of agent schedules"
)
# Long-running operation configuration
long_running_operation_ttl: int = Field(
default=600,
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
)
# Langfuse Prompt Management Configuration
# Note: Langfuse credentials are in Settings().secrets (settings.py)
langfuse_prompt_name: str = Field(
default="CoPilot Prompt",
description="Name of the prompt in Langfuse to fetch",
)
@field_validator("api_key", mode="before")
@classmethod
def get_api_key(cls, v):
"""Get API key from environment if not provided."""
if v is None:
# Try to get from environment variables
# First check for CHAT_API_KEY (Pydantic prefix)
v = os.getenv("CHAT_API_KEY")
if not v:
# Fall back to OPEN_ROUTER_API_KEY
v = os.getenv("OPEN_ROUTER_API_KEY")
if not v:
# Fall back to OPENAI_API_KEY
v = os.getenv("OPENAI_API_KEY")
return v
@field_validator("base_url", mode="before")
@classmethod
def get_base_url(cls, v):
"""Get base URL from environment if not provided."""
if v is None:
# Check for OpenRouter or custom base URL
v = os.getenv("CHAT_BASE_URL")
if not v:
v = os.getenv("OPENROUTER_BASE_URL")
if not v:
v = os.getenv("OPENAI_BASE_URL")
if not v:
v = "https://openrouter.ai/api/v1"
return v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",
"onboarding": "prompts/onboarding_system.md",
}
class Config:
"""Pydantic config."""
env_file = ".env"
env_file_encoding = "utf-8"
extra = "ignore" # Ignore extra environment variables

View File

@@ -14,27 +14,29 @@ from prisma.types import (
ChatSessionWhereInput,
)
from backend.data import db
from backend.data.db import transaction
from backend.util.json import SafeJson
from .model import ChatMessage, ChatSession, ChatSessionInfo
logger = logging.getLogger(__name__)
async def get_chat_session(session_id: str) -> ChatSession | None:
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
"""Get a chat session by ID from the database."""
session = await PrismaChatSession.prisma().find_unique(
where={"id": session_id},
include={"Messages": {"order_by": {"sequence": "asc"}}},
include={"Messages": True},
)
return ChatSession.from_db(session) if session else None
if session and session.Messages:
# Sort messages by sequence in Python - Prisma Python client doesn't support
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
session.Messages.sort(key=lambda m: m.sequence)
return session
async def create_chat_session(
session_id: str,
user_id: str,
) -> ChatSessionInfo:
) -> PrismaChatSession:
"""Create a new chat session in the database."""
data = ChatSessionCreateInput(
id=session_id,
@@ -43,8 +45,10 @@ async def create_chat_session(
successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}),
)
prisma_session = await PrismaChatSession.prisma().create(data=data)
return ChatSessionInfo.from_db(prisma_session)
return await PrismaChatSession.prisma().create(
data=data,
include={"Messages": True},
)
async def update_chat_session(
@@ -55,7 +59,7 @@ async def update_chat_session(
total_prompt_tokens: int | None = None,
total_completion_tokens: int | None = None,
title: str | None = None,
) -> ChatSession | None:
) -> PrismaChatSession | None:
"""Update a chat session's metadata."""
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
@@ -75,9 +79,12 @@ async def update_chat_session(
session = await PrismaChatSession.prisma().update(
where={"id": session_id},
data=data,
include={"Messages": {"order_by": {"sequence": "asc"}}},
include={"Messages": True},
)
return ChatSession.from_db(session) if session else None
if session and session.Messages:
# Sort in Python - Prisma Python doesn't support order_by in include clauses
session.Messages.sort(key=lambda m: m.sequence)
return session
async def add_chat_message(
@@ -90,7 +97,7 @@ async def add_chat_message(
refusal: str | None = None,
tool_calls: list[dict[str, Any]] | None = None,
function_call: dict[str, Any] | None = None,
) -> ChatMessage:
) -> PrismaChatMessage:
"""Add a message to a chat session."""
# Build input dict dynamically rather than using ChatMessageCreateInput directly
# because Prisma's TypedDict validation rejects optional fields set to None.
@@ -125,14 +132,14 @@ async def add_chat_message(
),
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
)
return ChatMessage.from_db(message)
return message
async def add_chat_messages_batch(
session_id: str,
messages: list[dict[str, Any]],
start_sequence: int,
) -> list[ChatMessage]:
) -> list[PrismaChatMessage]:
"""Add multiple messages to a chat session in a batch.
Uses a transaction for atomicity - if any message creation fails,
@@ -143,7 +150,7 @@ async def add_chat_messages_batch(
created_messages = []
async with db.transaction() as tx:
async with transaction() as tx:
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
@@ -183,22 +190,21 @@ async def add_chat_messages_batch(
data={"updatedAt": datetime.now(UTC)},
)
return [ChatMessage.from_db(m) for m in created_messages]
return created_messages
async def get_user_chat_sessions(
user_id: str,
limit: int = 50,
offset: int = 0,
) -> list[ChatSessionInfo]:
) -> list[PrismaChatSession]:
"""Get chat sessions for a user, ordered by most recent."""
prisma_sessions = await PrismaChatSession.prisma().find_many(
return await PrismaChatSession.prisma().find_many(
where={"userId": user_id},
order={"updatedAt": "desc"},
take=limit,
skip=offset,
)
return [ChatSessionInfo.from_db(s) for s in prisma_sessions]
async def get_user_session_count(user_id: str) -> int:

View File

@@ -2,7 +2,7 @@ import asyncio
import logging
import uuid
from datetime import UTC, datetime
from typing import Any, Self, cast
from typing import Any
from weakref import WeakValueDictionary
from openai.types.chat import (
@@ -23,17 +23,26 @@ from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from pydantic import BaseModel
from backend.data.db_accessors import chat_db
from backend.data.redis_client import get_redis_async
from backend.util import json
from backend.util.exceptions import DatabaseError, RedisError
from . import db as chat_db
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
"""Parse a JSON field that may be stored as string or already parsed."""
if value is None:
return default
if isinstance(value, str):
return json.loads(value)
return value
# Redis cache key prefix for chat sessions
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
@@ -43,7 +52,28 @@ def _get_session_cache_key(session_id: str) -> str:
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
# ===================== Chat data models ===================== #
# Session-level locks to prevent race conditions during concurrent upserts.
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks.
"""
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock
class ChatMessage(BaseModel):
@@ -55,19 +85,6 @@ class ChatMessage(BaseModel):
tool_calls: list[dict] | None = None
function_call: dict | None = None
@staticmethod
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
"""Convert a Prisma ChatMessage to a Pydantic ChatMessage."""
return ChatMessage(
role=prisma_message.role,
content=prisma_message.content,
name=prisma_message.name,
tool_call_id=prisma_message.toolCallId,
refusal=prisma_message.refusal,
tool_calls=_parse_json_field(prisma_message.toolCalls),
function_call=_parse_json_field(prisma_message.functionCall),
)
class Usage(BaseModel):
prompt_tokens: int
@@ -75,10 +92,11 @@ class Usage(BaseModel):
total_tokens: int
class ChatSessionInfo(BaseModel):
class ChatSession(BaseModel):
session_id: str
user_id: str
title: str | None = None
messages: list[ChatMessage]
usage: list[Usage]
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
started_at: datetime
@@ -86,9 +104,40 @@ class ChatSessionInfo(BaseModel):
successful_agent_runs: dict[str, int] = {}
successful_agent_schedules: dict[str, int] = {}
@classmethod
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
"""Convert Prisma ChatSession to Pydantic ChatSession."""
@staticmethod
def new(user_id: str) -> "ChatSession":
return ChatSession(
session_id=str(uuid.uuid4()),
user_id=user_id,
title=None,
messages=[],
usage=[],
credentials={},
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
@staticmethod
def from_db(
prisma_session: PrismaChatSession,
prisma_messages: list[PrismaChatMessage] | None = None,
) -> "ChatSession":
"""Convert Prisma models to Pydantic ChatSession."""
messages = []
if prisma_messages:
for msg in prisma_messages:
messages.append(
ChatMessage(
role=msg.role,
content=msg.content,
name=msg.name,
tool_call_id=msg.toolCallId,
refusal=msg.refusal,
tool_calls=_parse_json_field(msg.toolCalls),
function_call=_parse_json_field(msg.functionCall),
)
)
# Parse JSON fields from Prisma
credentials = _parse_json_field(prisma_session.credentials, default={})
successful_agent_runs = _parse_json_field(
@@ -110,10 +159,11 @@ class ChatSessionInfo(BaseModel):
)
)
return cls(
return ChatSession(
session_id=prisma_session.id,
user_id=prisma_session.userId,
title=prisma_session.title,
messages=messages,
usage=usage,
credentials=credentials,
started_at=prisma_session.createdAt,
@@ -122,56 +172,6 @@ class ChatSessionInfo(BaseModel):
successful_agent_schedules=successful_agent_schedules,
)
class ChatSession(ChatSessionInfo):
messages: list[ChatMessage]
@classmethod
def new(cls, user_id: str) -> Self:
return cls(
session_id=str(uuid.uuid4()),
user_id=user_id,
title=None,
messages=[],
usage=[],
credentials={},
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
@classmethod
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
"""Convert Prisma ChatSession to Pydantic ChatSession."""
if prisma_session.Messages is None:
raise ValueError(
f"Prisma session {prisma_session.id} is missing Messages relation"
)
return cls(
**ChatSessionInfo.from_db(prisma_session).model_dump(),
messages=[ChatMessage.from_db(m) for m in prisma_session.Messages],
)
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
"""Attach a tool_call to the current turn's assistant message.
Searches backwards for the most recent assistant message (stopping at
any user message boundary). If found, appends the tool_call to it.
Otherwise creates a new assistant message with the tool_call.
"""
for msg in reversed(self.messages):
if msg.role == "user":
break
if msg.role == "assistant":
if not msg.tool_calls:
msg.tool_calls = []
msg.tool_calls.append(tool_call)
return
self.messages.append(
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
)
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
messages = []
for message in self.messages:
@@ -258,72 +258,43 @@ class ChatSession(ChatSessionInfo):
name=message.name or "",
)
)
return self._merge_consecutive_assistant_messages(messages)
@staticmethod
def _merge_consecutive_assistant_messages(
messages: list[ChatCompletionMessageParam],
) -> list[ChatCompletionMessageParam]:
"""Merge consecutive assistant messages into single messages.
Long-running tool flows can create split assistant messages: one with
text content and another with tool_calls. Anthropic's API requires
tool_result blocks to reference a tool_use in the immediately preceding
assistant message, so these splits cause 400 errors via OpenRouter.
"""
if len(messages) < 2:
return messages
result: list[ChatCompletionMessageParam] = [messages[0]]
for msg in messages[1:]:
prev = result[-1]
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
result.append(msg)
continue
prev = cast(ChatCompletionAssistantMessageParam, prev)
curr = cast(ChatCompletionAssistantMessageParam, msg)
curr_content = curr.get("content") or ""
if curr_content:
prev_content = prev.get("content") or ""
prev["content"] = (
f"{prev_content}\n{curr_content}" if prev_content else curr_content
)
curr_tool_calls = curr.get("tool_calls")
if curr_tool_calls:
prev_tool_calls = prev.get("tool_calls")
prev["tool_calls"] = (
list(prev_tool_calls) + list(curr_tool_calls)
if prev_tool_calls
else list(curr_tool_calls)
)
return result
return messages
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
"""Parse a JSON field that may be stored as string or already parsed."""
if value is None:
return default
if isinstance(value, str):
return json.loads(value)
return value
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
"""Get a chat session from Redis cache."""
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
raw_session: bytes | None = await async_redis.get(redis_key)
if raw_session is None:
return None
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
except Exception as e:
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
raise RedisError(f"Corrupted session data for {session_id}") from e
# ================ Chat cache + DB operations ================ #
# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not
# connected directly.
async def cache_chat_session(session: ChatSession) -> None:
"""Cache a chat session in Redis (without persisting to the database)."""
async def _cache_session(session: ChatSession) -> None:
"""Cache a chat session in Redis."""
redis_key = _get_session_cache_key(session.session_id)
async_redis = await get_redis_async()
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
async def cache_chat_session(session: ChatSession) -> None:
"""Cache a chat session without persisting to the database."""
await _cache_session(session)
async def invalidate_session_cache(session_id: str) -> None:
"""Invalidate a chat session from Redis cache.
@@ -339,6 +310,80 @@ async def invalidate_session_cache(session_id: str) -> None:
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
prisma_session = await chat_db.get_chat_session(session_id)
if not prisma_session:
return None
messages = prisma_session.Messages
logger.info(
f"Loading session {session_id} from DB: "
f"has_messages={messages is not None}, "
f"message_count={len(messages) if messages else 0}, "
f"roles={[m.role for m in messages] if messages else []}"
)
return ChatSession.from_db(prisma_session, messages)
async def _save_session_to_db(
session: ChatSession, existing_message_count: int
) -> None:
"""Save or update a chat session in the database."""
# Check if session exists in DB
existing = await chat_db.get_chat_session(session.session_id)
if not existing:
# Create new session
await chat_db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
# Calculate total tokens from usage
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
# Update session metadata
await chat_db.update_chat_session(
session_id=session.session_id,
credentials=session.credentials,
successful_agent_runs=session.successful_agent_runs,
successful_agent_schedules=session.successful_agent_schedules,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
if new_messages:
messages_data = []
for msg in new_messages:
messages_data.append(
{
"role": msg.role,
"content": msg.content,
"name": msg.name,
"tool_call_id": msg.tool_call_id,
"refusal": msg.refusal,
"tool_calls": msg.tool_calls,
"function_call": msg.function_call,
}
)
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await chat_db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
async def get_chat_session(
session_id: str,
user_id: str | None = None,
@@ -370,7 +415,7 @@ async def get_chat_session(
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
# Fall back to database
logger.debug(f"Session {session_id} not in cache, checking database")
logger.info(f"Session {session_id} not in cache, checking database")
session = await _get_session_from_db(session_id)
if session is None:
@@ -386,7 +431,7 @@ async def get_chat_session(
# Cache the session from DB
try:
await cache_chat_session(session)
await _cache_session(session)
logger.info(f"Cached session {session_id} from database")
except Exception as e:
logger.warning(f"Failed to cache session {session_id}: {e}")
@@ -394,45 +439,9 @@ async def get_chat_session(
return session
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
"""Get a chat session from Redis cache."""
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
raw_session: bytes | None = await async_redis.get(redis_key)
if raw_session is None:
return None
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
except Exception as e:
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
raise RedisError(f"Corrupted session data for {session_id}") from e
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
session = await chat_db().get_chat_session(session_id)
if not session:
return None
logger.info(
f"Loaded session {session_id} from DB: "
f"has_messages={bool(session.messages)}, "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
async def upsert_chat_session(session: ChatSession) -> ChatSession:
async def upsert_chat_session(
session: ChatSession,
) -> ChatSession:
"""Update a chat session in both cache and database.
Uses session-level locking to prevent race conditions when concurrent
@@ -450,7 +459,7 @@ async def upsert_chat_session(session: ChatSession) -> ChatSession:
async with lock:
# Get existing message count from DB for incremental saves
existing_message_count = await chat_db().get_chat_session_message_count(
existing_message_count = await chat_db.get_chat_session_message_count(
session.session_id
)
@@ -467,7 +476,7 @@ async def upsert_chat_session(session: ChatSession) -> ChatSession:
# Save to cache (best-effort, even if DB failed)
try:
await cache_chat_session(session)
await _cache_session(session)
except Exception as e:
# If DB succeeded but cache failed, raise cache error
if db_error is None:
@@ -488,99 +497,6 @@ async def upsert_chat_session(session: ChatSession) -> ChatSession:
return session
async def _save_session_to_db(
session: ChatSession, existing_message_count: int
) -> None:
"""Save or update a chat session in the database."""
db = chat_db()
# Check if session exists in DB
existing = await db.get_chat_session(session.session_id)
if not existing:
# Create new session
await db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
# Calculate total tokens from usage
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
# Update session metadata
await db.update_chat_session(
session_id=session.session_id,
credentials=session.credentials,
successful_agent_runs=session.successful_agent_runs,
successful_agent_schedules=session.successful_agent_schedules,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
if new_messages:
messages_data = []
for msg in new_messages:
messages_data.append(
{
"role": msg.role,
"content": msg.content,
"name": msg.name,
"tool_call_id": msg.tool_call_id,
"refusal": msg.refusal,
"tool_calls": msg.tool_calls,
"function_call": msg.function_call,
}
)
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.
Acquires the session lock, re-fetches the latest session state,
appends the message, and saves preventing message loss when
concurrent requests modify the same session.
"""
lock = await _get_session_lock(session_id)
async with lock:
session = await get_chat_session(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
session.messages.append(message)
existing_message_count = await chat_db().get_chat_session_message_count(
session_id
)
try:
await _save_session_to_db(session, existing_message_count)
except Exception as e:
raise DatabaseError(
f"Failed to persist message to session {session_id}"
) from e
try:
await cache_chat_session(session)
except Exception as e:
logger.warning(f"Cache write failed for session {session_id}: {e}")
return session
async def create_chat_session(user_id: str) -> ChatSession:
"""Create a new chat session and persist it.
@@ -593,7 +509,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
# Create in database first - fail fast if this fails
try:
await chat_db().create_chat_session(
await chat_db.create_chat_session(
session_id=session.session_id,
user_id=user_id,
)
@@ -605,7 +521,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
# Cache the session (best-effort optimization, DB is source of truth)
try:
await cache_chat_session(session)
await _cache_session(session)
except Exception as e:
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
@@ -616,16 +532,20 @@ async def get_user_sessions(
user_id: str,
limit: int = 50,
offset: int = 0,
) -> tuple[list[ChatSessionInfo], int]:
) -> tuple[list[ChatSession], int]:
"""Get chat sessions for a user from the database with total count.
Returns:
A tuple of (sessions, total_count) where total_count is the overall
number of sessions for the user (not just the current page).
"""
db = chat_db()
sessions = await db.get_user_chat_sessions(user_id, limit, offset)
total_count = await db.get_user_session_count(user_id)
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
total_count = await chat_db.get_user_session_count(user_id)
sessions = []
for prisma_session in prisma_sessions:
# Convert without messages for listing (lighter weight)
sessions.append(ChatSession.from_db(prisma_session, None))
return sessions, total_count
@@ -643,7 +563,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
"""
# Delete from database first (with optional user_id validation)
# This confirms ownership before invalidating cache
deleted = await chat_db().delete_chat_session(session_id, user_id)
deleted = await chat_db.delete_chat_session(session_id, user_id)
if not deleted:
return False
@@ -678,52 +598,20 @@ async def update_session_title(session_id: str, title: str) -> bool:
True if updated successfully, False otherwise.
"""
try:
result = await chat_db().update_chat_session(session_id=session_id, title=title)
result = await chat_db.update_chat_session(session_id=session_id, title=title)
if result is None:
logger.warning(f"Session {session_id} not found for title update")
return False
# Update title in cache if it exists (instead of invalidating).
# This prevents race conditions where cache invalidation causes
# the frontend to see stale DB data while streaming is still in progress.
# Invalidate cache so next fetch gets updated title
try:
cached = await _get_session_from_cache(session_id)
if cached:
cached.title = title
await cache_chat_session(cached)
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
except Exception as e:
# Not critical - title will be correct on next full cache refresh
logger.warning(
f"Failed to update title in cache for session {session_id}: {e}"
)
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
return True
except Exception as e:
logger.error(f"Failed to update title for session {session_id}: {e}")
return False
# ==================== Chat session locks ==================== #
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
This was originally added to solve the specific problem of race conditions between
the session title thread and the conversation thread, which always occurs on the
same instance as we prevent rapid request sends on the frontend.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks. Explicit cleanup also occurs
in `delete_chat_session()`.
"""
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock

View File

@@ -0,0 +1,119 @@
import pytest
from .model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
upsert_chat_session,
)
messages = [
ChatMessage(content="Hello, how are you?", role="user"),
ChatMessage(
content="I'm fine, thank you!",
role="assistant",
tool_calls=[
{
"id": "t123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "New York"}',
},
}
],
),
ChatMessage(
content="I'm using the tool to get the weather",
role="tool",
tool_call_id="t123",
),
]
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_serialization_deserialization():
s = ChatSession.new(user_id="abc123")
s.messages = messages
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
serialized = s.model_dump_json()
s2 = ChatSession.model_validate_json(serialized)
assert s2.model_dump() == s.model_dump()
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
s2 = await get_chat_session(
session_id=s.session_id,
user_id=s.user_id,
)
assert s2 == s
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_redis_storage_user_id_mismatch(
setup_test_user, test_user_id
):
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
s2 = await get_chat_session(s.session_id, "different_user_id")
assert s2 is None
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_db_storage(setup_test_user, test_user_id):
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
from backend.data.redis_client import get_redis_async
# Create session with messages including assistant message
s = ChatSession.new(user_id=test_user_id)
s.messages = messages # Contains user, assistant, and tool messages
assert s.session_id is not None, "Session id is not set"
# Upsert to save to both cache and DB
s = await upsert_chat_session(s)
# Clear the Redis cache to force DB load
redis_key = f"chat:session:{s.session_id}"
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
# Load from DB (cache was cleared)
s2 = await get_chat_session(
session_id=s.session_id,
user_id=s.user_id,
)
assert s2 is not None, "Session not found after loading from DB"
assert len(s2.messages) == len(
s.messages
), f"Message count mismatch: expected {len(s.messages)}, got {len(s2.messages)}"
# Verify all roles are present
roles = [m.role for m in s2.messages]
assert "user" in roles, f"User message missing. Roles found: {roles}"
assert "assistant" in roles, f"Assistant message missing. Roles found: {roles}"
assert "tool" in roles, f"Tool message missing. Roles found: {roles}"
# Verify message content
for orig, loaded in zip(s.messages, s2.messages):
assert orig.role == loaded.role, f"Role mismatch: {orig.role} != {loaded.role}"
assert (
orig.content == loaded.content
), f"Content mismatch for {orig.role}: {orig.content} != {loaded.content}"
if orig.tool_calls:
assert (
loaded.tool_calls is not None
), f"Tool calls missing for {orig.role} message"
assert len(orig.tool_calls) == len(loaded.tool_calls)

View File

@@ -10,8 +10,6 @@ from typing import Any
from pydantic import BaseModel, Field
from backend.util.json import dumps as json_dumps
class ResponseType(str, Enum):
"""Types of streaming responses following AI SDK protocol."""
@@ -20,10 +18,6 @@ class ResponseType(str, Enum):
START = "start"
FINISH = "finish"
# Step lifecycle (one LLM API call within a message)
START_STEP = "start-step"
FINISH_STEP = "finish-step"
# Text streaming
TEXT_START = "text-start"
TEXT_DELTA = "text-delta"
@@ -58,20 +52,6 @@ class StreamStart(StreamBaseResponse):
type: ResponseType = ResponseType.START
messageId: str = Field(..., description="Unique message ID")
taskId: str | None = Field(
default=None,
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-protocol fields like taskId."""
import json
data: dict[str, Any] = {
"type": self.type.value,
"messageId": self.messageId,
}
return f"data: {json.dumps(data)}\n\n"
class StreamFinish(StreamBaseResponse):
@@ -80,26 +60,6 @@ class StreamFinish(StreamBaseResponse):
type: ResponseType = ResponseType.FINISH
class StreamStartStep(StreamBaseResponse):
"""Start of a step (one LLM API call within a message).
The AI SDK uses this to add a step-start boundary to message.parts,
enabling visual separation between multiple LLM calls in a single message.
"""
type: ResponseType = ResponseType.START_STEP
class StreamFinishStep(StreamBaseResponse):
"""End of a step (one LLM API call within a message).
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
so the next LLM call in a tool-call continuation starts with clean state.
"""
type: ResponseType = ResponseType.FINISH_STEP
# ========== Text Streaming ==========
@@ -153,7 +113,7 @@ class StreamToolOutputAvailable(StreamBaseResponse):
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
toolCallId: str = Field(..., description="Tool call ID this responds to")
output: str | dict[str, Any] = Field(..., description="Tool execution output")
# Keep these for internal backend use
# Additional fields for internal use (not part of AI SDK spec but useful)
toolName: str | None = Field(
default=None, description="Name of the tool that was executed"
)
@@ -161,17 +121,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
default=True, description="Whether the tool execution succeeded"
)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-spec fields."""
import json
data = {
"type": self.type.value,
"toolCallId": self.toolCallId,
"output": self.output,
}
return f"data: {json.dumps(data)}\n\n"
# ========== Other ==========
@@ -195,18 +144,6 @@ class StreamError(StreamBaseResponse):
default=None, description="Additional error details"
)
def to_sse(self) -> str:
"""Convert to SSE format, only emitting fields required by AI SDK protocol.
The AI SDK uses z.strictObject({type, errorText}) which rejects
any extra fields like `code` or `details`.
"""
data = {
"type": self.type.value,
"errorText": self.errorText,
}
return f"data: {json_dumps(data)}\n\n"
class StreamHeartbeat(StreamBaseResponse):
"""Heartbeat to keep SSE connection alive during long-running operations.

View File

@@ -1,60 +1,20 @@
"""Chat API routes for chat session management and streaming via SSE."""
import asyncio
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
from typing import Annotated
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
from fastapi import APIRouter, Depends, Query, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.completion_handler import (
process_operation_failure,
process_operation_success,
)
from backend.copilot.config import ChatConfig
from backend.copilot.executor.utils import enqueue_copilot_task
from backend.copilot.model import (
ChatMessage,
ChatSession,
append_and_save_message,
create_chat_session,
delete_chat_session,
get_chat_session,
get_user_sessions,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
AgentPreviewResponse,
AgentSavedResponse,
AgentsFoundResponse,
BlockDetailsResponse,
BlockListResponse,
BlockOutputResponse,
ClarificationNeededResponse,
DocPageResponse,
DocSearchResultsResponse,
ErrorResponse,
ExecutionStartedResponse,
InputValidationErrorResponse,
NeedLoginResponse,
NoResultsResponse,
OperationInProgressResponse,
OperationPendingResponse,
OperationStartedResponse,
SetupRequirementsResponse,
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from . import service as chat_service
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
config = ChatConfig()
@@ -95,15 +55,6 @@ class CreateSessionResponse(BaseModel):
user_id: str | None
class ActiveStreamInfo(BaseModel):
"""Information about an active stream for reconnection."""
task_id: str
last_message_id: str # Redis Stream message ID for resumption
operation_id: str # Operation ID for completion tracking
tool_name: str # Name of the tool being executed
class SessionDetailResponse(BaseModel):
"""Response model providing complete details for a chat session, including messages."""
@@ -112,7 +63,6 @@ class SessionDetailResponse(BaseModel):
updated_at: str
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
class SessionSummaryResponse(BaseModel):
@@ -131,14 +81,6 @@ class ListSessionsResponse(BaseModel):
total: int
class OperationCompleteRequest(BaseModel):
"""Request model for external completion webhook."""
success: bool
result: dict | str | None = None
error: str | None = None
# ========== Routes ==========
@@ -213,43 +155,6 @@ async def create_session(
)
@router.delete(
"/sessions/{session_id}",
dependencies=[Security(auth.requires_user)],
status_code=204,
responses={404: {"description": "Session not found or access denied"}},
)
async def delete_session(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> Response:
"""
Delete a chat session.
Permanently removes a chat session and all its messages.
Only the owner can delete their sessions.
Args:
session_id: The session ID to delete.
user_id: The authenticated user's ID.
Returns:
204 No Content on success.
Raises:
HTTPException: 404 if session not found or not owned by user.
"""
deleted = await delete_chat_session(session_id, user_id)
if not deleted:
raise HTTPException(
status_code=404,
detail=f"Session {session_id} not found or access denied",
)
return Response(status_code=204)
@router.get(
"/sessions/{session_id}",
)
@@ -261,14 +166,13 @@ async def get_session(
Retrieve the details of a specific chat session.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns the task_id for reconnection.
Args:
session_id: The unique identifier for the desired chat session.
user_id: The optional authenticated user ID, or None for anonymous access.
Returns:
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
SessionDetailResponse: Details for the requested session, or None if not found.
"""
session = await get_chat_session(session_id, user_id)
@@ -276,32 +180,11 @@ async def get_session(
raise NotFoundError(f"Session {session_id} not found.")
messages = [message.model_dump() for message in session.messages]
# Check if there's an active stream for this session
active_stream_info = None
active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
f"Returning session {session_id}: "
f"message_count={len(messages)}, "
f"roles={[m.get('role') for m in messages]}"
)
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
# stream replay instead, preventing duplicate content.
if messages and messages[-1].get("role") == "assistant":
messages = messages[:-1]
# Use "0-0" as last_message_id to replay the stream from the beginning.
# Since we filtered out the cached assistant message, the client needs
# the full stream to reconstruct the response.
active_stream_info = ActiveStreamInfo(
task_id=active_task.task_id,
last_message_id="0-0",
operation_id=active_task.operation_id,
tool_name=active_task.tool_name,
)
return SessionDetailResponse(
id=session.session_id,
@@ -309,7 +192,6 @@ async def get_session(
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
)
@@ -329,225 +211,49 @@ async def stream_chat_post(
- Tool call UI elements (if invoked)
- Tool execution results
The AI generation runs in a background task that continues even if the client disconnects.
All chunks are written to Redis for reconnection support. If the client disconnects,
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
containing the task_id for reconnection.
StreamingResponse: SSE-formatted response chunks.
"""
import asyncio
import time
session = await _validate_and_get_session(session_id, user_id)
stream_start_time = time.perf_counter()
log_meta = {"component": "ChatStream", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
f"user={user_id}, message_len={len(request.message)}",
extra={"json_fields": log_meta},
)
await _validate_and_get_session(session_id, user_id)
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
}
},
)
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
# message loss from concurrent requests.
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
log_meta["task_id"] = task_id
task_create_start = time.perf_counter()
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream", # Not a tool call, but needed for the model
tool_name="chat",
operation_id=operation_id,
)
logger.info(
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
}
},
)
await enqueue_copilot_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
operation_id=operation_id,
message=request.message,
is_user_message=request.is_user_message,
context=request.context,
)
setup_time = (time.perf_counter() - stream_start_time) * 1000
logger.info(
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
# SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]:
import time as time_module
event_gen_start = time_module.perf_counter()
chunk_count = 0
first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
f"user={user_id}",
extra={"json_fields": log_meta},
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
subscriber_queue = None
first_chunk_yielded = False
chunks_yielded = 0
try:
# Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id="0-0", # Get all messages from the beginning
)
if subscriber_queue is None:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return
# Read from the subscriber queue and yield to SSE
logger.info(
"[TIMING] Starting to read from subscriber_queue",
extra={"json_fields": log_meta},
)
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
chunks_yielded += 1
if not first_chunk_yielded:
first_chunk_yielded = True
elapsed = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
f"type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"chunk_type": type(chunk).__name__,
"elapsed_ms": elapsed * 1000,
}
},
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
f"n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"total_time_ms": total_time * 1000,
}
},
)
break
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
except GeneratorExit:
logger.info(
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"reason": "client_disconnect",
}
},
)
pass # Client disconnected - background task continues
except Exception as e:
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
logger.error(
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
extra={
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
},
)
# Surface error to frontend so it doesn't appear stuck
yield StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
finally:
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_task(
task_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time * 1000,
"chunks_yielded": chunks_yielded,
}
},
)
yield "data: [DONE]\n\n"
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@@ -564,90 +270,63 @@ async def stream_chat_post(
@router.get(
"/sessions/{session_id}/stream",
)
async def resume_session_stream(
async def stream_chat_get(
session_id: str,
message: Annotated[str, Query(min_length=1, max_length=10000)],
user_id: str | None = Depends(auth.get_user_id),
is_user_message: bool = Query(default=True),
):
"""
Resume an active stream for a session.
Stream chat responses for a session (GET - legacy endpoint).
Called by the AI SDK's ``useChat(resume: true)`` on page load.
Checks for an active (in-progress) task on the session and either replays
the full SSE stream or returns 204 No Content if nothing is running.
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
- Text fragments as they are generated
- Tool call UI elements (if invoked)
- Tool execution results
Args:
session_id: The chat session identifier.
session_id: The chat session identifier to associate with the streamed messages.
message: The user's new message to process.
user_id: Optional authenticated user ID.
is_user_message: Whether the message is a user message.
Returns:
StreamingResponse (SSE) when an active stream exists,
or 204 No Content when there is nothing to resume.
StreamingResponse: SSE-formatted response chunks.
"""
import asyncio
active_task, _last_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if not active_task:
return Response(status_code=204)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=active_task.task_id,
user_id=user_id,
last_message_id="0-0", # Full replay so useChat rebuilds the message
)
if subscriber_queue is None:
return Response(status_code=204)
session = await _validate_and_get_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
chunk_count = 0
first_chunk_type: str | None = None
try:
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
if chunk_count < 3:
logger.info(
"Resume stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
except GeneratorExit:
pass
except Exception as e:
logger.error(f"Error in resume stream for session {session_id}: {e}")
finally:
try:
await stream_registry.unsubscribe_from_task(
active_task.task_id, subscriber_queue
async for chunk in chat_service.stream_chat_completion(
session_id,
message,
is_user_message=is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
exc_info=True,
)
logger.info(
"Resume stream completed",
extra={
"session_id": session_id,
"n_chunks": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
yield "data: [DONE]\n\n"
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@@ -655,8 +334,8 @@ async def resume_session_stream(
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
"X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
},
)
@@ -687,249 +366,6 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Task Streaming (SSE Reconnection) ==========
@router.get(
"/tasks/{task_id}/stream",
)
async def stream_task(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
last_message_id: str = Query(
default="0-0",
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
),
):
"""
Reconnect to a long-running task's SSE stream.
When a long-running operation (like agent generation) starts, the client
receives a task_id. If the connection drops, the client can reconnect
using this endpoint to resume receiving updates.
Args:
task_id: The task ID from the operation_started response.
user_id: Authenticated user ID for ownership validation.
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
Returns:
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
Raises:
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
"""
# Check task existence and expiry before subscribing
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
if error_code == "TASK_EXPIRED":
raise HTTPException(
status_code=410,
detail={
"code": "TASK_EXPIRED",
"message": "This operation has expired. Please try again.",
},
)
if error_code == "TASK_NOT_FOUND":
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found.",
},
)
# Validate ownership if task has an owner
if task and task.user_id and user_id != task.user_id:
raise HTTPException(
status_code=403,
detail={
"code": "ACCESS_DENIED",
"message": "You do not have access to this task.",
},
)
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=last_message_id,
)
if subscriber_queue is None:
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found or access denied.",
},
)
async def event_generator() -> AsyncGenerator[str, None]:
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:
try:
# Wait for next chunk with timeout for heartbeats
chunk = await asyncio.wait_for(
subscriber_queue.get(), timeout=heartbeat_interval
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except Exception as e:
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
finally:
# Unsubscribe when client disconnects or stream ends
try:
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
)
@router.get(
"/tasks/{task_id}",
)
async def get_task_status(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
) -> dict:
"""
Get the status of a long-running task.
Args:
task_id: The task ID to check.
user_id: Authenticated user ID for ownership validation.
Returns:
dict: Task status including task_id, status, tool_name, and operation_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
task = await stream_registry.get_task(task_id)
if task is None:
raise NotFoundError(f"Task {task_id} not found.")
# Validate ownership - if task has an owner, requester must match
if task.user_id and user_id != task.user_id:
raise NotFoundError(f"Task {task_id} not found.")
return {
"task_id": task.task_id,
"session_id": task.session_id,
"status": task.status,
"tool_name": task.tool_name,
"operation_id": task.operation_id,
"created_at": task.created_at.isoformat(),
}
# ========== External Completion Webhook ==========
@router.post(
"/operations/{operation_id}/complete",
status_code=200,
)
async def complete_operation(
operation_id: str,
request: OperationCompleteRequest,
x_api_key: str | None = Header(default=None),
) -> dict:
"""
External completion webhook for long-running operations.
Called by Agent Generator (or other services) when an operation completes.
This triggers the stream registry to publish completion and continue LLM generation.
Args:
operation_id: The operation ID to complete.
request: Completion payload with success status and result/error.
x_api_key: Internal API key for authentication.
Returns:
dict: Status of the completion.
Raises:
HTTPException: If API key is invalid or operation not found.
"""
# Validate internal API key - reject if not configured or invalid
if not config.internal_api_key:
logger.error(
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
)
raise HTTPException(
status_code=503,
detail="Webhook not available: internal API key not configured",
)
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
# Find task by operation_id
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Operation {operation_id} not found",
)
logger.info(
f"Received completion webhook for operation {operation_id} "
f"(task_id={task.task_id}, success={request.success})"
)
if request.success:
await process_operation_success(task, request.result)
else:
await process_operation_failure(task, request.error)
return {"status": "ok", "task_id": task.task_id}
# ========== Configuration ==========
@router.get("/config/ttl", status_code=200)
async def get_ttl_config() -> dict:
"""
Get the stream TTL configuration.
Returns the Time-To-Live settings for chat streams, which determines
how long clients can reconnect to an active stream.
Returns:
dict: TTL configuration with seconds and milliseconds values.
"""
return {
"stream_ttl_seconds": config.stream_ttl,
"stream_ttl_ms": config.stream_ttl * 1000,
}
# ========== Health Check ==========
@@ -966,43 +402,3 @@ async def health_check() -> dict:
"service": "chat",
"version": "0.1.0",
}
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
ToolResponseUnion = (
AgentsFoundResponse
| NoResultsResponse
| AgentDetailsResponse
| SetupRequirementsResponse
| ExecutionStartedResponse
| NeedLoginResponse
| ErrorResponse
| InputValidationErrorResponse
| AgentOutputResponse
| UnderstandingUpdatedResponse
| AgentPreviewResponse
| AgentSavedResponse
| ClarificationNeededResponse
| BlockListResponse
| BlockDetailsResponse
| BlockOutputResponse
| DocSearchResultsResponse
| DocPageResponse
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
)
@router.get(
"/schema/tool-responses",
response_model=ToolResponseUnion,
include_in_schema=True,
summary="[Dummy] Tool response type export for codegen",
description="This endpoint is not meant to be called. It exists solely to "
"expose tool response models in the OpenAPI schema for frontend codegen.",
)
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
"""Never called at runtime. Exists only so Orval generates TS types."""
raise HTTPException(status_code=501, detail="Schema-only endpoint")

View File

@@ -0,0 +1,82 @@
import logging
from os import getenv
import pytest
from . import service as chat_service
from .model import create_chat_session, get_chat_session, upsert_chat_session
from .response_model import (
StreamError,
StreamFinish,
StreamTextDelta,
StreamToolOutputAvailable,
)
logger = logging.getLogger(__name__)
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion(setup_test_user, test_user_id):
"""
Test the stream_chat_completion function.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
has_errors = False
has_ended = False
assistant_message = ""
async for chunk in chat_service.stream_chat_completion(
session.session_id, "Hello, how are you?", user_id=session.user_id
):
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamTextDelta):
assistant_message += chunk.delta
if isinstance(chunk, StreamFinish):
has_ended = True
assert has_ended, "Chat completion did not end"
assert not has_errors, "Error occurred while streaming chat completion"
assert assistant_message, "Assistant message is empty"
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
"""
Test the stream_chat_completion function.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
has_errors = False
has_ended = False
had_tool_calls = False
async for chunk in chat_service.stream_chat_completion(
session.session_id,
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
user_id=session.user_id,
):
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamFinish):
has_ended = True
if isinstance(chunk, StreamToolOutputAvailable):
had_tool_calls = True
assert has_ended, "Chat completion did not end"
assert not has_errors, "Error occurred while streaming chat completion"
assert had_tool_calls, "Tool calls did not occur"
session = await get_chat_session(session.session_id)
assert session, "Session not found"
assert session.usage, "Usage is empty"

View File

@@ -3,18 +3,14 @@ from typing import TYPE_CHECKING, Any
from openai.types.chat import ChatCompletionToolParam
from backend.copilot.model import ChatSession
from backend.copilot.tracking import track_tool_called
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tracking import track_tool_called
from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .check_operation_status import CheckOperationStatusTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool
from .find_agent import FindAgentTool
from .find_block import FindBlockTool
from .find_library_agent import FindLibraryAgentTool
@@ -22,7 +18,6 @@ from .get_doc_page import GetDocPageTool
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .search_docs import SearchDocsTool
from .web_fetch import WebFetchTool
from .workspace_files import (
DeleteWorkspaceFileTool,
ListWorkspaceFilesTool,
@@ -31,7 +26,7 @@ from .workspace_files import (
)
if TYPE_CHECKING:
from backend.copilot.response_model import StreamToolOutputAvailable
from backend.api.features.chat.response_model import StreamToolOutputAvailable
logger = logging.getLogger(__name__)
@@ -39,7 +34,6 @@ logger = logging.getLogger(__name__)
TOOL_REGISTRY: dict[str, BaseTool] = {
"add_understanding": AddUnderstandingTool(),
"create_agent": CreateAgentTool(),
"customize_agent": CustomizeAgentTool(),
"edit_agent": EditAgentTool(),
"find_agent": FindAgentTool(),
"find_block": FindBlockTool(),
@@ -47,17 +41,8 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"view_agent_output": AgentOutputTool(),
"check_operation_status": CheckOperationStatusTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
# Web fetch for safe URL retrieval
"web_fetch": WebFetchTool(),
# Sandboxed code execution (bubblewrap)
"bash_exec": BashExecTool(),
# Persistent workspace tools (cloud storage, survives across sessions)
# Feature request tools
"search_feature_requests": SearchFeatureRequestsTool(),
"create_feature_request": CreateFeatureRequestTool(),
# Workspace tools for CoPilot file operations
"list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(),

View File

@@ -6,11 +6,11 @@ import pytest
from prisma.types import ProfileCreateInput
from pydantic import SecretStr
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.blocks.llm import AITextGeneratorBlock
from backend.copilot.model import ChatSession
from backend.data.db import prisma
from backend.data.graph import Graph, Link, Node, create_graph
from backend.data.model import APIKeyCredentials

View File

@@ -3,9 +3,11 @@
import logging
from typing import Any
from backend.copilot.model import ChatSession
from backend.data.db_accessors import understanding_db
from backend.data.understanding import BusinessUnderstandingInput
from backend.api.features.chat.model import ChatSession
from backend.data.understanding import (
BusinessUnderstandingInput,
upsert_business_understanding,
)
from .base import BaseTool
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
@@ -97,9 +99,7 @@ and automations for the user's specific needs."""
]
# Upsert with merge
understanding = await understanding_db().upsert_business_understanding(
user_id, input_data
)
understanding = await upsert_business_understanding(user_id, input_data)
# Build current understanding summary (filter out empty values)
current_understanding = {

View File

@@ -8,7 +8,6 @@ from .core import (
DecompositionStep,
LibraryAgentSummary,
MarketplaceAgentSummary,
customize_template,
decompose_goal,
enrich_library_agents_from_steps,
extract_search_terms_from_steps,
@@ -20,7 +19,6 @@ from .core import (
get_library_agent_by_graph_id,
get_library_agent_by_id,
get_library_agents_for_generation,
graph_to_json,
json_to_graph,
save_agent_to_library,
search_marketplace_agents_for_generation,
@@ -38,7 +36,6 @@ __all__ = [
"LibraryAgentSummary",
"MarketplaceAgentSummary",
"check_external_service_health",
"customize_template",
"decompose_goal",
"enrich_library_agents_from_steps",
"extract_search_terms_from_steps",
@@ -51,7 +48,6 @@ __all__ = [
"get_library_agent_by_id",
"get_library_agents_for_generation",
"get_user_message_for_error",
"graph_to_json",
"is_external_service_configured",
"json_to_graph",
"save_agent_to_library",

View File

@@ -5,12 +5,20 @@ import re
import uuid
from typing import Any, NotRequired, TypedDict
from backend.data.db_accessors import graph_db, library_db, store_db
from backend.data.graph import Graph, Link, Node
from backend.api.features.library import db as library_db
from backend.api.features.store import db as store_db
from backend.data.graph import (
Graph,
Link,
Node,
create_graph,
get_graph,
get_graph_all_versions,
get_store_listed_graphs,
)
from backend.util.exceptions import DatabaseError, NotFoundError
from .service import (
customize_template_external,
decompose_goal_external,
generate_agent_external,
generate_agent_patch_external,
@@ -19,6 +27,8 @@ from .service import (
logger = logging.getLogger(__name__)
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
class ExecutionSummary(TypedDict):
"""Summary of a single execution for quality assessment."""
@@ -144,9 +154,8 @@ async def get_library_agent_by_id(
Returns:
LibraryAgentSummary if found, None otherwise
"""
db = library_db()
try:
agent = await db.get_library_agent_by_graph_id(user_id, agent_id)
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return LibraryAgentSummary(
@@ -163,7 +172,7 @@ async def get_library_agent_by_id(
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
try:
agent = await db.get_library_agent(agent_id, user_id)
agent = await library_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return LibraryAgentSummary(
@@ -215,7 +224,7 @@ async def get_library_agents_for_generation(
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
"""
try:
response = await library_db().list_library_agents(
response = await library_db.list_library_agents(
user_id=user_id,
search_term=search_query,
page=1,
@@ -272,7 +281,7 @@ async def search_marketplace_agents_for_generation(
List of LibraryAgentSummary with full input/output schemas
"""
try:
response = await store_db().get_store_agents(
response = await store_db.get_store_agents(
search_query=search_query,
page=1,
page_size=max_results,
@@ -286,7 +295,7 @@ async def search_marketplace_agents_for_generation(
return []
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
graphs = await graph_db().get_store_listed_graphs(graph_ids)
graphs = await get_store_listed_graphs(*graph_ids)
results: list[LibraryAgentSummary] = []
for agent in agents_with_graphs:
@@ -540,21 +549,15 @@ async def decompose_goal(
async def generate_agent(
instructions: DecompositionResult | dict[str, Any],
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams
completion notification)
task_id: Task ID for async processing (enables Redis Streams persistence
and SSE delivery)
Returns:
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Agent JSON dict, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
@@ -562,13 +565,8 @@ async def generate_agent(
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
dict(instructions), _to_dict_list(library_agents)
)
# Don't modify async response
if result and result.get("status") == "accepted":
return result
if result:
if isinstance(result, dict) and result.get("type") == "error":
return result
@@ -659,6 +657,45 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
)
def _reassign_node_ids(graph: Graph) -> None:
"""Reassign all node and link IDs to new UUIDs.
This is needed when creating a new version to avoid unique constraint violations.
"""
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
for node in graph.nodes:
node.id = id_map[node.id]
for link in graph.links:
link.id = str(uuid.uuid4())
if link.source_id in id_map:
link.source_id = id_map[link.source_id]
if link.sink_id in id_map:
link.sink_id = id_map[link.sink_id]
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
"""Populate user_id in AgentExecutorBlock nodes.
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
This function fills in the actual user_id so sub-agents run with correct permissions.
Args:
agent_json: Agent JSON dict (modified in place)
user_id: User ID to set
"""
for node in agent_json.get("nodes", []):
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
input_default = node.get("input_default") or {}
if not input_default.get("user_id"):
input_default["user_id"] = user_id
node["input_default"] = input_default
logger.debug(
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
)
async def save_agent_to_library(
agent_json: dict[str, Any], user_id: str, is_update: bool = False
) -> tuple[Graph, Any]:
@@ -672,22 +709,63 @@ async def save_agent_to_library(
Returns:
Tuple of (created Graph, LibraryAgent)
"""
# Populate user_id in AgentExecutorBlock nodes before conversion
_populate_agent_executor_user_ids(agent_json, user_id)
graph = json_to_graph(agent_json)
db = library_db()
if is_update:
return await db.update_graph_in_library(graph, user_id)
return await db.create_graph_in_library(graph, user_id)
if graph.id:
existing_versions = await get_graph_all_versions(graph.id, user_id)
if existing_versions:
latest_version = max(v.version for v in existing_versions)
graph.version = latest_version + 1
_reassign_node_ids(graph)
logger.info(f"Updating agent {graph.id} to version {graph.version}")
else:
graph.id = str(uuid.uuid4())
graph.version = 1
_reassign_node_ids(graph)
logger.info(f"Creating new agent with ID {graph.id}")
created_graph = await create_graph(graph, user_id)
library_agents = await library_db.create_library_agent(
graph=created_graph,
user_id=user_id,
sensitive_action_safe_mode=True,
create_library_agents_for_sub_graphs=False,
)
return created_graph, library_agents[0]
def graph_to_json(graph: Graph) -> dict[str, Any]:
"""Convert a Graph object to JSON format for the agent generator.
async def get_agent_as_json(
agent_id: str, user_id: str | None
) -> dict[str, Any] | None:
"""Fetch an agent and convert to JSON format for editing.
Args:
graph: Graph object to convert
agent_id: Graph ID or library agent ID
user_id: User ID
Returns:
Agent as JSON dict
Agent as JSON dict or None if not found
"""
graph = await get_graph(agent_id, version=None, user_id=user_id)
if not graph and user_id:
try:
library_agent = await library_db.get_library_agent(agent_id, user_id)
graph = await get_graph(
library_agent.graph_id, version=None, user_id=user_id
)
except NotFoundError:
pass
if not graph:
return None
nodes = []
for node in graph.nodes:
nodes.append(
@@ -724,43 +802,10 @@ def graph_to_json(graph: Graph) -> dict[str, Any]:
}
async def get_agent_as_json(
agent_id: str, user_id: str | None
) -> dict[str, Any] | None:
"""Fetch an agent and convert to JSON format for editing.
Args:
agent_id: Graph ID or library agent ID
user_id: User ID
Returns:
Agent as JSON dict or None if not found
"""
db = graph_db()
graph = await db.get_graph(agent_id, version=None, user_id=user_id)
if not graph and user_id:
try:
library_agent = await library_db().get_library_agent(agent_id, user_id)
graph = await db.get_graph(
library_agent.graph_id, version=None, user_id=user_id
)
except NotFoundError:
pass
if not graph:
return None
return graph_to_json(graph)
async def generate_agent_patch(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[AgentSummary] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
@@ -773,12 +818,10 @@ async def generate_agent_patch(
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
error dict {"type": "error", ...}, or None on unexpected error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
@@ -786,43 +829,5 @@ async def generate_agent_patch(
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(
update_request,
current_agent,
_to_dict_list(library_agents),
operation_id,
task_id,
)
async def customize_template(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Customize a template/marketplace agent using natural language.
This is used when users want to modify a template or marketplace agent
to fit their specific needs before adding it to their library.
The external Agent Generator service handles:
- Understanding the modification request
- Applying changes to the template
- Fixing and validating the result
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on unexpected error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for customize_template")
return await customize_template_external(
template_agent, modification_request, context
update_request, current_agent, _to_dict_list(library_agents)
)

View File

@@ -12,19 +12,8 @@ import httpx
from backend.util.settings import Settings
from .dummy import (
customize_template_dummy,
decompose_goal_dummy,
generate_agent_dummy,
generate_agent_patch_dummy,
get_blocks_dummy,
health_check_dummy,
)
logger = logging.getLogger(__name__)
_dummy_mode_warned = False
def _create_error_response(
error_message: str,
@@ -101,26 +90,10 @@ def _get_settings() -> Settings:
return _settings
def _is_dummy_mode() -> bool:
"""Check if dummy mode is enabled for testing."""
global _dummy_mode_warned
settings = _get_settings()
is_dummy = bool(settings.config.agentgenerator_use_dummy)
if is_dummy and not _dummy_mode_warned:
logger.warning(
"Agent Generator running in DUMMY MODE - returning mock responses. "
"Do not use in production!"
)
_dummy_mode_warned = True
return is_dummy
def is_external_service_configured() -> bool:
"""Check if external Agent Generator service is configured (or dummy mode)."""
"""Check if external Agent Generator service is configured."""
settings = _get_settings()
return bool(settings.config.agentgenerator_host) or bool(
settings.config.agentgenerator_use_dummy
)
return bool(settings.config.agentgenerator_host)
def _get_base_url() -> str:
@@ -164,15 +137,13 @@ async def decompose_goal_external(
- {"type": "error", "error": "...", "error_type": "..."} on error
Or None on unexpected error
"""
if _is_dummy_mode():
return await decompose_goal_dummy(description, context, library_agents)
client = _get_client()
if context:
description = f"{description}\n\nAdditional context from user:\n{context}"
# Build the request payload
payload: dict[str, Any] = {"description": description}
if context:
# The external service uses user_instruction for additional context
payload["user_instruction"] = context
if library_agents:
payload["library_agents"] = library_agents
@@ -242,50 +213,24 @@ async def decompose_goal_external(
async def generate_agent_external(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
Agent JSON dict on success, or error dict {"type": "error", ...} on error
"""
if _is_dummy_mode():
return await generate_agent_dummy(
instructions, library_agents, operation_id, task_id
)
client = _get_client()
# Build request payload
payload: dict[str, Any] = {"instructions": instructions}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/generate-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
@@ -317,8 +262,6 @@ async def generate_agent_patch_external(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
@@ -326,45 +269,21 @@ async def generate_agent_patch_external(
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
Updated agent JSON, clarifying questions dict, or error dict on error
"""
if _is_dummy_mode():
return await generate_agent_patch_dummy(
update_request, current_agent, library_agents, operation_id, task_id
)
client = _get_client()
# Build request payload
payload: dict[str, Any] = {
"update_request": update_request,
"current_agent_json": current_agent,
}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/update-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async update request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
@@ -408,91 +327,12 @@ async def generate_agent_patch_external(
return _create_error_response(error_msg, "unexpected_error")
async def customize_template_external(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Call the external service to customize a template/marketplace agent.
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict, or error dict on error
"""
if _is_dummy_mode():
return await customize_template_dummy(
template_agent, modification_request, context
)
client = _get_client()
request = modification_request
if context:
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {
"template_agent_json": template_agent,
"modification_request": request,
}
try:
response = await client.post("/api/template-modification", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator template customization failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the customized agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def get_blocks_external() -> list[dict[str, Any]] | None:
"""Get available blocks from the external service.
Returns:
List of block info dicts or None on error
"""
if _is_dummy_mode():
return await get_blocks_dummy()
client = _get_client()
try:
@@ -526,9 +366,6 @@ async def health_check() -> bool:
if not is_external_service_configured():
return False
if _is_dummy_mode():
return await health_check_dummy()
client = _get_client()
try:

View File

@@ -5,15 +5,15 @@ import re
from datetime import datetime, timedelta, timezone
from typing import Any
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.library import db as library_db
from backend.api.features.library.model import LibraryAgent
from backend.copilot.model import ChatSession
from backend.data.db_accessors import execution_db, library_db
from backend.data import execution as execution_db
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
from .base import BaseTool
from .execution_utils import TERMINAL_STATUSES, wait_for_execution
from .models import (
AgentOutputResponse,
ErrorResponse,
@@ -34,7 +34,6 @@ class AgentOutputInput(BaseModel):
store_slug: str = ""
execution_id: str = ""
run_time: str = "latest"
wait_if_running: int = Field(default=0, ge=0, le=300)
@field_validator(
"agent_name",
@@ -118,11 +117,6 @@ class AgentOutputTool(BaseTool):
Select which run to retrieve using:
- execution_id: Specific execution ID
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
Wait for completion (optional):
- wait_if_running: Max seconds to wait if execution is still running (0-300).
If the execution is running/queued, waits up to this many seconds for completion.
Returns current status on timeout. If already finished, returns immediately.
"""
@property
@@ -152,13 +146,6 @@ class AgentOutputTool(BaseTool):
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
),
},
"wait_if_running": {
"type": "integer",
"description": (
"Max seconds to wait if execution is still running (0-300). "
"If running, waits for completion. Returns current state on timeout."
),
},
},
"required": [],
}
@@ -178,12 +165,10 @@ class AgentOutputTool(BaseTool):
Resolve agent from provided identifiers.
Returns (library_agent, error_message).
"""
lib_db = library_db()
# Priority 1: Exact library agent ID
if library_agent_id:
try:
agent = await lib_db.get_library_agent(library_agent_id, user_id)
agent = await library_db.get_library_agent(library_agent_id, user_id)
return agent, None
except Exception as e:
logger.warning(f"Failed to get library agent by ID: {e}")
@@ -197,7 +182,7 @@ class AgentOutputTool(BaseTool):
return None, f"Agent '{store_slug}' not found in marketplace"
# Find in user's library by graph_id
agent = await lib_db.get_library_agent_by_graph_id(user_id, graph.id)
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
if not agent:
return (
None,
@@ -209,7 +194,7 @@ class AgentOutputTool(BaseTool):
# Priority 3: Fuzzy name search in library
if agent_name:
try:
response = await lib_db.list_library_agents(
response = await library_db.list_library_agents(
user_id=user_id,
search_term=agent_name,
page_size=5,
@@ -238,20 +223,14 @@ class AgentOutputTool(BaseTool):
execution_id: str | None,
time_start: datetime | None,
time_end: datetime | None,
include_running: bool = False,
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
"""
Fetch execution(s) based on filters.
Returns (single_execution, available_executions_meta, error_message).
Args:
include_running: If True, also look for running/queued executions (for waiting)
"""
exec_db = execution_db()
# If specific execution_id provided, fetch it directly
if execution_id:
execution = await exec_db.get_graph_execution(
execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
@@ -260,22 +239,11 @@ class AgentOutputTool(BaseTool):
return None, [], f"Execution '{execution_id}' not found"
return execution, [], None
# Determine which statuses to query
statuses = [ExecutionStatus.COMPLETED]
if include_running:
statuses.extend(
[
ExecutionStatus.RUNNING,
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
]
)
# Get executions with time filters
executions = await exec_db.get_graph_executions(
# Get completed executions with time filters
executions = await execution_db.get_graph_executions(
graph_id=graph_id,
user_id=user_id,
statuses=statuses,
statuses=[ExecutionStatus.COMPLETED],
created_time_gte=time_start,
created_time_lte=time_end,
limit=10,
@@ -286,7 +254,7 @@ class AgentOutputTool(BaseTool):
# If only one execution, fetch full details
if len(executions) == 1:
full_execution = await exec_db.get_graph_execution(
full_execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=executions[0].id,
include_node_executions=False,
@@ -294,7 +262,7 @@ class AgentOutputTool(BaseTool):
return full_execution, [], None
# Multiple executions - return latest with full details, plus list of available
full_execution = await exec_db.get_graph_execution(
full_execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=executions[0].id,
include_node_executions=False,
@@ -342,28 +310,10 @@ class AgentOutputTool(BaseTool):
for e in available_executions[:5]
]
# Build appropriate message based on execution status
if execution.status == ExecutionStatus.COMPLETED:
message = f"Found execution outputs for agent '{agent.name}'"
elif execution.status == ExecutionStatus.FAILED:
message = f"Execution for agent '{agent.name}' failed"
elif execution.status == ExecutionStatus.TERMINATED:
message = f"Execution for agent '{agent.name}' was terminated"
elif execution.status in (
ExecutionStatus.RUNNING,
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
):
message = (
f"Execution for agent '{agent.name}' is still {execution.status.value}. "
"Results may be incomplete. Use wait_if_running to wait for completion."
)
else:
message = f"Found execution for agent '{agent.name}' (status: {execution.status.value})"
message = f"Found execution outputs for agent '{agent.name}'"
if len(available_executions) > 1:
message += (
f" Showing latest of {len(available_executions)} matching executions."
f". Showing latest of {len(available_executions)} matching executions."
)
return AgentOutputResponse(
@@ -430,7 +380,7 @@ class AgentOutputTool(BaseTool):
and not input_data.store_slug
):
# Fetch execution directly to get graph_id
execution = await execution_db().get_graph_execution(
execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=input_data.execution_id,
include_node_executions=False,
@@ -442,7 +392,7 @@ class AgentOutputTool(BaseTool):
)
# Find library agent by graph_id
agent = await library_db().get_library_agent_by_graph_id(
agent = await library_db.get_library_agent_by_graph_id(
user_id, execution.graph_id
)
if not agent:
@@ -478,17 +428,13 @@ class AgentOutputTool(BaseTool):
# Parse time expression
time_start, time_end = parse_time_expression(input_data.run_time)
# Check if we should wait for running executions
wait_timeout = input_data.wait_if_running
# Fetch execution(s) - include running if we're going to wait
# Fetch execution(s)
execution, available_executions, exec_error = await self._get_execution(
user_id=user_id,
graph_id=agent.graph_id,
execution_id=input_data.execution_id or None,
time_start=time_start,
time_end=time_end,
include_running=wait_timeout > 0,
)
if exec_error:
@@ -497,17 +443,4 @@ class AgentOutputTool(BaseTool):
session_id=session_id,
)
# If we have an execution that's still running and we should wait
if execution and wait_timeout > 0 and execution.status not in TERMINAL_STATUSES:
logger.info(
f"Execution {execution.id} is {execution.status}, "
f"waiting up to {wait_timeout}s for completion"
)
execution = await wait_for_execution(
user_id=user_id,
graph_id=agent.graph_id,
execution_id=execution.id,
timeout_seconds=wait_timeout,
)
return self._build_response(agent, execution, available_executions, session_id)

View File

@@ -4,7 +4,8 @@ import logging
import re
from typing import Literal
from backend.data.db_accessors import library_db, store_db
from backend.api.features.library import db as library_db
from backend.api.features.store import db as store_db
from backend.util.exceptions import DatabaseError, NotFoundError
from .models import (
@@ -44,10 +45,8 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
Returns:
AgentInfo if found, None otherwise
"""
lib_db = library_db()
try:
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return AgentInfo(
@@ -72,7 +71,7 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
)
try:
agent = await lib_db.get_library_agent(agent_id, user_id)
agent = await library_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return AgentInfo(
@@ -134,7 +133,7 @@ async def search_agents(
try:
if source == "marketplace":
logger.info(f"Searching marketplace for: {query}")
results = await store_db().get_store_agents(search_query=query, page_size=5)
results = await store_db.get_store_agents(search_query=query, page_size=5)
for agent in results.agents:
agents.append(
AgentInfo(
@@ -160,7 +159,7 @@ async def search_agents(
if not agents:
logger.info(f"Searching user library for: {query}")
results = await library_db().list_library_agents(
results = await library_db.list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=query,
page_size=10,
@@ -207,9 +206,9 @@ async def search_agents(
]
)
no_results_msg = (
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
if source == "marketplace"
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
else f"No agents matching '{query}' found in your library."
)
return NoResultsResponse(
message=no_results_msg, session_id=session_id, suggestions=suggestions
@@ -225,10 +224,10 @@ async def search_agents(
message = (
"Now you have found some options for the user to choose from. "
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
"Please ask the user if they would like to use any of these agents."
if source == "marketplace"
else "Found agents in the user's library. You can provide a link to view an agent at: "
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
)
return AgentsFoundResponse(

View File

@@ -5,8 +5,8 @@ from typing import Any
from openai.types.chat import ChatCompletionToolParam
from backend.copilot.model import ChatSession
from backend.copilot.response_model import StreamToolOutputAvailable
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.response_model import StreamToolOutputAvailable
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase

View File

@@ -3,7 +3,7 @@
import logging
from typing import Any
from backend.copilot.model import ChatSession
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
@@ -18,7 +18,6 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -99,10 +98,6 @@ class CreateAgentTool(BaseTool):
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not description:
return ErrorResponse(
message="Please provide a description of what the agent should do.",
@@ -224,12 +219,7 @@ class CreateAgentTool(BaseTool):
logger.warning(f"Failed to enrich library agents from steps: {e}")
try:
agent_json = await generate_agent(
decomposition_result,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
agent_json = await generate_agent(decomposition_result, library_agents)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
@@ -273,19 +263,6 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if agent_json.get("status") == "accepted":
logger.info(
f"Agent generation delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent generation started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))

View File

@@ -3,7 +3,7 @@
import logging
from typing import Any
from backend.copilot.model import ChatSession
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
@@ -17,7 +17,6 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -105,10 +104,6 @@ class EditAgentTool(BaseTool):
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
@@ -154,11 +149,7 @@ class EditAgentTool(BaseTool):
try:
result = await generate_agent_patch(
update_request,
current_agent,
library_agents,
operation_id=operation_id,
task_id=task_id,
update_request, current_agent, library_agents
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
@@ -178,20 +169,6 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if result.get("status") == "accepted":
logger.info(
f"Agent edit delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent edit started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")

View File

@@ -2,7 +2,7 @@
from typing import Any
from backend.copilot.model import ChatSession
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
from .base import BaseTool

View File

@@ -3,43 +3,20 @@ from typing import Any
from prisma.enums import ContentType
from backend.blocks import get_block
from backend.blocks._base import BlockType
from backend.copilot.model import ChatSession
from backend.data.db_accessors import search
from .base import BaseTool, ToolResponseBase
from .models import (
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
from backend.api.features.chat.tools.models import (
BlockInfoSummary,
BlockInputFieldInfo,
BlockListResponse,
ErrorResponse,
NoResultsResponse,
)
from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.data.block import get_block
logger = logging.getLogger(__name__)
_TARGET_RESULTS = 10
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
_OVERFETCH_PAGE_SIZE = 40
# Block types that only work within graphs and cannot run standalone in CoPilot.
COPILOT_EXCLUDED_BLOCK_TYPES = {
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
BlockType.NOTE, # Visual annotation only - no runtime behavior
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
}
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
COPILOT_EXCLUDED_BLOCK_IDS = {
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
"3b191d9f-356f-482d-8238-ba04b6d18381",
}
class FindBlockTool(BaseTool):
"""Tool for searching available blocks."""
@@ -55,8 +32,7 @@ class FindBlockTool(BaseTool):
"Blocks are reusable components that perform specific tasks like "
"sending emails, making API calls, processing text, etc. "
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
"The response includes each block's id, name, and description. "
"Call run_block with the block's id **with no inputs** to see detailed inputs/outputs and execute it."
"The response includes each block's id, required_inputs, and input_schema."
)
@property
@@ -108,11 +84,11 @@ class FindBlockTool(BaseTool):
try:
# Search for blocks using hybrid search
results, total = await search().unified_hybrid_search(
results, total = await unified_hybrid_search(
query=query,
content_types=[ContentType.BLOCK],
page=1,
page_size=_OVERFETCH_PAGE_SIZE,
page_size=10,
)
if not results:
@@ -125,44 +101,67 @@ class FindBlockTool(BaseTool):
session_id=session_id,
)
# Enrich results with block information
# Enrich results with full block information
blocks: list[BlockInfoSummary] = []
for result in results:
block_id = result["content_id"]
block = get_block(block_id)
# Skip disabled blocks
if not block or block.disabled:
continue
if block and not block.disabled:
# Get input/output schemas
input_schema = {}
output_schema = {}
try:
input_schema = block.input_schema.jsonschema()
except Exception:
pass
try:
output_schema = block.output_schema.jsonschema()
except Exception:
pass
# Skip blocks excluded from CoPilot (graph-only blocks)
if (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
continue
# Get categories from block instance
categories = []
if hasattr(block, "categories") and block.categories:
categories = [cat.value for cat in block.categories]
blocks.append(
BlockInfoSummary(
id=block_id,
name=block.name,
description=block.description or "",
categories=[c.value for c in block.categories],
# Extract required inputs for easier use
required_inputs: list[BlockInputFieldInfo] = []
if input_schema:
properties = input_schema.get("properties", {})
required_fields = set(input_schema.get("required", []))
# Get credential field names to exclude from required inputs
credentials_fields = set(
block.input_schema.get_credentials_fields().keys()
)
for field_name, field_schema in properties.items():
# Skip credential fields - they're handled separately
if field_name in credentials_fields:
continue
required_inputs.append(
BlockInputFieldInfo(
name=field_name,
type=field_schema.get("type", "string"),
description=field_schema.get("description", ""),
required=field_name in required_fields,
default=field_schema.get("default"),
)
)
blocks.append(
BlockInfoSummary(
id=block_id,
name=block.name,
description=block.description or "",
categories=categories,
input_schema=input_schema,
output_schema=output_schema,
required_inputs=required_inputs,
)
)
)
if len(blocks) >= _TARGET_RESULTS:
break
if blocks and len(blocks) < _TARGET_RESULTS:
logger.debug(
"find_block returned %d/%d results for query '%s' "
"(filtered %d excluded/disabled blocks)",
len(blocks),
_TARGET_RESULTS,
query,
len(results) - len(blocks),
)
if not blocks:
return NoResultsResponse(
@@ -176,7 +175,8 @@ class FindBlockTool(BaseTool):
return BlockListResponse(
message=(
f"Found {len(blocks)} block(s) matching '{query}'. "
"To see a block's inputs/outputs and execute it, use run_block with the block's 'id' - providing no inputs."
"To execute a block, use run_block with the block's 'id' field "
"and provide 'input_data' matching the block's input_schema."
),
blocks=blocks,
count=len(blocks),

View File

@@ -2,7 +2,7 @@
from typing import Any
from backend.copilot.model import ChatSession
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
from .base import BaseTool

View File

@@ -4,10 +4,13 @@ import logging
from pathlib import Path
from typing import Any
from backend.copilot.model import ChatSession
from .base import BaseTool
from .models import DocPageResponse, ErrorResponse, ToolResponseBase
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
DocPageResponse,
ErrorResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)

View File

@@ -25,7 +25,6 @@ class ResponseType(str, Enum):
AGENT_SAVED = "agent_saved"
CLARIFICATION_NEEDED = "clarification_needed"
BLOCK_LIST = "block_list"
BLOCK_DETAILS = "block_details"
BLOCK_OUTPUT = "block_output"
DOC_SEARCH_RESULTS = "doc_search_results"
DOC_PAGE = "doc_page"
@@ -39,17 +38,6 @@ class ResponseType(str, Enum):
OPERATION_STARTED = "operation_started"
OPERATION_PENDING = "operation_pending"
OPERATION_IN_PROGRESS = "operation_in_progress"
# Input validation
INPUT_VALIDATION_ERROR = "input_validation_error"
# Web fetch
WEB_FETCH = "web_fetch"
# Code execution
BASH_EXEC = "bash_exec"
# Operation status check
OPERATION_STATUS = "operation_status"
# Feature request types
FEATURE_REQUEST_SEARCH = "feature_request_search"
FEATURE_REQUEST_CREATED = "feature_request_created"
# Base response model
@@ -80,10 +68,6 @@ class AgentInfo(BaseModel):
has_external_trigger: bool | None = None
new_output: bool | None = None
graph_id: str | None = None
inputs: dict[str, Any] | None = Field(
default=None,
description="Input schema for the agent, including field names, types, and defaults",
)
class AgentsFoundResponse(ToolResponseBase):
@@ -210,20 +194,6 @@ class ErrorResponse(ToolResponseBase):
details: dict[str, Any] | None = None
class InputValidationErrorResponse(ToolResponseBase):
"""Response when run_agent receives unknown input fields."""
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
unrecognized_fields: list[str] = Field(
description="List of input field names that were not recognized"
)
inputs: dict[str, Any] = Field(
description="The agent's valid input schema for reference"
)
graph_id: str | None = None
graph_version: int | None = None
# Agent output models
class ExecutionOutputInfo(BaseModel):
"""Summary of a single execution's outputs."""
@@ -345,17 +315,11 @@ class BlockInfoSummary(BaseModel):
name: str
description: str
categories: list[str]
input_schema: dict[str, Any] = Field(
default_factory=dict,
description="Full JSON schema for block inputs",
)
output_schema: dict[str, Any] = Field(
default_factory=dict,
description="Full JSON schema for block outputs",
)
input_schema: dict[str, Any]
output_schema: dict[str, Any]
required_inputs: list[BlockInputFieldInfo] = Field(
default_factory=list,
description="List of input fields for this block",
description="List of required input fields for this block",
)
@@ -368,29 +332,10 @@ class BlockListResponse(ToolResponseBase):
query: str
usage_hint: str = Field(
default="To execute a block, call run_block with block_id set to the block's "
"'id' field and input_data containing the fields listed in required_inputs."
"'id' field and input_data containing the required fields from input_schema."
)
class BlockDetails(BaseModel):
"""Detailed block information."""
id: str
name: str
description: str
inputs: dict[str, Any] = {}
outputs: dict[str, Any] = {}
credentials: list[CredentialsMetaInput] = []
class BlockDetailsResponse(ToolResponseBase):
"""Response for block details (first run_block attempt)."""
type: ResponseType = ResponseType.BLOCK_DETAILS
block: BlockDetails
user_authenticated: bool = False
class BlockOutputResponse(ToolResponseBase):
"""Response for run_block tool."""
@@ -407,15 +352,11 @@ class OperationStartedResponse(ToolResponseBase):
This is returned immediately to the client while the operation continues
to execute. The user can close the tab and check back later.
The task_id can be used to reconnect to the SSE stream via
GET /chat/tasks/{task_id}/stream?last_idx=0
"""
type: ResponseType = ResponseType.OPERATION_STARTED
operation_id: str
tool_name: str
task_id: str | None = None # For SSE reconnection
class OperationPendingResponse(ToolResponseBase):
@@ -439,72 +380,3 @@ class OperationInProgressResponse(ToolResponseBase):
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
tool_call_id: str
class AsyncProcessingResponse(ToolResponseBase):
"""Response when an operation has been delegated to async processing.
This is returned by tools when the external service accepts the request
for async processing (HTTP 202 Accepted). The Redis Streams completion
consumer will handle the result when the external service completes.
The status field is specifically "accepted" to allow the long-running tool
handler to detect this response and skip LLM continuation.
"""
type: ResponseType = ResponseType.OPERATION_STARTED
status: str = "accepted" # Must be "accepted" for detection
operation_id: str | None = None
task_id: str | None = None
class WebFetchResponse(ToolResponseBase):
"""Response for web_fetch tool."""
type: ResponseType = ResponseType.WEB_FETCH
url: str
status_code: int
content_type: str
content: str
truncated: bool = False
class BashExecResponse(ToolResponseBase):
"""Response for bash_exec tool."""
type: ResponseType = ResponseType.BASH_EXEC
stdout: str
stderr: str
exit_code: int
timed_out: bool = False
# Feature request models
class FeatureRequestInfo(BaseModel):
"""Information about a feature request issue."""
id: str
identifier: str
title: str
description: str | None = None
class FeatureRequestSearchResponse(ToolResponseBase):
"""Response for search_feature_requests tool."""
type: ResponseType = ResponseType.FEATURE_REQUEST_SEARCH
results: list[FeatureRequestInfo]
count: int
query: str
class FeatureRequestCreatedResponse(ToolResponseBase):
"""Response for create_feature_request tool."""
type: ResponseType = ResponseType.FEATURE_REQUEST_CREATED
issue_id: str
issue_identifier: str
issue_title: str
issue_url: str
is_new_issue: bool # False if added to existing
customer_name: str

View File

@@ -5,13 +5,16 @@ from typing import Any
from pydantic import BaseModel, Field, field_validator
from backend.copilot.config import ChatConfig
from backend.copilot.model import ChatSession
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
from backend.data.db_accessors import graph_db, library_db, user_db
from backend.data.execution import ExecutionStatus
from backend.api.features.chat.config import ChatConfig
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tracking import (
track_agent_run_success,
track_agent_scheduled,
)
from backend.api.features.library import db as library_db
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.data.user import get_user_by_id
from backend.executor import utils as execution_utils
from backend.util.clients import get_scheduler_client
from backend.util.exceptions import DatabaseError, NotFoundError
@@ -21,21 +24,17 @@ from backend.util.timezone_utils import (
)
from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import (
AgentDetails,
AgentDetailsResponse,
AgentOutputResponse,
ErrorResponse,
ExecutionOptions,
ExecutionStartedResponse,
InputValidationErrorResponse,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
UserReadiness,
)
from .execution_utils import get_execution_outputs, wait_for_execution
from .utils import (
build_missing_credentials_from_graph,
extract_credentials_from_schema,
@@ -69,7 +68,6 @@ class RunAgentInput(BaseModel):
schedule_name: str = ""
cron: str = ""
timezone: str = "UTC"
wait_for_result: int = Field(default=0, ge=0, le=300)
@field_validator(
"username_agent_slug",
@@ -151,14 +149,6 @@ class RunAgentTool(BaseTool):
"type": "string",
"description": "IANA timezone for schedule (default: UTC)",
},
"wait_for_result": {
"type": "integer",
"description": (
"Max seconds to wait for execution to complete (0-300). "
"If >0, blocks until the execution finishes or times out. "
"Returns execution outputs when complete."
),
},
},
"required": [],
}
@@ -208,7 +198,7 @@ class RunAgentTool(BaseTool):
# Priority: library_agent_id if provided
if has_library_id:
library_agent = await library_db().get_library_agent(
library_agent = await library_db.get_library_agent(
params.library_agent_id, user_id
)
if not library_agent:
@@ -217,7 +207,9 @@ class RunAgentTool(BaseTool):
session_id=session_id,
)
# Get the graph from the library agent
graph = await graph_db().get_graph(
from backend.data.graph import get_graph
graph = await get_graph(
library_agent.graph_id,
library_agent.graph_version,
user_id=user_id,
@@ -268,7 +260,7 @@ class RunAgentTool(BaseTool):
),
requirements={
"credentials": requirements_creds_list,
"inputs": get_inputs_from_schema(graph.input_schema),
"inputs": self._get_inputs_list(graph.input_schema),
"execution_modes": self._get_execution_modes(graph),
},
),
@@ -281,22 +273,6 @@ class RunAgentTool(BaseTool):
input_properties = graph.input_schema.get("properties", {})
required_fields = set(graph.input_schema.get("required", []))
provided_inputs = set(params.inputs.keys())
valid_fields = set(input_properties.keys())
# Check for unknown input fields
unrecognized_fields = provided_inputs - valid_fields
if unrecognized_fields:
return InputValidationErrorResponse(
message=(
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
f"Agent was not executed. Please use the correct field names from the schema."
),
session_id=session_id,
unrecognized_fields=sorted(unrecognized_fields),
inputs=graph.input_schema,
graph_id=graph.id,
graph_version=graph.version,
)
# If agent has inputs but none were provided AND use_defaults is not set,
# always show what's available first so user can decide
@@ -353,7 +329,6 @@ class RunAgentTool(BaseTool):
graph=graph,
graph_credentials=graph_credentials,
inputs=params.inputs,
wait_for_result=params.wait_for_result,
)
except NotFoundError as e:
@@ -377,6 +352,22 @@ class RunAgentTool(BaseTool):
session_id=session_id,
)
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract inputs list from schema."""
inputs_list = []
if isinstance(input_schema, dict) and "properties" in input_schema:
for field_name, field_schema in input_schema["properties"].items():
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in input_schema.get("required", []),
}
)
return inputs_list
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
"""Get available execution modes for the graph."""
trigger_info = graph.trigger_setup_info
@@ -390,7 +381,7 @@ class RunAgentTool(BaseTool):
suffix: str,
) -> str:
"""Build a message describing available inputs for an agent."""
inputs_list = get_inputs_from_schema(graph.input_schema)
inputs_list = self._get_inputs_list(graph.input_schema)
required_names = [i["name"] for i in inputs_list if i["required"]]
optional_names = [i["name"] for i in inputs_list if not i["required"]]
@@ -437,9 +428,8 @@ class RunAgentTool(BaseTool):
graph: GraphModel,
graph_credentials: dict[str, CredentialsMetaInput],
inputs: dict[str, Any],
wait_for_result: int = 0,
) -> ToolResponseBase:
"""Execute an agent immediately, optionally waiting for completion."""
"""Execute an agent immediately."""
session_id = session.session_id
# Check rate limits
@@ -476,56 +466,6 @@ class RunAgentTool(BaseTool):
)
library_agent_link = f"/library/agents/{library_agent.id}"
# If wait_for_result is requested, wait for execution to complete
if wait_for_result > 0:
logger.info(
f"Waiting up to {wait_for_result}s for execution {execution.id}"
)
completed = await wait_for_execution(
user_id=user_id,
graph_id=library_agent.graph_id,
execution_id=execution.id,
timeout_seconds=wait_for_result,
)
if completed and completed.status == ExecutionStatus.COMPLETED:
outputs = get_execution_outputs(completed)
return AgentOutputResponse(
message=(
f"Agent '{library_agent.name}' completed successfully. "
f"View at {library_agent_link}."
),
session_id=session_id,
execution_id=execution.id,
graph_id=library_agent.graph_id,
graph_name=library_agent.name,
outputs=outputs or {},
)
elif completed and completed.status == ExecutionStatus.FAILED:
return ErrorResponse(
message=(
f"Agent '{library_agent.name}' execution failed. "
f"View details at {library_agent_link}."
),
session_id=session_id,
)
else:
status = completed.status.value if completed else "unknown"
return ExecutionStartedResponse(
message=(
f"Agent '{library_agent.name}' is still {status} after "
f"{wait_for_result}s. Check results later at {library_agent_link}. "
f"Use view_agent_output with wait_if_running to check again."
),
session_id=session_id,
execution_id=execution.id,
graph_id=library_agent.graph_id,
graph_name=library_agent.name,
library_agent_id=library_agent.id,
library_agent_link=library_agent_link,
)
return ExecutionStartedResponse(
message=(
f"Agent '{library_agent.name}' execution started successfully. "
@@ -580,7 +520,7 @@ class RunAgentTool(BaseTool):
library_agent = await get_or_create_library_agent(graph, user_id)
# Get user timezone
user = await user_db().get_user_by_id(user_id)
user = await get_user_by_id(user_id)
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
# Create schedule

View File

@@ -402,42 +402,3 @@ async def test_run_agent_schedule_without_name(setup_test_data):
# Should return error about missing schedule_name
assert result_data.get("type") == "error"
assert "schedule_name" in result_data["message"].lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
"""Test that run_agent returns input_validation_error for unknown input fields."""
user = setup_test_data["user"]
store_submission = setup_test_data["store_submission"]
tool = RunAgentTool()
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
session = make_session(user_id=user.id)
# Execute with unknown input field names
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={
"unknown_field": "some value",
"another_unknown": "another value",
},
session=session,
)
assert response is not None
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
# Should return input_validation_error type with unrecognized fields
assert result_data.get("type") == "input_validation_error"
assert "unrecognized_fields" in result_data
assert set(result_data["unrecognized_fields"]) == {
"another_unknown",
"unknown_field",
}
assert "inputs" in result_data # Contains the valid schema
assert "Agent was not executed" in result_data["message"]

View File

@@ -5,35 +5,24 @@ import uuid
from collections import defaultdict
from typing import Any
from pydantic_core import PydanticUndefined
from backend.blocks import get_block
from backend.blocks._base import AnyBlockSchema
from backend.copilot.model import ChatSession
from backend.data.db_accessors import workspace_db
from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.data.model import CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from .base import BaseTool
from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES
from .helpers import get_inputs_from_schema
from .models import (
BlockDetails,
BlockDetailsResponse,
BlockOutputResponse,
ErrorResponse,
InputValidationErrorResponse,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
UserReadiness,
)
from .utils import (
build_missing_credentials_from_field_info,
match_credentials_to_requirements,
)
from .utils import build_missing_credentials_from_field_info
logger = logging.getLogger(__name__)
@@ -51,8 +40,8 @@ class RunBlockTool(BaseTool):
"Execute a specific block with the provided input data. "
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
"do NOT guess or make up block IDs. "
"On first attempt (without input_data), returns detailed schema showing "
"required inputs and outputs. Then call again with proper input_data to execute."
"Use the 'id' from find_block results and provide input_data "
"matching the block's required_inputs."
)
@property
@@ -67,19 +56,11 @@ class RunBlockTool(BaseTool):
"NEVER guess this - always get it from find_block first."
),
},
"block_name": {
"type": "string",
"description": (
"The block's human-readable name from find_block results. "
"Used for display purposes in the UI."
),
},
"input_data": {
"type": "object",
"description": (
"Input values for the block. "
"First call with empty {} to see the block's schema, "
"then call again with proper values to execute."
"Input values for the block. Use the 'required_inputs' field "
"from find_block to see what fields are needed."
),
},
},
@@ -90,6 +71,65 @@ class RunBlockTool(BaseTool):
def requires_auth(self) -> bool:
return True
async def _check_block_credentials(
self,
user_id: str,
block: Any,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Returns:
tuple[matched_credentials, missing_credentials]
"""
matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
# Get credential field info from block's input schema
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return matched_credentials, missing_credentials
# Get user's available credentials
creds_manager = IntegrationCredentialsManager()
available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items():
# field_info.provider is a frozenset of acceptable providers
# field_info.supported_types is a frozenset of acceptable types
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in field_info.provider
and cred.type in field_info.supported_types
),
None,
)
if matching_cred:
matched_credentials[field_name] = CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
else:
# Create a placeholder for the missing credential
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched_credentials, missing_credentials
async def _execute(
self,
user_id: str | None,
@@ -144,54 +184,14 @@ class RunBlockTool(BaseTool):
session_id=session_id,
)
# Check if block is excluded from CoPilot (graph-only blocks)
if (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
return ErrorResponse(
message=(
f"Block '{block.name}' cannot be run directly in CoPilot. "
"This block is designed for use within graphs only."
),
session_id=session_id,
)
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
# Check credentials
creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = (
await self._resolve_block_credentials(user_id, block, input_data)
matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block
)
# Get block schemas for details/validation
try:
input_schema: dict[str, Any] = block.input_schema.jsonschema()
except Exception as e:
logger.warning(
"Failed to generate input schema for block %s: %s",
block_id,
e,
)
return ErrorResponse(
message=f"Block '{block.name}' has an invalid input schema",
error=str(e),
session_id=session_id,
)
try:
output_schema: dict[str, Any] = block.output_schema.jsonschema()
except Exception as e:
logger.warning(
"Failed to generate output schema for block %s: %s",
block_id,
e,
)
return ErrorResponse(
message=f"Block '{block.name}' has an invalid output schema",
error=str(e),
session_id=session_id,
)
if missing_credentials:
# Return setup requirements response with missing credentials
credentials_fields_info = block.input_schema.get_credentials_fields_info()
@@ -224,56 +224,9 @@ class RunBlockTool(BaseTool):
graph_version=None,
)
# Check if this is a first attempt (required inputs missing)
# Return block details so user can see what inputs are needed
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
required_keys = set(input_schema.get("required", []))
required_non_credential_keys = required_keys - credentials_fields
provided_input_keys = set(input_data.keys()) - credentials_fields
# Check for unknown input fields
valid_fields = (
set(input_schema.get("properties", {}).keys()) - credentials_fields
)
unrecognized_fields = provided_input_keys - valid_fields
if unrecognized_fields:
return InputValidationErrorResponse(
message=(
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
f"Block was not executed. Please use the correct field names from the schema."
),
session_id=session_id,
unrecognized_fields=sorted(unrecognized_fields),
inputs=input_schema,
)
# Show details when not all required non-credential inputs are provided
if not (required_non_credential_keys <= provided_input_keys):
# Get credentials info for the response
credentials_meta = []
for field_name, cred_meta in matched_credentials.items():
credentials_meta.append(cred_meta)
return BlockDetailsResponse(
message=(
f"Block '{block.name}' details. "
"Provide input_data matching the inputs schema to execute the block."
),
session_id=session_id,
block=BlockDetails(
id=block_id,
name=block.name,
description=block.description or "",
inputs=input_schema,
outputs=output_schema,
credentials=credentials_meta,
),
user_authenticated=True,
)
try:
# Get or create user's workspace for CoPilot file operations
workspace = await workspace_db().get_or_create_workspace(user_id)
workspace = await get_or_create_workspace(user_id)
# Generate synthetic IDs for CoPilot context
# Each chat session is treated as its own agent with one continuous run
@@ -365,75 +318,29 @@ class RunBlockTool(BaseTool):
session_id=session_id,
)
async def _resolve_block_credentials(
self,
user_id: str,
block: AnyBlockSchema,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Resolve credentials for a block by matching user's available credentials.
Args:
user_id: User ID
block: Block to resolve credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple of (matched_credentials, missing_credentials) - matched credentials
are used for block execution, missing ones indicate setup requirements.
"""
input_data = input_data or {}
requirements = self._resolve_discriminated_credentials(block, input_data)
if not requirements:
return {}, []
return await match_credentials_to_requirements(user_id, requirements)
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
"""Extract non-credential inputs from block schema."""
inputs_list = []
schema = block.input_schema.jsonschema()
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
# Get credential field names to exclude
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
def _resolve_discriminated_credentials(
self,
block: AnyBlockSchema,
input_data: dict[str, Any],
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed."""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return {}
for field_name, field_schema in properties.items():
# Skip credential fields
if field_name in credentials_fields:
continue
resolved: dict[str, CredentialsFieldInfo] = {}
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in required_fields,
}
)
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
# For host-scoped credentials, add the discriminator value
# (e.g., URL) so _credential_is_for_host can match it
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
resolved[field_name] = effective_field_info
return resolved
return inputs_list

View File

@@ -5,17 +5,16 @@ from typing import Any
from prisma.enums import ContentType
from backend.copilot.model import ChatSession
from backend.data.db_accessors import search
from .base import BaseTool
from .models import (
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
DocSearchResult,
DocSearchResultsResponse,
ErrorResponse,
NoResultsResponse,
ToolResponseBase,
)
from backend.api.features.store.hybrid_search import unified_hybrid_search
logger = logging.getLogger(__name__)
@@ -118,7 +117,7 @@ class SearchDocsTool(BaseTool):
try:
# Search using hybrid search for DOCUMENTATION content type only
results, total = await search().unified_hybrid_search(
results, total = await unified_hybrid_search(
query=query,
content_types=[ContentType.DOCUMENTATION],
page=1,

View File

@@ -3,18 +3,13 @@
import logging
from typing import Any
from backend.api.features.library import db as library_db
from backend.api.features.library import model as library_model
from backend.data.db_accessors import library_db, store_db
from backend.api.features.store import db as store_db
from backend.data import graph as graph_db
from backend.data.graph import GraphModel
from backend.data.model import (
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
HostScopedCredentials,
OAuth2Credentials,
)
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.exceptions import NotFoundError
logger = logging.getLogger(__name__)
@@ -38,15 +33,20 @@ async def fetch_graph_from_store_slug(
Raises:
DatabaseError: If there's a database error during lookup.
"""
sdb = store_db()
try:
store_agent = await sdb.get_store_agent_details(username, agent_name)
store_agent = await store_db.get_store_agent_details(username, agent_name)
except NotFoundError:
return None, None
# Get the graph from store listing version
graph = await sdb.get_available_graph(
store_agent.store_listing_version_id, hide_nodes=False
graph_meta = await store_db.get_available_graph(
store_agent.store_listing_version_id
)
graph = await graph_db.get_graph(
graph_id=graph_meta.id,
version=graph_meta.version,
user_id=None, # Public access
include_subgraphs=True,
)
return graph, store_agent
@@ -123,7 +123,7 @@ def build_missing_credentials_from_graph(
return {
field_key: _serialize_missing_credential(field_key, field_info)
for field_key, (field_info, _, _) in aggregated_fields.items()
for field_key, (field_info, _node_fields) in aggregated_fields.items()
if field_key not in matched_keys
}
@@ -210,13 +210,13 @@ async def get_or_create_library_agent(
Returns:
LibraryAgent instance
"""
existing = await library_db().get_library_agent_by_graph_id(
existing = await library_db.get_library_agent_by_graph_id(
graph_id=graph.id, user_id=user_id
)
if existing:
return existing
library_agents = await library_db().create_library_agent(
library_agents = await library_db.create_library_agent(
graph=graph,
user_id=user_id,
create_library_agents_for_sub_graphs=False,
@@ -225,99 +225,6 @@ async def get_or_create_library_agent(
return library_agents[0]
async def match_credentials_to_requirements(
user_id: str,
requirements: dict[str, CredentialsFieldInfo],
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Match user's credentials against a dictionary of credential requirements.
This is the core matching logic shared by both graph and block credential matching.
"""
matched: dict[str, CredentialsMetaInput] = {}
missing: list[CredentialsMetaInput] = []
if not requirements:
return matched, missing
available_creds = await get_user_credentials(user_id)
for field_name, field_info in requirements.items():
matching_cred = find_matching_credential(available_creds, field_info)
if matching_cred:
try:
matched[field_name] = create_credential_meta_from_match(matching_cred)
except Exception as e:
logger.error(
f"Failed to create CredentialsMetaInput for field '{field_name}': "
f"provider={matching_cred.provider}, type={matching_cred.type}, "
f"credential_id={matching_cred.id}",
exc_info=True,
)
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=f"{field_name} (validation failed: {e})",
)
)
else:
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched, missing
async def get_user_credentials(user_id: str) -> list[Credentials]:
"""Get all available credentials for a user."""
creds_manager = IntegrationCredentialsManager()
return await creds_manager.store.get_all_creds(user_id)
def find_matching_credential(
available_creds: list[Credentials],
field_info: CredentialsFieldInfo,
) -> Credentials | None:
"""Find a credential that matches the required provider, type, scopes, and host."""
for cred in available_creds:
if cred.provider not in field_info.provider:
continue
if cred.type not in field_info.supported_types:
continue
if cred.type == "oauth2" and not _credential_has_required_scopes(
cred, field_info
):
continue
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
continue
return cred
return None
def create_credential_meta_from_match(
matching_cred: Credentials,
) -> CredentialsMetaInput:
"""Create a CredentialsMetaInput from a matched credential."""
return CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
async def match_user_credentials_to_graph(
user_id: str,
graph: GraphModel,
@@ -357,28 +264,16 @@ async def match_user_credentials_to_graph(
# provider is in the set of acceptable providers.
for credential_field_name, (
credential_requirements,
_,
_,
_node_fields,
) in aggregated_creds.items():
# Find first matching credential by provider, type, scopes, and host/URL
# Find first matching credential by provider, type, and scopes
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in credential_requirements.provider
and cred.type in credential_requirements.supported_types
and (
cred.type != "oauth2"
or _credential_has_required_scopes(cred, credential_requirements)
)
and (
cred.type != "host_scoped"
or _credential_is_for_host(cred, credential_requirements)
)
and (
cred.provider != ProviderName.MCP
or _credential_is_for_mcp_server(cred, credential_requirements)
)
and _credential_has_required_scopes(cred, credential_requirements)
),
None,
)
@@ -423,46 +318,25 @@ async def match_user_credentials_to_graph(
def _credential_has_required_scopes(
credential: OAuth2Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if an OAuth2 credential has all the scopes required by the input."""
# If no scopes are required, any credential matches
if not requirements.required_scopes:
return True
return set(credential.scopes).issuperset(requirements.required_scopes)
def _credential_is_for_host(
credential: HostScopedCredentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if a host-scoped credential matches the host required by the input."""
# We need to know the host to match host-scoped credentials to.
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
# to discriminator_values. No discriminator_values -> no host to match against.
if not requirements.discriminator_values:
return True
# Check that credential host matches required host.
# Host-scoped credential inputs are grouped by host, so any item from the set works.
return credential.matches_url(list(requirements.discriminator_values)[0])
def _credential_is_for_mcp_server(
credential: Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if an MCP OAuth credential matches the required server URL."""
if not requirements.discriminator_values:
"""
Check if a credential has all the scopes required by the block.
For OAuth2 credentials, verifies that the credential's scopes are a superset
of the required scopes. For other credential types, returns True (no scope check).
"""
# Only OAuth2 credentials have scopes to check
if credential.type != "oauth2":
return True
server_url = (
credential.metadata.get("mcp_server_url")
if isinstance(credential, OAuth2Credentials)
else None
)
return server_url in requirements.discriminator_values if server_url else False
# If no scopes are required, any credential matches
if not requirements.required_scopes:
return True
# Check that credential scopes are a superset of required scopes
return set(credential.scopes).issuperset(requirements.required_scopes)
async def check_user_has_required_credentials(

View File

@@ -6,8 +6,8 @@ from typing import Any, Optional
from pydantic import BaseModel
from backend.copilot.model import ChatSession
from backend.data.db_accessors import workspace_db
from backend.api.features.chat.model import ChatSession
from backend.data.workspace import get_or_create_workspace
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
@@ -88,9 +88,7 @@ class ListWorkspaceFilesTool(BaseTool):
@property
def description(self) -> str:
return (
"List files in the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Read/Glob tools instead. "
"List files in the user's workspace. "
"Returns file names, paths, sizes, and metadata. "
"Optionally filter by path prefix."
)
@@ -148,7 +146,7 @@ class ListWorkspaceFilesTool(BaseTool):
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try:
workspace = await workspace_db().get_or_create_workspace(user_id)
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
@@ -167,8 +165,8 @@ class ListWorkspaceFilesTool(BaseTool):
file_id=f.id,
name=f.name,
path=f.path,
mime_type=f.mime_type,
size_bytes=f.size_bytes,
mime_type=f.mimeType,
size_bytes=f.sizeBytes,
)
for f in files
]
@@ -206,9 +204,7 @@ class ReadWorkspaceFileTool(BaseTool):
@property
def description(self) -> str:
return (
"Read a file from the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Read tool instead. "
"Read a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"For small text files, returns content directly. "
"For large or binary files, returns metadata and a download URL. "
@@ -284,7 +280,7 @@ class ReadWorkspaceFileTool(BaseTool):
)
try:
workspace = await workspace_db().get_or_create_workspace(user_id)
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
@@ -309,8 +305,8 @@ class ReadWorkspaceFileTool(BaseTool):
target_file_id = file_info.id
# Decide whether to return inline content or metadata+URL
is_small_file = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
is_text_file = self._is_text_mime_type(file_info.mime_type)
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
is_text_file = self._is_text_mime_type(file_info.mimeType)
# Return inline content for small text files (unless force_download_url)
if is_small_file and is_text_file and not force_download_url:
@@ -321,7 +317,7 @@ class ReadWorkspaceFileTool(BaseTool):
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mime_type,
mime_type=file_info.mimeType,
content_base64=content_b64,
message=f"Successfully read file: {file_info.name}",
session_id=session_id,
@@ -350,11 +346,11 @@ class ReadWorkspaceFileTool(BaseTool):
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mime_type,
size_bytes=file_info.size_bytes,
mime_type=file_info.mimeType,
size_bytes=file_info.sizeBytes,
download_url=download_url,
preview=preview,
message=f"File: {file_info.name} ({file_info.size_bytes} bytes). Use download_url to retrieve content.",
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
session_id=session_id,
)
@@ -382,9 +378,7 @@ class WriteWorkspaceFileTool(BaseTool):
@property
def description(self) -> str:
return (
"Write or create a file in the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Write tool instead. "
"Write or create a file in the user's workspace. "
"Provide the content as a base64-encoded string. "
f"Maximum file size is {Config().max_file_size_mb}MB. "
"Files are saved to the current session's folder by default. "
@@ -484,7 +478,7 @@ class WriteWorkspaceFileTool(BaseTool):
# Virus scan
await scan_content_safe(content, filename=filename)
workspace = await workspace_db().get_or_create_workspace(user_id)
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
@@ -500,7 +494,7 @@ class WriteWorkspaceFileTool(BaseTool):
file_id=file_record.id,
name=file_record.name,
path=file_record.path,
size_bytes=file_record.size_bytes,
size_bytes=file_record.sizeBytes,
message=f"Successfully wrote file: {file_record.name}",
session_id=session_id,
)
@@ -529,7 +523,7 @@ class DeleteWorkspaceFileTool(BaseTool):
@property
def description(self) -> str:
return (
"Delete a file from the user's persistent workspace (cloud storage). "
"Delete a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
@@ -583,7 +577,7 @@ class DeleteWorkspaceFileTool(BaseTool):
)
try:
workspace = await workspace_db().get_or_create_workspace(user_id)
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)

View File

@@ -1,7 +1,7 @@
import asyncio
import logging
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated, Any, List, Literal
from typing import TYPE_CHECKING, Annotated, List, Literal
from autogpt_libs.auth import get_user_id
from fastapi import (
@@ -14,7 +14,7 @@ from fastapi import (
Security,
status,
)
from pydantic import BaseModel, Field, SecretStr, model_validator
from pydantic import BaseModel, Field, SecretStr
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
from backend.api.features.library.db import set_preset_webhook, update_preset
@@ -39,11 +39,7 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.credentials_store import provider_matches
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
@@ -106,37 +102,9 @@ class CredentialsMetaResponse(BaseModel):
scopes: list[str] | None
username: str | None
host: str | None = Field(
default=None,
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
default=None, description="Host pattern for host-scoped credentials"
)
@model_validator(mode="before")
@classmethod
def _normalize_provider(cls, data: Any) -> Any:
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug."""
if isinstance(data, dict):
prov = data.get("provider", "")
if isinstance(prov, str) and prov.startswith("ProviderName."):
member = prov.removeprefix("ProviderName.")
try:
data = {**data, "provider": ProviderName[member].value}
except KeyError:
pass
return data
@staticmethod
def get_host(cred: Credentials) -> str | None:
"""Extract host from credential: HostScoped host or MCP server URL."""
if isinstance(cred, HostScopedCredentials):
return cred.host
if isinstance(cred, OAuth2Credentials) and cred.provider in (
ProviderName.MCP,
ProviderName.MCP.value,
"ProviderName.MCP",
):
return (cred.metadata or {}).get("mcp_server_url")
return None
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
async def callback(
@@ -211,7 +179,9 @@ async def callback(
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=(CredentialsMetaResponse.get_host(credentials)),
host=(
credentials.host if isinstance(credentials, HostScopedCredentials) else None
),
)
@@ -229,7 +199,7 @@ async def list_credentials(
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@@ -252,7 +222,7 @@ async def list_credentials_by_provider(
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@@ -352,11 +322,7 @@ async def delete_credentials(
tokens_revoked = None
if isinstance(creds, OAuth2Credentials):
if provider_matches(provider.value, ProviderName.MCP.value):
# MCP uses dynamic per-server OAuth — create handler from metadata
handler = create_mcp_oauth_handler(creds)
else:
handler = _get_provider_oauth_handler(request, provider)
handler = _get_provider_oauth_handler(request, provider)
tokens_revoked = await handler.revoke_tokens(creds)
return CredentialsDeletionResponse(revoked=tokens_revoked)

View File

@@ -12,16 +12,14 @@ import backend.api.features.store.image_gen as store_image_gen
import backend.api.features.store.media as store_media
import backend.data.graph as graph_db
import backend.data.integrations as integrations_db
from backend.data.block import BlockInput
from backend.data.db import transaction
from backend.data.execution import get_graph_execution
from backend.data.graph import GraphSettings
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
from backend.data.model import CredentialsMetaInput, GraphInput
from backend.data.model import CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import (
on_graph_activate,
on_graph_deactivate,
)
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.clients import get_scheduler_client
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
from backend.util.json import SafeJson
@@ -373,7 +371,7 @@ async def get_library_agent_by_graph_id(
async def add_generated_agent_image(
graph: graph_db.GraphBaseMeta,
graph: graph_db.BaseGraph,
user_id: str,
library_agent_id: str,
) -> Optional[prisma.models.LibraryAgent]:
@@ -539,92 +537,6 @@ async def update_agent_version_in_library(
return library_model.LibraryAgent.from_db(lib)
async def create_graph_in_library(
graph: graph_db.Graph,
user_id: str,
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
"""Create a new graph and add it to the user's library."""
graph.version = 1
graph_model = graph_db.make_graph_model(graph, user_id)
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=True)
created_graph = await graph_db.create_graph(graph_model, user_id)
library_agents = await create_library_agent(
graph=created_graph,
user_id=user_id,
sensitive_action_safe_mode=True,
create_library_agents_for_sub_graphs=False,
)
if created_graph.is_active:
created_graph = await on_graph_activate(created_graph, user_id=user_id)
return created_graph, library_agents[0]
async def update_graph_in_library(
graph: graph_db.Graph,
user_id: str,
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
"""Create a new version of an existing graph and update the library entry."""
existing_versions = await graph_db.get_graph_all_versions(graph.id, user_id)
current_active_version = (
next((v for v in existing_versions if v.is_active), None)
if existing_versions
else None
)
graph.version = (
max(v.version for v in existing_versions) + 1 if existing_versions else 1
)
graph_model = graph_db.make_graph_model(graph, user_id)
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=False)
created_graph = await graph_db.create_graph(graph_model, user_id)
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
if not library_agent:
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
library_agent = await update_library_agent_version_and_settings(
user_id, created_graph
)
if created_graph.is_active:
created_graph = await on_graph_activate(created_graph, user_id=user_id)
await graph_db.set_graph_active_version(
graph_id=created_graph.id,
version=created_graph.version,
user_id=user_id,
)
if current_active_version:
await on_graph_deactivate(current_active_version, user_id=user_id)
return created_graph, library_agent
async def update_library_agent_version_and_settings(
user_id: str, agent_graph: graph_db.GraphModel
) -> library_model.LibraryAgent:
"""Update library agent to point to new graph version and sync settings."""
library = await update_agent_version_in_library(
user_id, agent_graph.id, agent_graph.version
)
updated_settings = GraphSettings.from_graph(
graph=agent_graph,
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
)
if updated_settings != library.settings:
library = await update_library_agent(
library_agent_id=library.id,
user_id=user_id,
settings=updated_settings,
)
return library
async def update_library_agent(
library_agent_id: str,
user_id: str,
@@ -1129,7 +1041,7 @@ async def create_preset_from_graph_execution(
async def update_preset(
user_id: str,
preset_id: str,
inputs: Optional[GraphInput] = None,
inputs: Optional[BlockInput] = None,
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
name: Optional[str] = None,
description: Optional[str] = None,

View File

@@ -6,12 +6,9 @@ import prisma.enums
import prisma.models
import pydantic
from backend.data.block import BlockInput
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
from backend.data.model import (
CredentialsMetaInput,
GraphInput,
is_credentials_field_name,
)
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
from backend.util.json import loads as json_loads
from backend.util.models import Pagination
@@ -326,7 +323,7 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
graph_id: str
graph_version: int
inputs: GraphInput
inputs: BlockInput
credentials: dict[str, CredentialsMetaInput]
name: str
@@ -355,7 +352,7 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
Request model used when updating a preset for a library agent.
"""
inputs: Optional[GraphInput] = None
inputs: Optional[BlockInput] = None
credentials: Optional[dict[str, CredentialsMetaInput]] = None
name: Optional[str] = None
@@ -398,7 +395,7 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
"Webhook must be included in AgentPreset query when webhookId is set"
)
input_data: GraphInput = {}
input_data: BlockInput = {}
input_credentials: dict[str, CredentialsMetaInput] = {}
for preset_input in preset.InputPresets:

View File

@@ -1,404 +0,0 @@
"""
MCP (Model Context Protocol) API routes.
Provides endpoints for MCP tool discovery and OAuth authentication so the
frontend can list available tools on an MCP server before placing a block.
"""
import logging
from typing import Annotated, Any
from urllib.parse import urlparse
import fastapi
from autogpt_libs.auth import get_user_id
from fastapi import Security
from pydantic import BaseModel, Field
from backend.api.features.integrations.router import CredentialsMetaResponse
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
router = fastapi.APIRouter(tags=["mcp"])
creds_manager = IntegrationCredentialsManager()
# ====================== Tool Discovery ====================== #
class DiscoverToolsRequest(BaseModel):
"""Request to discover tools on an MCP server."""
server_url: str = Field(description="URL of the MCP server")
auth_token: str | None = Field(
default=None,
description="Optional Bearer token for authenticated MCP servers",
)
class MCPToolResponse(BaseModel):
"""A single MCP tool returned by discovery."""
name: str
description: str
input_schema: dict[str, Any]
class DiscoverToolsResponse(BaseModel):
"""Response containing the list of tools available on an MCP server."""
tools: list[MCPToolResponse]
server_name: str | None = None
protocol_version: str | None = None
@router.post(
"/discover-tools",
summary="Discover available tools on an MCP server",
response_model=DiscoverToolsResponse,
)
async def discover_tools(
request: DiscoverToolsRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> DiscoverToolsResponse:
"""
Connect to an MCP server and return its available tools.
If the user has a stored MCP credential for this server URL, it will be
used automatically — no need to pass an explicit auth token.
"""
auth_token = request.auth_token
# Auto-use stored MCP credential when no explicit token is provided.
if not auth_token:
mcp_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
# Find the freshest credential for this server URL
best_cred: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
):
if best_cred is None or (
(cred.access_token_expires_at or 0)
> (best_cred.access_token_expires_at or 0)
):
best_cred = cred
if best_cred:
# Refresh the token if expired before using it
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
logger.info(
f"Using MCP credential {best_cred.id} for {request.server_url}, "
f"expires_at={best_cred.access_token_expires_at}"
)
auth_token = best_cred.access_token.get_secret_value()
client = MCPClient(request.server_url, auth_token=auth_token)
try:
init_result = await client.initialize()
tools = await client.list_tools()
except HTTPClientError as e:
if e.status_code in (401, 403):
raise fastapi.HTTPException(
status_code=401,
detail="This MCP server requires authentication. "
"Please provide a valid auth token.",
)
raise fastapi.HTTPException(status_code=502, detail=str(e))
except MCPClientError as e:
raise fastapi.HTTPException(status_code=502, detail=str(e))
except Exception as e:
raise fastapi.HTTPException(
status_code=502,
detail=f"Failed to connect to MCP server: {e}",
)
return DiscoverToolsResponse(
tools=[
MCPToolResponse(
name=t.name,
description=t.description,
input_schema=t.input_schema,
)
for t in tools
],
server_name=(
init_result.get("serverInfo", {}).get("name")
or urlparse(request.server_url).hostname
or "MCP"
),
protocol_version=init_result.get("protocolVersion"),
)
# ======================== OAuth Flow ======================== #
class MCPOAuthLoginRequest(BaseModel):
"""Request to start an OAuth flow for an MCP server."""
server_url: str = Field(description="URL of the MCP server that requires OAuth")
class MCPOAuthLoginResponse(BaseModel):
"""Response with the OAuth login URL for the user to authenticate."""
login_url: str
state_token: str
@router.post(
"/oauth/login",
summary="Initiate OAuth login for an MCP server",
)
async def mcp_oauth_login(
request: MCPOAuthLoginRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> MCPOAuthLoginResponse:
"""
Discover OAuth metadata from the MCP server and return a login URL.
1. Discovers the protected-resource metadata (RFC 9728)
2. Fetches the authorization server metadata (RFC 8414)
3. Performs Dynamic Client Registration (RFC 7591) if available
4. Returns the authorization URL for the frontend to open in a popup
"""
client = MCPClient(request.server_url)
# Step 1: Discover protected-resource metadata (RFC 9728)
protected_resource = await client.discover_auth()
metadata: dict[str, Any] | None = None
if protected_resource and protected_resource.get("authorization_servers"):
auth_server_url = protected_resource["authorization_servers"][0]
resource_url = protected_resource.get("resource", request.server_url)
# Step 2a: Discover auth-server metadata (RFC 8414)
metadata = await client.discover_auth_server_metadata(auth_server_url)
else:
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
# and serve OAuth metadata directly without protected-resource metadata.
# Don't assume a resource_url — omitting it lets the auth server choose
# the correct audience for the token (RFC 8707 resource is optional).
resource_url = None
metadata = await client.discover_auth_server_metadata(request.server_url)
if (
not metadata
or "authorization_endpoint" not in metadata
or "token_endpoint" not in metadata
):
raise fastapi.HTTPException(
status_code=400,
detail="This MCP server does not advertise OAuth support. "
"You may need to provide an auth token manually.",
)
authorize_url = metadata["authorization_endpoint"]
token_url = metadata["token_endpoint"]
registration_endpoint = metadata.get("registration_endpoint")
revoke_url = metadata.get("revocation_endpoint")
# Step 3: Dynamic Client Registration (RFC 7591) if available
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
raise fastapi.HTTPException(
status_code=500,
detail="Frontend base URL is not configured.",
)
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
client_id = ""
client_secret = ""
if registration_endpoint:
reg_result = await _register_mcp_client(
registration_endpoint, redirect_uri, request.server_url
)
if reg_result:
client_id = reg_result.get("client_id", "")
client_secret = reg_result.get("client_secret", "")
if not client_id:
client_id = "autogpt-platform"
# Step 4: Store state token with OAuth metadata for the callback
scopes = (protected_resource or {}).get("scopes_supported") or metadata.get(
"scopes_supported", []
)
state_token, code_challenge = await creds_manager.store.store_state_token(
user_id,
ProviderName.MCP.value,
scopes,
state_metadata={
"authorize_url": authorize_url,
"token_url": token_url,
"revoke_url": revoke_url,
"resource_url": resource_url,
"server_url": request.server_url,
"client_id": client_id,
"client_secret": client_secret,
},
)
# Step 5: Build and return the login URL
handler = MCPOAuthHandler(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
authorize_url=authorize_url,
token_url=token_url,
resource_url=resource_url,
)
login_url = handler.get_login_url(
scopes, state_token, code_challenge=code_challenge
)
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
class MCPOAuthCallbackRequest(BaseModel):
"""Request to exchange an OAuth code for tokens."""
code: str = Field(description="Authorization code from OAuth callback")
state_token: str = Field(description="State token for CSRF verification")
class MCPOAuthCallbackResponse(BaseModel):
"""Response after successfully storing OAuth credentials."""
credential_id: str
@router.post(
"/oauth/callback",
summary="Exchange OAuth code for MCP tokens",
)
async def mcp_oauth_callback(
request: MCPOAuthCallbackRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> CredentialsMetaResponse:
"""
Exchange the authorization code for tokens and store the credential.
The frontend calls this after receiving the OAuth code from the popup.
On success, subsequent ``/discover-tools`` calls for the same server URL
will automatically use the stored credential.
"""
valid_state = await creds_manager.store.verify_state_token(
user_id, request.state_token, ProviderName.MCP.value
)
if not valid_state:
raise fastapi.HTTPException(
status_code=400,
detail="Invalid or expired state token.",
)
meta = valid_state.state_metadata
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
raise fastapi.HTTPException(
status_code=500,
detail="Frontend base URL is not configured.",
)
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
handler = MCPOAuthHandler(
client_id=meta["client_id"],
client_secret=meta.get("client_secret", ""),
redirect_uri=redirect_uri,
authorize_url=meta["authorize_url"],
token_url=meta["token_url"],
revoke_url=meta.get("revoke_url"),
resource_url=meta.get("resource_url"),
)
try:
credentials = await handler.exchange_code_for_tokens(
request.code, valid_state.scopes, valid_state.code_verifier
)
except Exception as e:
raise fastapi.HTTPException(
status_code=400,
detail=f"OAuth token exchange failed: {e}",
)
# Enrich credential metadata for future lookup and token refresh
if credentials.metadata is None:
credentials.metadata = {}
credentials.metadata["mcp_server_url"] = meta["server_url"]
credentials.metadata["mcp_client_id"] = meta["client_id"]
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
credentials.metadata["mcp_token_url"] = meta["token_url"]
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
credentials.title = f"MCP: {hostname}"
# Remove old MCP credentials for the same server to prevent stale token buildup.
try:
old_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
for old in old_creds:
if (
isinstance(old, OAuth2Credentials)
and (old.metadata or {}).get("mcp_server_url") == meta["server_url"]
):
await creds_manager.store.delete_creds_by_id(user_id, old.id)
logger.info(
f"Removed old MCP credential {old.id} for {meta['server_url']}"
)
except Exception:
logger.debug("Could not clean up old MCP credentials", exc_info=True)
await creds_manager.create(user_id, credentials)
return CredentialsMetaResponse(
id=credentials.id,
provider=credentials.provider,
type=credentials.type,
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=credentials.metadata.get("mcp_server_url"),
)
# ======================== Helpers ======================== #
async def _register_mcp_client(
registration_endpoint: str,
redirect_uri: str,
server_url: str,
) -> dict[str, Any] | None:
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
try:
response = await Requests(raise_for_status=True).post(
registration_endpoint,
json={
"client_name": "AutoGPT Platform",
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_post",
},
)
data = response.json()
if isinstance(data, dict) and "client_id" in data:
return data
return None
except Exception as e:
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
return None

View File

@@ -1,436 +0,0 @@
"""Tests for MCP API routes.
Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient
to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop.
"""
from unittest.mock import AsyncMock, patch
import fastapi
import httpx
import pytest
import pytest_asyncio
from autogpt_libs.auth import get_user_id
from backend.api.features.mcp.routes import router
from backend.blocks.mcp.client import MCPClientError, MCPTool
from backend.util.request import HTTPClientError
app = fastapi.FastAPI()
app.include_router(router)
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
@pytest_asyncio.fixture(scope="module")
async def client():
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
yield c
class TestDiscoverTools:
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_success(self, client):
mock_tools = [
MCPTool(
name="get_weather",
description="Get weather for a city",
input_schema={
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
),
MCPTool(
name="add_numbers",
description="Add two numbers",
input_schema={
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"},
},
},
),
]
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={
"protocolVersion": "2025-03-26",
"serverInfo": {"name": "test-server"},
}
)
instance.list_tools = AsyncMock(return_value=mock_tools)
response = await client.post(
"/discover-tools",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
data = response.json()
assert len(data["tools"]) == 2
assert data["tools"][0]["name"] == "get_weather"
assert data["tools"][1]["name"] == "add_numbers"
assert data["server_name"] == "test-server"
assert data["protocol_version"] == "2025-03-26"
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_with_auth_token(self, client):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
)
instance.list_tools = AsyncMock(return_value=[])
response = await client.post(
"/discover-tools",
json={
"server_url": "https://mcp.example.com/mcp",
"auth_token": "my-secret-token",
},
)
assert response.status_code == 200
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="my-secret-token",
)
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auto_uses_stored_credential(self, client):
"""When no explicit token is given, stored MCP credentials are used."""
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
stored_cred = OAuth2Credentials(
provider="mcp",
title="MCP: example.com",
access_token=SecretStr("stored-token-123"),
refresh_token=None,
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
)
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
)
instance.list_tools = AsyncMock(return_value=[])
response = await client.post(
"/discover-tools",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="stored-token-123",
)
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_mcp_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=MCPClientError("Connection refused")
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://bad-server.example.com/mcp"},
)
assert response.status_code == 502
assert "Connection refused" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_generic_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
response = await client.post(
"/discover-tools",
json={"server_url": "https://timeout.example.com/mcp"},
)
assert response.status_code == 502
assert "Failed to connect" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auth_required(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://auth-server.example.com/mcp"},
)
assert response.status_code == 401
assert "requires authentication" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_forbidden(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://auth-server.example.com/mcp"},
)
assert response.status_code == 401
assert "requires authentication" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_missing_url(self, client):
response = await client.post("/discover-tools", json={})
assert response.status_code == 422
class TestOAuthLogin:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_success(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch(
"backend.api.features.mcp.routes._register_mcp_client"
) as mock_register,
):
instance = MockClient.return_value
instance.discover_auth = AsyncMock(
return_value={
"authorization_servers": ["https://auth.sentry.io"],
"resource": "https://mcp.sentry.dev/mcp",
"scopes_supported": ["openid"],
}
)
instance.discover_auth_server_metadata = AsyncMock(
return_value={
"authorization_endpoint": "https://auth.sentry.io/authorize",
"token_endpoint": "https://auth.sentry.io/token",
"registration_endpoint": "https://auth.sentry.io/register",
}
)
mock_register.return_value = {
"client_id": "registered-client-id",
"client_secret": "registered-secret",
}
mock_cm.store.store_state_token = AsyncMock(
return_value=("state-token-123", "code-challenge-abc")
)
mock_settings.config.frontend_base_url = "http://localhost:3000"
response = await client.post(
"/oauth/login",
json={"server_url": "https://mcp.sentry.dev/mcp"},
)
assert response.status_code == 200
data = response.json()
assert "login_url" in data
assert data["state_token"] == "state-token-123"
assert "auth.sentry.io/authorize" in data["login_url"]
assert "registered-client-id" in data["login_url"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_no_oauth_support(self, client):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.discover_auth = AsyncMock(return_value=None)
instance.discover_auth_server_metadata = AsyncMock(return_value=None)
response = await client.post(
"/oauth/login",
json={"server_url": "https://simple-server.example.com/mcp"},
)
assert response.status_code == 400
assert "does not advertise OAuth" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_fallback_to_public_client(self, client):
"""When DCR is unavailable, falls back to default public client ID."""
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
):
instance = MockClient.return_value
instance.discover_auth = AsyncMock(
return_value={
"authorization_servers": ["https://auth.example.com"],
"resource": "https://mcp.example.com/mcp",
}
)
instance.discover_auth_server_metadata = AsyncMock(
return_value={
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
# No registration_endpoint
}
)
mock_cm.store.store_state_token = AsyncMock(
return_value=("state-abc", "challenge-xyz")
)
mock_settings.config.frontend_base_url = "http://localhost:3000"
response = await client.post(
"/oauth/login",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
data = response.json()
assert "autogpt-platform" in data["login_url"]
class TestOAuthCallback:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_success(self, client):
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
mock_creds = OAuth2Credentials(
provider="mcp",
title=None,
access_token=SecretStr("access-token-xyz"),
refresh_token=None,
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
metadata={
"mcp_token_url": "https://auth.sentry.io/token",
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
},
)
with (
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
):
mock_settings.config.frontend_base_url = "http://localhost:3000"
# Mock state verification
mock_state = AsyncMock()
mock_state.state_metadata = {
"authorize_url": "https://auth.sentry.io/authorize",
"token_url": "https://auth.sentry.io/token",
"client_id": "test-client-id",
"client_secret": "test-secret",
"server_url": "https://mcp.sentry.dev/mcp",
}
mock_state.scopes = ["openid"]
mock_state.code_verifier = "verifier-123"
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
mock_cm.create = AsyncMock()
handler_instance = MockHandler.return_value
handler_instance.exchange_code_for_tokens = AsyncMock(
return_value=mock_creds
)
# Mock old credential cleanup
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
response = await client.post(
"/oauth/callback",
json={"code": "auth-code-abc", "state_token": "state-token-123"},
)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["provider"] == "mcp"
assert data["type"] == "oauth2"
mock_cm.create.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_invalid_state(self, client):
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
response = await client.post(
"/oauth/callback",
json={"code": "auth-code", "state_token": "bad-state"},
)
assert response.status_code == 400
assert "Invalid or expired" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_token_exchange_fails(self, client):
with (
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
):
mock_settings.config.frontend_base_url = "http://localhost:3000"
mock_state = AsyncMock()
mock_state.state_metadata = {
"authorize_url": "https://auth.example.com/authorize",
"token_url": "https://auth.example.com/token",
"client_id": "cid",
"server_url": "https://mcp.example.com/mcp",
}
mock_state.scopes = []
mock_state.code_verifier = "v"
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
handler_instance = MockHandler.return_value
handler_instance.exchange_code_for_tokens = AsyncMock(
side_effect=RuntimeError("Token exchange failed")
)
response = await client.post(
"/oauth/callback",
json={"code": "bad-code", "state_token": "state"},
)
assert response.status_code == 400
assert "token exchange failed" in response.json()["detail"].lower()

View File

@@ -5,8 +5,8 @@ from typing import Optional
import aiohttp
from fastapi import HTTPException
from backend.blocks import get_block
from backend.data import graph as graph_db
from backend.data.block import get_block
from backend.util.settings import Settings
from .models import ApiResponse, ChatRequest, GraphData

View File

@@ -152,7 +152,7 @@ class BlockHandler(ContentHandler):
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
"""Fetch blocks without embeddings."""
from backend.blocks import get_blocks
from backend.data.block import get_blocks
# Get all available blocks
all_blocks = get_blocks()
@@ -249,7 +249,7 @@ class BlockHandler(ContentHandler):
async def get_stats(self) -> dict[str, int]:
"""Get statistics about block embedding coverage."""
from backend.blocks import get_blocks
from backend.data.block import get_blocks
all_blocks = get_blocks()

View File

@@ -93,7 +93,7 @@ async def test_block_handler_get_missing_items(mocker):
mock_existing = []
with patch(
"backend.blocks.get_blocks",
"backend.data.block.get_blocks",
return_value=mock_blocks,
):
with patch(
@@ -135,7 +135,7 @@ async def test_block_handler_get_stats(mocker):
mock_embedded = [{"count": 2}]
with patch(
"backend.blocks.get_blocks",
"backend.data.block.get_blocks",
return_value=mock_blocks,
):
with patch(
@@ -327,7 +327,7 @@ async def test_block_handler_handles_missing_attributes():
mock_blocks = {"block-minimal": mock_block_class}
with patch(
"backend.blocks.get_blocks",
"backend.data.block.get_blocks",
return_value=mock_blocks,
):
with patch(
@@ -360,7 +360,7 @@ async def test_block_handler_skips_failed_blocks():
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
with patch(
"backend.blocks.get_blocks",
"backend.data.block.get_blocks",
return_value=mock_blocks,
):
with patch(

View File

@@ -1,7 +1,7 @@
import asyncio
import logging
from datetime import datetime, timezone
from typing import Any, Literal, overload
from typing import Any, Literal
import fastapi
import prisma.enums
@@ -11,8 +11,8 @@ import prisma.types
from backend.data.db import transaction
from backend.data.graph import (
GraphMeta,
GraphModel,
GraphModelWithoutNodes,
get_graph,
get_graph_as_admin,
get_sub_graphs,
@@ -334,22 +334,7 @@ async def get_store_agent_details(
raise DatabaseError("Failed to fetch agent details") from e
@overload
async def get_available_graph(
store_listing_version_id: str, hide_nodes: Literal[False]
) -> GraphModel: ...
@overload
async def get_available_graph(
store_listing_version_id: str, hide_nodes: Literal[True] = True
) -> GraphModelWithoutNodes: ...
async def get_available_graph(
store_listing_version_id: str,
hide_nodes: bool = True,
) -> GraphModelWithoutNodes | GraphModel:
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
try:
# Get avaialble, non-deleted store listing version
store_listing_version = (
@@ -359,7 +344,7 @@ async def get_available_graph(
"isAvailable": True,
"isDeleted": False,
},
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
include={"AgentGraph": {"include": {"Nodes": True}}},
)
)
@@ -369,9 +354,7 @@ async def get_available_graph(
detail=f"Store listing version {store_listing_version_id} not found",
)
return (GraphModelWithoutNodes if hide_nodes else GraphModel).from_db(
store_listing_version.AgentGraph
)
return GraphModel.from_db(store_listing_version.AgentGraph).meta()
except Exception as e:
logger.error(f"Error getting agent: {e}")

View File

@@ -662,7 +662,7 @@ async def cleanup_orphaned_embeddings() -> dict[str, Any]:
)
current_ids = {row["id"] for row in valid_agents}
elif content_type == ContentType.BLOCK:
from backend.blocks import get_blocks
from backend.data.block import get_blocks
current_ids = set(get_blocks().keys())
elif content_type == ContentType.DOCUMENTATION:

View File

@@ -454,9 +454,6 @@ async def test_unified_hybrid_search_pagination(
cleanup_embeddings: list,
):
"""Test unified search pagination works correctly."""
# Use a unique search term to avoid matching other test data
unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}"
# Create multiple items
content_ids = []
for i in range(5):
@@ -468,14 +465,14 @@ async def test_unified_hybrid_search_pagination(
content_type=ContentType.BLOCK,
content_id=content_id,
embedding=mock_embedding,
searchable_text=f"{unique_term} item number {i}",
searchable_text=f"pagination test item number {i}",
metadata={"index": i},
user_id=None,
)
# Get first page
page1_results, total1 = await unified_hybrid_search(
query=unique_term,
query="pagination test",
content_types=[ContentType.BLOCK],
page=1,
page_size=2,
@@ -483,7 +480,7 @@ async def test_unified_hybrid_search_pagination(
# Get second page
page2_results, total2 = await unified_hybrid_search(
query=unique_term,
query="pagination test",
content_types=[ContentType.BLOCK],
page=2,
page_size=2,

View File

@@ -8,7 +8,6 @@ Includes BM25 reranking for improved lexical relevance.
import logging
import re
import time
from dataclasses import dataclass
from typing import Any, Literal
@@ -363,11 +362,7 @@ async def unified_hybrid_search(
LIMIT {limit_param} OFFSET {offset_param}
"""
try:
results = await query_raw_with_schema(sql_query, *params)
except Exception as e:
await _log_vector_error_diagnostics(e)
raise
results = await query_raw_with_schema(sql_query, *params)
total = results[0]["total_count"] if results else 0
# Apply BM25 reranking
@@ -691,11 +686,7 @@ async def hybrid_search(
LIMIT {limit_param} OFFSET {offset_param}
"""
try:
results = await query_raw_with_schema(sql_query, *params)
except Exception as e:
await _log_vector_error_diagnostics(e)
raise
results = await query_raw_with_schema(sql_query, *params)
total = results[0]["total_count"] if results else 0
@@ -727,87 +718,6 @@ async def hybrid_search_simple(
return await hybrid_search(query=query, page=page, page_size=page_size)
# ============================================================================
# Diagnostics
# ============================================================================
# Rate limit: only log vector error diagnostics once per this interval
_VECTOR_DIAG_INTERVAL_SECONDS = 60
_last_vector_diag_time: float = 0
async def _log_vector_error_diagnostics(error: Exception) -> None:
"""Log diagnostic info when 'type vector does not exist' error occurs.
Note: Diagnostic queries use query_raw_with_schema which may run on a different
pooled connection than the one that failed. Session-level search_path can differ,
so these diagnostics show cluster-wide state, not necessarily the failed session.
Includes rate limiting to avoid log spam - only logs once per minute.
Caller should re-raise the error after calling this function.
"""
global _last_vector_diag_time
# Check if this is the vector type error
error_str = str(error).lower()
if not (
"type" in error_str and "vector" in error_str and "does not exist" in error_str
):
return
# Rate limit: only log once per interval
now = time.time()
if now - _last_vector_diag_time < _VECTOR_DIAG_INTERVAL_SECONDS:
return
_last_vector_diag_time = now
try:
diagnostics: dict[str, object] = {}
try:
search_path_result = await query_raw_with_schema("SHOW search_path")
diagnostics["search_path"] = search_path_result
except Exception as e:
diagnostics["search_path"] = f"Error: {e}"
try:
schema_result = await query_raw_with_schema("SELECT current_schema()")
diagnostics["current_schema"] = schema_result
except Exception as e:
diagnostics["current_schema"] = f"Error: {e}"
try:
user_result = await query_raw_with_schema(
"SELECT current_user, session_user, current_database()"
)
diagnostics["user_info"] = user_result
except Exception as e:
diagnostics["user_info"] = f"Error: {e}"
try:
# Check pgvector extension installation (cluster-wide, stable info)
ext_result = await query_raw_with_schema(
"SELECT extname, extversion, nspname as schema "
"FROM pg_extension e "
"JOIN pg_namespace n ON e.extnamespace = n.oid "
"WHERE extname = 'vector'"
)
diagnostics["pgvector_extension"] = ext_result
except Exception as e:
diagnostics["pgvector_extension"] = f"Error: {e}"
logger.error(
f"Vector type error diagnostics:\n"
f" Error: {error}\n"
f" search_path: {diagnostics.get('search_path')}\n"
f" current_schema: {diagnostics.get('current_schema')}\n"
f" user_info: {diagnostics.get('user_info')}\n"
f" pgvector_extension: {diagnostics.get('pgvector_extension')}"
)
except Exception as diag_error:
logger.error(f"Failed to collect vector error diagnostics: {diag_error}")
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
# for existing code that expects the popularity parameter
HybridSearchWeights = StoreAgentSearchWeights

View File

@@ -7,7 +7,16 @@ from replicate.client import Client as ReplicateClient
from replicate.exceptions import ReplicateError
from replicate.helpers import FileOutput
from backend.data.graph import GraphBaseMeta
from backend.blocks.ideogram import (
AspectRatio,
ColorPalettePreset,
IdeogramModelBlock,
IdeogramModelName,
MagicPromptOption,
StyleType,
UpscaleOption,
)
from backend.data.graph import BaseGraph
from backend.data.model import CredentialsMetaInput, ProviderName
from backend.integrations.credentials_store import ideogram_credentials
from backend.util.request import Requests
@@ -25,14 +34,14 @@ class ImageStyle(str, Enum):
DIGITAL_ART = "digital art"
async def generate_agent_image(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
async def generate_agent_image(agent: BaseGraph | AgentGraph) -> io.BytesIO:
if settings.config.use_agent_image_generation_v2:
return await generate_agent_image_v2(graph=agent)
else:
return await generate_agent_image_v1(agent=agent)
async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.BytesIO:
async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
"""
Generate an image for an agent using Ideogram model.
Returns:
@@ -41,31 +50,18 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
if not ideogram_credentials.api_key:
raise ValueError("Missing Ideogram API key")
from backend.blocks.ideogram import (
AspectRatio,
ColorPalettePreset,
IdeogramModelBlock,
IdeogramModelName,
MagicPromptOption,
StyleType,
UpscaleOption,
)
name = graph.name
description = f"{name} ({graph.description})" if graph.description else name
prompt = (
"Create a visually striking retro-futuristic vector pop art illustration "
f'prominently featuring "{name}" in bold typography. The image clearly and '
f"literally depicts a {description}, along with recognizable objects directly "
f"associated with the primary function of a {name}. "
f"Ensure the imagery is concrete, intuitive, and immediately understandable, "
f"clearly conveying the purpose of a {name}. "
"Maintain vibrant, limited-palette colors, sharp vector lines, "
"geometric shapes, flat illustration techniques, and solid colors "
"without gradients or shading. Preserve a retro-futuristic aesthetic "
"influenced by mid-century futurism and 1960s psychedelia, "
"prioritizing clear visual storytelling and thematic clarity above all else."
f"Create a visually striking retro-futuristic vector pop art illustration prominently featuring "
f'"{name}" in bold typography. The image clearly and literally depicts a {description}, '
f"along with recognizable objects directly associated with the primary function of a {name}. "
f"Ensure the imagery is concrete, intuitive, and immediately understandable, clearly conveying the "
f"purpose of a {name}. Maintain vibrant, limited-palette colors, sharp vector lines, geometric "
f"shapes, flat illustration techniques, and solid colors without gradients or shading. Preserve a "
f"retro-futuristic aesthetic influenced by mid-century futurism and 1960s psychedelia, "
f"prioritizing clear visual storytelling and thematic clarity above all else."
)
custom_colors = [
@@ -103,12 +99,12 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
return io.BytesIO(response.content)
async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
"""
Generate an image for an agent using Flux model via Replicate API.
Args:
agent (GraphBaseMeta | AgentGraph): The agent to generate an image for
agent (Graph): The agent to generate an image for
Returns:
io.BytesIO: The generated image as bytes
@@ -118,13 +114,7 @@ async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.Bytes
raise ValueError("Missing Replicate API key in settings")
# Construct prompt from agent details
prompt = (
"Create a visually engaging app store thumbnail for the AI agent "
"that highlights what it does in a clear and captivating way:\n"
f"- **Name**: {agent.name}\n"
f"- **Description**: {agent.description}\n"
f"Focus on showcasing its core functionality with an appealing design."
)
prompt = f"Create a visually engaging app store thumbnail for the AI agent that highlights what it does in a clear and captivating way:\n- **Name**: {agent.name}\n- **Description**: {agent.description}\nFocus on showcasing its core functionality with an appealing design."
# Set up Replicate client
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)

View File

@@ -278,7 +278,7 @@ async def get_agent(
)
async def get_graph_meta_by_store_listing_version_id(
store_listing_version_id: str,
) -> backend.data.graph.GraphModelWithoutNodes:
) -> backend.data.graph.GraphMeta:
"""
Get Agent Graph from Store Listing Version ID.
"""

View File

@@ -40,11 +40,10 @@ from backend.api.model import (
UpdateTimezoneRequest,
UploadFileResponse,
)
from backend.blocks import get_block, get_blocks
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.auth import api_key as api_key_db
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
from backend.data.credit import (
AutoTopUpConfig,
RefundRequest,
@@ -102,6 +101,7 @@ from backend.util.timezone_utils import (
from backend.util.virus_scanner import scan_content_safe
from .library import db as library_db
from .library import model as library_model
from .store.model import StoreAgentDetails
@@ -823,16 +823,18 @@ async def update_graph(
graph: graph_db.Graph,
user_id: Annotated[str, Security(get_user_id)],
) -> graph_db.GraphModel:
# Sanity check
if graph.id and graph.id != graph_id:
raise HTTPException(400, detail="Graph ID does not match ID in URI")
# Determine new version
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
if not existing_versions:
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
latest_version_number = max(g.version for g in existing_versions)
graph.version = latest_version_number + 1
graph.version = max(g.version for g in existing_versions) + 1
current_active_version = next((v for v in existing_versions if v.is_active), None)
graph = graph_db.make_graph_model(graph, user_id)
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
graph.validate_graph(for_run=False)
@@ -840,23 +842,27 @@ async def update_graph(
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
if new_graph_version.is_active:
await library_db.update_library_agent_version_and_settings(
user_id, new_graph_version
)
# Keep the library agent up to date with the new active version
await _update_library_agent_version_and_settings(user_id, new_graph_version)
# Handle activation of the new graph first to ensure continuity
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
# Ensure new version is the only active version
await graph_db.set_graph_active_version(
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
)
if current_active_version:
# Handle deactivation of the previously active version
await on_graph_deactivate(current_active_version, user_id=user_id)
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
new_graph_version_with_subgraphs = await graph_db.get_graph(
graph_id,
new_graph_version.version,
user_id=user_id,
include_subgraphs=True,
)
assert new_graph_version_with_subgraphs
assert new_graph_version_with_subgraphs # make type checker happy
return new_graph_version_with_subgraphs
@@ -894,15 +900,33 @@ async def set_graph_active_version(
)
# Keep the library agent up to date with the new active version
await library_db.update_library_agent_version_and_settings(
user_id, new_active_graph
)
await _update_library_agent_version_and_settings(user_id, new_active_graph)
if current_active_graph and current_active_graph.version != new_active_version:
# Handle deactivation of the previously active version
await on_graph_deactivate(current_active_graph, user_id=user_id)
async def _update_library_agent_version_and_settings(
user_id: str, agent_graph: graph_db.GraphModel
) -> library_model.LibraryAgent:
library = await library_db.update_agent_version_in_library(
user_id, agent_graph.id, agent_graph.version
)
updated_settings = GraphSettings.from_graph(
graph=agent_graph,
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
)
if updated_settings != library.settings:
library = await library_db.update_library_agent(
library_agent_id=library.id,
user_id=user_id,
settings=updated_settings,
)
return library
@v1_router.patch(
path="/graphs/{graph_id}/settings",
summary="Update graph settings",

View File

@@ -11,7 +11,7 @@ import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi.responses import Response
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file
from backend.data.workspace import get_workspace, get_workspace_file
from backend.util.workspace_storage import get_workspace_storage
@@ -44,11 +44,11 @@ router = fastapi.APIRouter(
)
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
def _create_streaming_response(content: bytes, file) -> Response:
"""Create a streaming response for file content."""
return Response(
content=content,
media_type=file.mime_type,
media_type=file.mimeType,
headers={
"Content-Disposition": _sanitize_filename_for_header(file.name),
"Content-Length": str(len(content)),
@@ -56,7 +56,7 @@ def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
)
async def _create_file_download_response(file: WorkspaceFile) -> Response:
async def _create_file_download_response(file) -> Response:
"""
Create a download response for a workspace file.
@@ -66,33 +66,33 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
storage = await get_workspace_storage()
# For local storage, stream the file directly
if file.storage_path.startswith("local://"):
content = await storage.retrieve(file.storage_path)
if file.storagePath.startswith("local://"):
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
# For GCS, try to redirect to signed URL, fall back to streaming
try:
url = await storage.get_download_url(file.storage_path, expires_in=300)
url = await storage.get_download_url(file.storagePath, expires_in=300)
# If we got back an API path (fallback), stream directly instead
if url.startswith("/api/"):
content = await storage.retrieve(file.storage_path)
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
return fastapi.responses.RedirectResponse(url=url, status_code=302)
except Exception as e:
# Log the signed URL failure with context
logger.error(
f"Failed to get signed URL for file {file.id} "
f"(storagePath={file.storage_path}): {e}",
f"(storagePath={file.storagePath}): {e}",
exc_info=True,
)
# Fall back to streaming directly from GCS
try:
content = await storage.retrieve(file.storage_path)
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
except Exception as fallback_error:
logger.error(
f"Fallback streaming also failed for file {file.id} "
f"(storagePath={file.storage_path}): {fallback_error}",
f"(storagePath={file.storagePath}): {fallback_error}",
exc_info=True,
)
raise

View File

@@ -26,7 +26,6 @@ import backend.api.features.executions.review.routes
import backend.api.features.library.db
import backend.api.features.library.model
import backend.api.features.library.routes
import backend.api.features.mcp.routes as mcp_routes
import backend.api.features.oauth
import backend.api.features.otto.routes
import backend.api.features.postmark.postmark
@@ -42,10 +41,6 @@ import backend.integrations.webhooks.utils
import backend.util.service
import backend.util.settings
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.copilot.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi
@@ -123,21 +118,9 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for Redis Streams notifications
try:
await start_completion_consumer()
except Exception as e:
logger.warning(f"Could not start chat completion consumer: {e}")
with launch_darkly_context():
yield
# Stop chat completion consumer
try:
await stop_completion_consumer()
except Exception as e:
logger.warning(f"Error stopping chat completion consumer: {e}")
try:
await shutdown_cloud_storage_handler()
except Exception as e:
@@ -344,11 +327,6 @@ app.include_router(
tags=["workspace"],
prefix="/api/workspace",
)
app.include_router(
mcp_routes.router,
tags=["v2", "mcp"],
prefix="/api/mcp",
)
app.include_router(
backend.api.features.oauth.router,
tags=["oauth"],

View File

@@ -66,24 +66,18 @@ async def event_broadcaster(manager: ConnectionManager):
execution_bus = AsyncRedisExecutionEventBus()
notification_bus = AsyncRedisNotificationEventBus()
try:
async def execution_worker():
async for event in execution_bus.listen("*"):
await manager.send_execution_update(event)
async def execution_worker():
async for event in execution_bus.listen("*"):
await manager.send_execution_update(event)
async def notification_worker():
async for notification in notification_bus.listen("*"):
await manager.send_notification(
user_id=notification.user_id,
payload=notification.payload,
)
async def notification_worker():
async for notification in notification_bus.listen("*"):
await manager.send_notification(
user_id=notification.user_id,
payload=notification.payload,
)
await asyncio.gather(execution_worker(), notification_worker())
finally:
# Ensure PubSub connections are closed on any exit to prevent leaks
await execution_bus.close()
await notification_bus.close()
await asyncio.gather(execution_worker(), notification_worker())
async def authenticate_websocket(websocket: WebSocket) -> str:

View File

@@ -38,9 +38,7 @@ def main(**kwargs):
from backend.api.rest_api import AgentServer
from backend.api.ws_api import WebsocketServer
from backend.copilot.executor.manager import CoPilotExecutor
from backend.data.db_manager import DatabaseManager
from backend.executor import ExecutionManager, Scheduler
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
from backend.notifications import NotificationManager
run_processes(
@@ -50,7 +48,6 @@ def main(**kwargs):
WebsocketServer(),
AgentServer(),
ExecutionManager(),
CoPilotExecutor(),
**kwargs,
)

View File

@@ -3,19 +3,22 @@ import logging
import os
import re
from pathlib import Path
from typing import Sequence, Type, TypeVar
from typing import TYPE_CHECKING, TypeVar
from backend.blocks._base import AnyBlockSchema, BlockType
from backend.util.cache import cached
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.block import Block
T = TypeVar("T")
@cached(ttl_seconds=3600)
def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
from backend.blocks._base import Block
def load_all_blocks() -> dict[str, type["Block"]]:
from backend.data.block import Block
from backend.util.settings import Config
# Check if example blocks should be loaded from settings
@@ -47,8 +50,8 @@ def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
importlib.import_module(f".{module}", package=__name__)
# Load all Block instances from the available modules
available_blocks: dict[str, type["AnyBlockSchema"]] = {}
for block_cls in _all_subclasses(Block):
available_blocks: dict[str, type["Block"]] = {}
for block_cls in all_subclasses(Block):
class_name = block_cls.__name__
if class_name.endswith("Base"):
@@ -61,7 +64,7 @@ def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
"please name the class with 'Base' at the end"
)
block = block_cls() # pyright: ignore[reportAbstractUsage]
block = block_cls.create()
if not isinstance(block.id, str) or len(block.id) != 36:
raise ValueError(
@@ -102,7 +105,7 @@ def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
available_blocks[block.id] = block_cls
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
from ._utils import is_block_auth_configured
from backend.data.block import is_block_auth_configured
filtered_blocks = {}
for block_id, block_cls in available_blocks.items():
@@ -112,48 +115,11 @@ def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
return filtered_blocks
def _all_subclasses(cls: type[T]) -> list[type[T]]:
__all__ = ["load_all_blocks"]
def all_subclasses(cls: type[T]) -> list[type[T]]:
subclasses = cls.__subclasses__()
for subclass in subclasses:
subclasses += _all_subclasses(subclass)
subclasses += all_subclasses(subclass)
return subclasses
# ============== Block access helper functions ============== #
def get_blocks() -> dict[str, Type["AnyBlockSchema"]]:
return load_all_blocks()
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
def get_block(block_id: str) -> "AnyBlockSchema | None":
cls = get_blocks().get(block_id)
return cls() if cls else None
@cached(ttl_seconds=3600)
def get_webhook_block_ids() -> Sequence[str]:
return [
id
for id, B in get_blocks().items()
if B().block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
]
@cached(ttl_seconds=3600)
def get_io_block_ids() -> Sequence[str]:
return [
id
for id, B in get_blocks().items()
if B().block_type in (BlockType.INPUT, BlockType.OUTPUT)
]
@cached(ttl_seconds=3600)
def get_human_in_the_loop_block_ids() -> Sequence[str]:
return [
id
for id, B in get_blocks().items()
if B().block_type == BlockType.HUMAN_IN_THE_LOOP
]

View File

@@ -1,740 +0,0 @@
import inspect
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Generic,
Optional,
Type,
TypeAlias,
TypeVar,
cast,
get_origin,
)
import jsonref
import jsonschema
from pydantic import BaseModel
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
from backend.data.model import (
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
SchemaField,
is_credentials_field_name,
)
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.exceptions import (
BlockError,
BlockExecutionError,
BlockInputError,
BlockOutputError,
BlockUnknownError,
)
from backend.util.settings import Config
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
from backend.data.model import ContributorDetails, NodeExecutionStats
from ..data.graph import Link
app_config = Config()
BlockTestOutput = BlockOutputEntry | tuple[str, Callable[[Any], bool]]
class BlockType(Enum):
STANDARD = "Standard"
INPUT = "Input"
OUTPUT = "Output"
NOTE = "Note"
WEBHOOK = "Webhook"
WEBHOOK_MANUAL = "Webhook (manual)"
AGENT = "Agent"
AI = "AI"
AYRSHARE = "Ayrshare"
HUMAN_IN_THE_LOOP = "Human In The Loop"
MCP_TOOL = "MCP Tool"
class BlockCategory(Enum):
AI = "Block that leverages AI to perform a task."
SOCIAL = "Block that interacts with social media platforms."
TEXT = "Block that processes text data."
SEARCH = "Block that searches or extracts information from the internet."
BASIC = "Block that performs basic operations."
INPUT = "Block that interacts with input of the graph."
OUTPUT = "Block that interacts with output of the graph."
LOGIC = "Programming logic to control the flow of your agent"
COMMUNICATION = "Block that interacts with communication platforms."
DEVELOPER_TOOLS = "Developer tools such as GitHub blocks."
DATA = "Block that interacts with structured data."
HARDWARE = "Block that interacts with hardware."
AGENT = "Block that interacts with other agents."
CRM = "Block that interacts with CRM services."
SAFETY = (
"Block that provides AI safety mechanisms such as detecting harmful content"
)
PRODUCTIVITY = "Block that helps with productivity"
ISSUE_TRACKING = "Block that helps with issue tracking"
MULTIMEDIA = "Block that interacts with multimedia content"
MARKETING = "Block that helps with marketing"
def dict(self) -> dict[str, str]:
return {"category": self.name, "description": self.value}
class BlockCostType(str, Enum):
RUN = "run" # cost X credits per run
BYTE = "byte" # cost X credits per byte
SECOND = "second" # cost X credits per second
class BlockCost(BaseModel):
cost_amount: int
cost_filter: BlockInput
cost_type: BlockCostType
def __init__(
self,
cost_amount: int,
cost_type: BlockCostType = BlockCostType.RUN,
cost_filter: Optional[BlockInput] = None,
**data: Any,
) -> None:
super().__init__(
cost_amount=cost_amount,
cost_filter=cost_filter or {},
cost_type=cost_type,
**data,
)
class BlockInfo(BaseModel):
id: str
name: str
inputSchema: dict[str, Any]
outputSchema: dict[str, Any]
costs: list[BlockCost]
description: str
categories: list[dict[str, str]]
contributors: list[dict[str, Any]]
staticOutput: bool
uiType: str
class BlockSchema(BaseModel):
cached_jsonschema: ClassVar[dict[str, Any]]
@classmethod
def jsonschema(cls) -> dict[str, Any]:
if cls.cached_jsonschema:
return cls.cached_jsonschema
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
def ref_to_dict(obj):
if isinstance(obj, dict):
# OpenAPI <3.1 does not support sibling fields that has a $ref key
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
keys = {"allOf", "anyOf", "oneOf"}
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
if one_key:
obj.update(obj[one_key][0])
return {
key: ref_to_dict(value)
for key, value in obj.items()
if not key.startswith("$") and key != one_key
}
elif isinstance(obj, list):
return [ref_to_dict(item) for item in obj]
return obj
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
return cls.cached_jsonschema
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:
return json.validate_with_jsonschema(
schema=cls.jsonschema(),
data={k: v for k, v in data.items() if v is not None},
)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return cls.validate_data(data)
@classmethod
def get_field_schema(cls, field_name: str) -> dict[str, Any]:
model_schema = cls.jsonschema().get("properties", {})
if not model_schema:
raise ValueError(f"Invalid model schema {cls}")
property_schema = model_schema.get(field_name)
if not property_schema:
raise ValueError(f"Invalid property name {field_name}")
return property_schema
@classmethod
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:
"""
Validate the data against a specific property (one of the input/output name).
Returns the validation error message if the data does not match the schema.
"""
try:
property_schema = cls.get_field_schema(field_name)
jsonschema.validate(json.to_dict(data), property_schema)
return None
except jsonschema.ValidationError as e:
return str(e)
@classmethod
def get_fields(cls) -> set[str]:
return set(cls.model_fields.keys())
@classmethod
def get_required_fields(cls) -> set[str]:
return {
field
for field, field_info in cls.model_fields.items()
if field_info.is_required()
}
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""Validates the schema definition. Rules:
- Fields with annotation `CredentialsMetaInput` MUST be
named `credentials` or `*_credentials`
- Fields named `credentials` or `*_credentials` MUST be
of type `CredentialsMetaInput`
"""
super().__pydantic_init_subclass__(**kwargs)
# Reset cached JSON schema to prevent inheriting it from parent class
cls.cached_jsonschema = {}
credentials_fields = cls.get_credentials_fields()
for field_name in cls.get_fields():
if is_credentials_field_name(field_name):
if field_name not in credentials_fields:
raise TypeError(
f"Credentials field '{field_name}' on {cls.__qualname__} "
f"is not of type {CredentialsMetaInput.__name__}"
)
CredentialsMetaInput.validate_credentials_field_schema(
cls.get_field_schema(field_name), field_name
)
elif field_name in credentials_fields:
raise KeyError(
f"Credentials field '{field_name}' on {cls.__qualname__} "
"has invalid name: must be 'credentials' or *_credentials"
)
@classmethod
def get_credentials_fields(cls) -> dict[str, type[CredentialsMetaInput]]:
return {
field_name: info.annotation
for field_name, info in cls.model_fields.items()
if (
inspect.isclass(info.annotation)
and issubclass(
get_origin(info.annotation) or info.annotation,
CredentialsMetaInput,
)
)
}
@classmethod
def get_auto_credentials_fields(cls) -> dict[str, dict[str, Any]]:
"""
Get fields that have auto_credentials metadata (e.g., GoogleDriveFileInput).
Returns a dict mapping kwarg_name -> {field_name, auto_credentials_config}
Raises:
ValueError: If multiple fields have the same kwarg_name, as this would
cause silent overwriting and only the last field would be processed.
"""
result: dict[str, dict[str, Any]] = {}
schema = cls.jsonschema()
properties = schema.get("properties", {})
for field_name, field_schema in properties.items():
auto_creds = field_schema.get("auto_credentials")
if auto_creds:
kwarg_name = auto_creds.get("kwarg_name", "credentials")
if kwarg_name in result:
raise ValueError(
f"Duplicate auto_credentials kwarg_name '{kwarg_name}' "
f"in fields '{result[kwarg_name]['field_name']}' and "
f"'{field_name}' on {cls.__qualname__}"
)
result[kwarg_name] = {
"field_name": field_name,
"config": auto_creds,
}
return result
@classmethod
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
result = {}
# Regular credentials fields
for field_name in cls.get_credentials_fields().keys():
result[field_name] = CredentialsFieldInfo.model_validate(
cls.get_field_schema(field_name), by_alias=True
)
# Auto-generated credentials fields (from GoogleDriveFileInput etc.)
for kwarg_name, info in cls.get_auto_credentials_fields().items():
config = info["config"]
# Build a schema-like dict that CredentialsFieldInfo can parse
auto_schema = {
"credentials_provider": [config.get("provider", "google")],
"credentials_types": [config.get("type", "oauth2")],
"credentials_scopes": config.get("scopes"),
}
result[kwarg_name] = CredentialsFieldInfo.model_validate(
auto_schema, by_alias=True
)
return result
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data # Return as is, by default.
@classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
input_fields_from_nodes = {link.sink_name for link in links}
return input_fields_from_nodes - set(data)
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
return cls.get_required_fields() - set(data)
class BlockSchemaInput(BlockSchema):
"""
Base schema class for block inputs.
All block input schemas should extend this class for consistency.
"""
pass
class BlockSchemaOutput(BlockSchema):
"""
Base schema class for block outputs that includes a standard error field.
All block output schemas should extend this class to ensure consistent error handling.
"""
error: str = SchemaField(
description="Error message if the operation failed", default=""
)
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchemaInput)
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchemaOutput)
class EmptyInputSchema(BlockSchemaInput):
pass
class EmptyOutputSchema(BlockSchemaOutput):
pass
# For backward compatibility - will be deprecated
EmptySchema = EmptyOutputSchema
# --8<-- [start:BlockWebhookConfig]
class BlockManualWebhookConfig(BaseModel):
"""
Configuration model for webhook-triggered blocks on which
the user has to manually set up the webhook at the provider.
"""
provider: ProviderName
"""The service provider that the webhook connects to"""
webhook_type: str
"""
Identifier for the webhook type. E.g. GitHub has repo and organization level hooks.
Only for use in the corresponding `WebhooksManager`.
"""
event_filter_input: str = ""
"""
Name of the block's event filter input.
Leave empty if the corresponding webhook doesn't have distinct event/payload types.
"""
event_format: str = "{event}"
"""
Template string for the event(s) that a block instance subscribes to.
Applied individually to each event selected in the event filter input.
Example: `"pull_request.{event}"` -> `"pull_request.opened"`
"""
class BlockWebhookConfig(BlockManualWebhookConfig):
"""
Configuration model for webhook-triggered blocks for which
the webhook can be automatically set up through the provider's API.
"""
resource_format: str
"""
Template string for the resource that a block instance subscribes to.
Fields will be filled from the block's inputs (except `payload`).
Example: `f"{repo}/pull_requests"` (note: not how it's actually implemented)
Only for use in the corresponding `WebhooksManager`.
"""
# --8<-- [end:BlockWebhookConfig]
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
def __init__(
self,
id: str = "",
description: str = "",
contributors: list["ContributorDetails"] = [],
categories: set[BlockCategory] | None = None,
input_schema: Type[BlockSchemaInputType] = EmptyInputSchema,
output_schema: Type[BlockSchemaOutputType] = EmptyOutputSchema,
test_input: BlockInput | list[BlockInput] | None = None,
test_output: BlockTestOutput | list[BlockTestOutput] | None = None,
test_mock: dict[str, Any] | None = None,
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
disabled: bool = False,
static_output: bool = False,
block_type: BlockType = BlockType.STANDARD,
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
is_sensitive_action: bool = False,
):
"""
Initialize the block with the given schema.
Args:
id: The unique identifier for the block, this value will be persisted in the
DB. So it should be a unique and constant across the application run.
Use the UUID format for the ID.
description: The description of the block, explaining what the block does.
contributors: The list of contributors who contributed to the block.
input_schema: The schema, defined as a Pydantic model, for the input data.
output_schema: The schema, defined as a Pydantic model, for the output data.
test_input: The list or single sample input data for the block, for testing.
test_output: The list or single expected output if the test_input is run.
test_mock: function names on the block implementation to mock on test run.
disabled: If the block is disabled, it will not be available for execution.
static_output: Whether the output links of the block are static by default.
"""
from backend.data.model import NodeExecutionStats
self.id = id
self.input_schema = input_schema
self.output_schema = output_schema
self.test_input = test_input
self.test_output = test_output
self.test_mock = test_mock
self.test_credentials = test_credentials
self.description = description
self.categories = categories or set()
self.contributors = contributors or set()
self.disabled = disabled
self.static_output = static_output
self.block_type = block_type
self.webhook_config = webhook_config
self.is_sensitive_action = is_sensitive_action
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
# Enforce presence of credentials field on auto-setup webhook blocks
if not (cred_fields := self.input_schema.get_credentials_fields()):
raise TypeError(
"credentials field is required on auto-setup webhook blocks"
)
# Disallow multiple credentials inputs on webhook blocks
elif len(cred_fields) > 1:
raise ValueError(
"Multiple credentials inputs not supported on webhook blocks"
)
self.block_type = BlockType.WEBHOOK
else:
self.block_type = BlockType.WEBHOOK_MANUAL
# Enforce shape of webhook event filter, if present
if self.webhook_config.event_filter_input:
event_filter_field = self.input_schema.model_fields[
self.webhook_config.event_filter_input
]
if not (
isinstance(event_filter_field.annotation, type)
and issubclass(event_filter_field.annotation, BaseModel)
and all(
field.annotation is bool
for field in event_filter_field.annotation.model_fields.values()
)
):
raise NotImplementedError(
f"{self.name} has an invalid webhook event selector: "
"field must be a BaseModel and all its fields must be boolean"
)
# Enforce presence of 'payload' input
if "payload" not in self.input_schema.model_fields:
raise TypeError(
f"{self.name} is webhook-triggered but has no 'payload' input"
)
# Disable webhook-triggered block if webhook functionality not available
if not app_config.platform_base_url:
self.disabled = True
@abstractmethod
async def run(self, input_data: BlockSchemaInputType, **kwargs) -> BlockOutput:
"""
Run the block with the given input data.
Args:
input_data: The input data with the structure of input_schema.
Kwargs: Currently 14/02/2025 these include
graph_id: The ID of the graph.
node_id: The ID of the node.
graph_exec_id: The ID of the graph execution.
node_exec_id: The ID of the node execution.
user_id: The ID of the user.
Returns:
A Generator that yields (output_name, output_data).
output_name: One of the output name defined in Block's output_schema.
output_data: The data for the output_name, matching the defined schema.
"""
# --- satisfy the type checker, never executed -------------
if False: # noqa: SIM115
yield "name", "value" # pyright: ignore[reportMissingYield]
raise NotImplementedError(f"{self.name} does not implement the run method.")
async def run_once(
self, input_data: BlockSchemaInputType, output: str, **kwargs
) -> Any:
async for item in self.run(input_data, **kwargs):
name, data = item
if name == output:
return data
raise ValueError(f"{self.name} did not produce any output for {output}")
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
self.execution_stats += stats
return self.execution_stats
@property
def name(self):
return self.__class__.__name__
def to_dict(self):
return {
"id": self.id,
"name": self.name,
"inputSchema": self.input_schema.jsonschema(),
"outputSchema": self.output_schema.jsonschema(),
"description": self.description,
"categories": [category.dict() for category in self.categories],
"contributors": [
contributor.model_dump() for contributor in self.contributors
],
"staticOutput": self.static_output,
"uiType": self.block_type.value,
}
def get_info(self) -> BlockInfo:
from backend.data.credit import get_block_cost
return BlockInfo(
id=self.id,
name=self.name,
inputSchema=self.input_schema.jsonschema(),
outputSchema=self.output_schema.jsonschema(),
costs=get_block_cost(self),
description=self.description,
categories=[category.dict() for category in self.categories],
contributors=[
contributor.model_dump() for contributor in self.contributors
],
staticOutput=self.static_output,
uiType=self.block_type.value,
)
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
try:
async for output_name, output_data in self._execute(input_data, **kwargs):
yield output_name, output_data
except Exception as ex:
if isinstance(ex, BlockError):
raise ex
else:
raise (
BlockExecutionError
if isinstance(ex, ValueError)
else BlockUnknownError
)(
message=str(ex),
block_name=self.name,
block_id=self.id,
) from ex
async def is_block_exec_need_review(
self,
input_data: BlockInput,
*,
user_id: str,
node_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: "ExecutionContext",
**kwargs,
) -> tuple[bool, BlockInput]:
"""
Check if this block execution needs human review and handle the review process.
Returns:
Tuple of (should_pause, input_data_to_use)
- should_pause: True if execution should be paused for review
- input_data_to_use: The input data to use (may be modified by reviewer)
"""
if not (
self.is_sensitive_action and execution_context.sensitive_action_safe_mode
):
return False, input_data
from backend.blocks.helpers.review import HITLReviewHelper
# Handle the review request and get decision
decision = await HITLReviewHelper.handle_review_decision(
input_data=input_data,
user_id=user_id,
node_id=node_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
block_name=self.name,
editable=True,
)
if decision is None:
# We're awaiting review - pause execution
return True, input_data
if not decision.should_proceed:
# Review was rejected, raise an error to stop execution
raise BlockExecutionError(
message=f"Block execution rejected by reviewer: {decision.message}",
block_name=self.name,
block_id=self.id,
)
# Review was approved - use the potentially modified data
# ReviewResult.data must be a dict for block inputs
reviewed_data = decision.review_result.data
if not isinstance(reviewed_data, dict):
raise BlockExecutionError(
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
block_name=self.name,
block_id=self.id,
)
return False, reviewed_data
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
# Check for review requirement only if running within a graph execution context
# Direct block execution (e.g., from chat) skips the review process
has_graph_context = all(
key in kwargs
for key in (
"node_exec_id",
"graph_exec_id",
"graph_id",
"execution_context",
)
)
if has_graph_context:
should_pause, input_data = await self.is_block_exec_need_review(
input_data, **kwargs
)
if should_pause:
return
# Validate the input data (original or reviewer-modified) once
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
# Use the validated input data
async for output_name, output_data in self.run(
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
**kwargs,
):
if output_name == "error":
raise BlockExecutionError(
message=output_data, block_name=self.name, block_id=self.id
)
if self.block_type == BlockType.STANDARD and (
error := self.output_schema.validate_field(output_name, output_data)
):
raise BlockOutputError(
message=f"Block produced an invalid output data: {error}",
block_name=self.name,
block_id=self.id,
)
yield output_name, output_data
def is_triggered_by_event_type(
self, trigger_config: dict[str, Any], event_type: str
) -> bool:
if not self.webhook_config:
raise TypeError("This method can't be used on non-trigger blocks")
if not self.webhook_config.event_filter_input:
return True
event_filter = trigger_config.get(self.webhook_config.event_filter_input)
if not event_filter:
raise ValueError("Event filter is not configured on trigger")
return event_type in [
self.webhook_config.event_format.format(event=k)
for k in event_filter
if event_filter[k] is True
]
# Type alias for any block with standard input/output schemas
AnyBlockSchema: TypeAlias = Block[BlockSchemaInput, BlockSchemaOutput]

View File

@@ -1,122 +0,0 @@
import logging
import os
from backend.integrations.providers import ProviderName
from ._base import AnyBlockSchema
logger = logging.getLogger(__name__)
def is_block_auth_configured(
block_cls: type[AnyBlockSchema],
) -> bool:
"""
Check if a block has a valid authentication method configured at runtime.
For example if a block is an OAuth-only block and there env vars are not set,
do not show it in the UI.
"""
from backend.sdk.registry import AutoRegistry
# Create an instance to access input_schema
try:
block = block_cls()
except Exception as e:
# If we can't create a block instance, assume it's not OAuth-only
logger.error(f"Error creating block instance for {block_cls.__name__}: {e}")
return True
logger.debug(
f"Checking if block {block_cls.__name__} has a valid provider configured"
)
# Get all credential inputs from input schema
credential_inputs = block.input_schema.get_credentials_fields_info()
required_inputs = block.input_schema.get_required_fields()
if not credential_inputs:
logger.debug(
f"Block {block_cls.__name__} has no credential inputs - Treating as valid"
)
return True
# Check credential inputs
if len(required_inputs.intersection(credential_inputs.keys())) == 0:
logger.debug(
f"Block {block_cls.__name__} has only optional credential inputs"
" - will work without credentials configured"
)
# Check if the credential inputs for this block are correctly configured
for field_name, field_info in credential_inputs.items():
provider_names = field_info.provider
if not provider_names:
logger.warning(
f"Block {block_cls.__name__} "
f"has credential input '{field_name}' with no provider options"
" - Disabling"
)
return False
# If a field has multiple possible providers, each one needs to be usable to
# prevent breaking the UX
for _provider_name in provider_names:
provider_name = _provider_name.value
if provider_name in ProviderName.__members__.values():
logger.debug(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' is part of the legacy provider system"
" - Treating as valid"
)
break
provider = AutoRegistry.get_provider(provider_name)
if not provider:
logger.warning(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"refers to unknown provider '{provider_name}' - Disabling"
)
return False
# Check the provider's supported auth types
if field_info.supported_types != provider.supported_auth_types:
logger.warning(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"has mismatched supported auth types (field <> Provider): "
f"{field_info.supported_types} != {provider.supported_auth_types}"
)
if not (supported_auth_types := provider.supported_auth_types):
# No auth methods are been configured for this provider
logger.warning(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' "
"has no authentication methods configured - Disabling"
)
return False
# Check if provider supports OAuth
if "oauth2" in supported_auth_types:
# Check if OAuth environment variables are set
if (oauth_config := provider.oauth_config) and bool(
os.getenv(oauth_config.client_id_env_var)
and os.getenv(oauth_config.client_secret_env_var)
):
logger.debug(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' is configured for OAuth"
)
else:
logger.error(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' "
"is missing OAuth client ID or secret - Disabling"
)
return False
logger.debug(
f"Block {block_cls.__name__} credential input '{field_name}' is valid; "
f"supported credential types: {', '.join(field_info.supported_types)}"
)
return True

View File

@@ -1,7 +1,7 @@
import logging
from typing import TYPE_CHECKING, Any, Optional
from typing import Any, Optional
from backend.blocks._base import (
from backend.data.block import (
Block,
BlockCategory,
BlockInput,
@@ -9,15 +9,13 @@ from backend.blocks._base import (
BlockSchema,
BlockSchemaInput,
BlockType,
get_block,
)
from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util.json import validate_with_jsonschema
from backend.util.retry import func_retry
if TYPE_CHECKING:
from backend.executor.utils import LogMetadata
_logger = logging.getLogger(__name__)
@@ -126,10 +124,9 @@ class AgentExecutorBlock(Block):
graph_version: int,
graph_exec_id: str,
user_id: str,
logger: "LogMetadata",
logger,
) -> BlockOutput:
from backend.blocks import get_block
from backend.data.execution import ExecutionEventType
from backend.executor import utils as execution_utils
@@ -201,7 +198,7 @@ class AgentExecutorBlock(Block):
self,
graph_exec_id: str,
user_id: str,
logger: "LogMetadata",
logger,
) -> None:
from backend.executor import utils as execution_utils

View File

@@ -1,11 +1,5 @@
from typing import Any
from backend.blocks._base import (
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.llm import (
DEFAULT_LLM_MODEL,
TEST_CREDENTIALS,
@@ -17,6 +11,12 @@ from backend.blocks.llm import (
LLMResponse,
llm_call,
)
from backend.data.block import (
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField

View File

@@ -6,7 +6,7 @@ from pydantic import SecretStr
from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.blocks._base import (
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -5,12 +5,7 @@ from pydantic import SecretStr
from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.blocks._base import (
Block,
BlockCategory,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,

View File

@@ -6,7 +6,7 @@ from typing import Literal
from pydantic import SecretStr
from replicate.client import Client as ReplicateClient
from backend.blocks._base import (
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -6,7 +6,7 @@ from typing import Literal
from pydantic import SecretStr
from backend.blocks._base import (
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,

View File

@@ -1,10 +1,3 @@
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS,
@@ -17,6 +10,13 @@ from backend.blocks.apollo.models import (
PrimaryPhone,
SearchOrganizationsRequest,
)
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import CredentialsField, SchemaField

View File

@@ -1,12 +1,5 @@
import asyncio
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS,
@@ -21,6 +14,13 @@ from backend.blocks.apollo.models import (
SearchPeopleRequest,
SenorityLevels,
)
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import CredentialsField, SchemaField

View File

@@ -1,10 +1,3 @@
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS,
@@ -13,6 +6,13 @@ from backend.blocks.apollo._auth import (
ApolloCredentialsInput,
)
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import CredentialsField, SchemaField

View File

@@ -3,7 +3,7 @@ from typing import Optional
from pydantic import BaseModel, Field
from backend.blocks._base import BlockSchemaInput
from backend.data.block import BlockSchemaInput
from backend.data.model import SchemaField, UserIntegrations
from backend.integrations.ayrshare import AyrshareClient
from backend.util.clients import get_database_manager_async_client

View File

@@ -1,7 +1,7 @@
import enum
from typing import Any
from backend.blocks._base import (
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
@@ -126,7 +126,6 @@ class PrintToConsoleBlock(Block):
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
is_sensitive_action=True,
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
test_output=[
("output", "Hello, World!"),
("status", "printed"),

Some files were not shown because too many files have changed in this diff Show More