mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-10 06:45:28 -05:00
Compare commits
2 Commits
feat/mcp-b
...
fix/sentry
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d58df37238 | ||
|
|
9c41512944 |
2
.github/workflows/classic-frontend-ci.yml
vendored
2
.github/workflows/classic-frontend-ci.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||||
if: github.event_name == 'push'
|
if: github.event_name == 'push'
|
||||||
uses: peter-evans/create-pull-request@v8
|
uses: peter-evans/create-pull-request@v7
|
||||||
with:
|
with:
|
||||||
add-paths: classic/frontend/build/web
|
add-paths: classic/frontend/build/web
|
||||||
base: ${{ github.ref_name }}
|
base: ${{ github.ref_name }}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Get CI failure details
|
- name: Get CI failure details
|
||||||
id: failure_details
|
id: failure_details
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const run = await github.rest.actions.getWorkflowRun({
|
const run = await github.rest.actions.getWorkflowRun({
|
||||||
|
|||||||
9
.github/workflows/claude-dependabot.yml
vendored
9
.github/workflows/claude-dependabot.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -78,7 +78,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -124,7 +124,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
@@ -309,7 +309,6 @@ jobs:
|
|||||||
uses: anthropics/claude-code-action@v1
|
uses: anthropics/claude-code-action@v1
|
||||||
with:
|
with:
|
||||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||||
allowed_bots: "dependabot[bot]"
|
|
||||||
claude_args: |
|
claude_args: |
|
||||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||||
prompt: |
|
prompt: |
|
||||||
|
|||||||
8
.github/workflows/claude.yml
vendored
8
.github/workflows/claude.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -94,7 +94,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -140,7 +140,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
|
|||||||
8
.github/workflows/copilot-setup-steps.yml
vendored
8
.github/workflows/copilot-setup-steps.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -76,7 +76,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -132,7 +132,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
|
|||||||
2
.github/workflows/docs-block-sync.yml
vendored
2
.github/workflows/docs-block-sync.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
2
.github/workflows/docs-claude-review.yml
vendored
2
.github/workflows/docs-claude-review.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
2
.github/workflows/docs-enhance.yml
vendored
2
.github/workflows/docs-enhance.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -88,7 +88,7 @@ jobs:
|
|||||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ jobs:
|
|||||||
- name: Check comment permissions and deployment status
|
- name: Check comment permissions and deployment status
|
||||||
id: check_status
|
id: check_status
|
||||||
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const commentBody = context.payload.comment.body.trim();
|
const commentBody = context.payload.comment.body.trim();
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post permission denied comment
|
- name: Post permission denied comment
|
||||||
if: steps.check_status.outputs.permission_denied == 'true'
|
if: steps.check_status.outputs.permission_denied == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -68,7 +68,7 @@ jobs:
|
|||||||
- name: Get PR details for deployment
|
- name: Get PR details for deployment
|
||||||
id: pr_details
|
id: pr_details
|
||||||
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const pr = await github.rest.pulls.get({
|
const pr = await github.rest.pulls.get({
|
||||||
@@ -98,7 +98,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post deploy success comment
|
- name: Post deploy success comment
|
||||||
if: steps.check_status.outputs.should_deploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -126,7 +126,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post undeploy success comment
|
- name: Post undeploy success comment
|
||||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -139,7 +139,7 @@ jobs:
|
|||||||
- name: Check deployment status on PR close
|
- name: Check deployment status on PR close
|
||||||
id: check_pr_close
|
id: check_pr_close
|
||||||
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const comments = await github.rest.issues.listComments({
|
const comments = await github.rest.issues.listComments({
|
||||||
@@ -187,7 +187,7 @@ jobs:
|
|||||||
github.event_name == 'pull_request' &&
|
github.event_name == 'pull_request' &&
|
||||||
github.event.action == 'closed' &&
|
github.event.action == 'closed' &&
|
||||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
|
|||||||
38
.github/workflows/platform-frontend-ci.yml
vendored
38
.github/workflows/platform-frontend-ci.yml
vendored
@@ -27,22 +27,13 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||||
components-changed: ${{ steps.filter.outputs.components }}
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Check for component changes
|
|
||||||
uses: dorny/paths-filter@v3
|
|
||||||
id: filter
|
|
||||||
with:
|
|
||||||
filters: |
|
|
||||||
components:
|
|
||||||
- 'autogpt_platform/frontend/src/components/**'
|
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -54,7 +45,7 @@ jobs:
|
|||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
key: ${{ steps.cache-key.outputs.key }}
|
||||||
@@ -74,7 +65,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -82,7 +73,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -99,11 +90,8 @@ jobs:
|
|||||||
chromatic:
|
chromatic:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: setup
|
needs: setup
|
||||||
# Disabled: to re-enable, remove 'false &&' from the condition below
|
# Only run on dev branch pushes or PRs targeting dev
|
||||||
if: >-
|
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
|
||||||
false
|
|
||||||
&& (github.ref == 'refs/heads/dev' || github.base_ref == 'dev')
|
|
||||||
&& needs.setup.outputs.components-changed == 'true'
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
@@ -112,7 +100,7 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -120,7 +108,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -153,7 +141,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -176,7 +164,7 @@ jobs:
|
|||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
- name: Cache Docker layers
|
- name: Cache Docker layers
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: /tmp/.buildx-cache
|
path: /tmp/.buildx-cache
|
||||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -231,7 +219,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -282,7 +270,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -290,7 +278,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
|||||||
8
.github/workflows/platform-fullstack-ci.yml
vendored
8
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ jobs:
|
|||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
key: ${{ steps.cache-key.outputs.key }}
|
||||||
@@ -68,7 +68,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -88,7 +88,7 @@ jobs:
|
|||||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
|||||||
1652
autogpt_platform/autogpt_libs/poetry.lock
generated
1652
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.10,<4.0"
|
python = ">=3.10,<4.0"
|
||||||
colorama = "^0.4.6"
|
colorama = "^0.4.6"
|
||||||
cryptography = "^46.0"
|
cryptography = "^45.0"
|
||||||
expiringdict = "^1.2.2"
|
expiringdict = "^1.2.2"
|
||||||
fastapi = "^0.128.0"
|
fastapi = "^0.116.1"
|
||||||
google-cloud-logging = "^3.13.0"
|
google-cloud-logging = "^3.12.1"
|
||||||
launchdarkly-server-sdk = "^9.14.1"
|
launchdarkly-server-sdk = "^9.12.0"
|
||||||
pydantic = "^2.12.5"
|
pydantic = "^2.11.7"
|
||||||
pydantic-settings = "^2.12.0"
|
pydantic-settings = "^2.10.1"
|
||||||
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
supabase = "^2.27.2"
|
supabase = "^2.16.0"
|
||||||
uvicorn = "^0.40.0"
|
uvicorn = "^0.35.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pyright = "^1.1.408"
|
pyright = "^1.1.404"
|
||||||
pytest = "^8.4.1"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.3.0"
|
pytest-asyncio = "^1.1.0"
|
||||||
pytest-mock = "^3.15.1"
|
pytest-mock = "^3.14.1"
|
||||||
pytest-cov = "^6.2.1"
|
pytest-cov = "^6.2.1"
|
||||||
ruff = "^0.15.0"
|
ruff = "^0.12.11"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|||||||
@@ -152,7 +152,6 @@ REPLICATE_API_KEY=
|
|||||||
REVID_API_KEY=
|
REVID_API_KEY=
|
||||||
SCREENSHOTONE_API_KEY=
|
SCREENSHOTONE_API_KEY=
|
||||||
UNREAL_SPEECH_API_KEY=
|
UNREAL_SPEECH_API_KEY=
|
||||||
ELEVENLABS_API_KEY=
|
|
||||||
|
|
||||||
# Data & Search Services
|
# Data & Search Services
|
||||||
E2B_API_KEY=
|
E2B_API_KEY=
|
||||||
|
|||||||
3
autogpt_platform/backend/.gitignore
vendored
3
autogpt_platform/backend/.gitignore
vendored
@@ -19,6 +19,3 @@ load-tests/*.json
|
|||||||
load-tests/*.log
|
load-tests/*.log
|
||||||
load-tests/node_modules/*
|
load-tests/node_modules/*
|
||||||
migrations/*/rollback*.sql
|
migrations/*/rollback*.sql
|
||||||
|
|
||||||
# Workspace files
|
|
||||||
workspaces/
|
|
||||||
|
|||||||
@@ -62,12 +62,10 @@ ENV POETRY_HOME=/opt/poetry \
|
|||||||
DEBIAN_FRONTEND=noninteractive
|
DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
# Install Python without upgrading system-managed packages
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
python3.13 \
|
python3.13 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
ffmpeg \
|
|
||||||
imagemagick \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy only necessary files from builder
|
# Copy only necessary files from builder
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class ChatConfig(BaseSettings):
|
|||||||
|
|
||||||
# OpenAI API Configuration
|
# OpenAI API Configuration
|
||||||
model: str = Field(
|
model: str = Field(
|
||||||
default="anthropic/claude-opus-4.6", description="Default model to use"
|
default="anthropic/claude-opus-4.5", description="Default model to use"
|
||||||
)
|
)
|
||||||
title_model: str = Field(
|
title_model: str = Field(
|
||||||
default="openai/gpt-4o-mini",
|
default="openai/gpt-4o-mini",
|
||||||
|
|||||||
@@ -45,7 +45,10 @@ async def create_chat_session(
|
|||||||
successfulAgentRuns=SafeJson({}),
|
successfulAgentRuns=SafeJson({}),
|
||||||
successfulAgentSchedules=SafeJson({}),
|
successfulAgentSchedules=SafeJson({}),
|
||||||
)
|
)
|
||||||
return await PrismaChatSession.prisma().create(data=data)
|
return await PrismaChatSession.prisma().create(
|
||||||
|
data=data,
|
||||||
|
include={"Messages": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def update_chat_session(
|
async def update_chat_session(
|
||||||
|
|||||||
@@ -266,38 +266,12 @@ async def stream_chat_post(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
|
||||||
|
|
||||||
stream_start_time = time.perf_counter()
|
|
||||||
|
|
||||||
# Base log metadata (task_id added after creation)
|
|
||||||
log_meta = {"component": "ChatStream", "session_id": session_id}
|
|
||||||
if user_id:
|
|
||||||
log_meta["user_id"] = user_id
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
|
|
||||||
f"user={user_id}, message_len={len(request.message)}",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
|
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
logger.info(
|
|
||||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Create a task in the stream registry for reconnection support
|
||||||
task_id = str(uuid_module.uuid4())
|
task_id = str(uuid_module.uuid4())
|
||||||
operation_id = str(uuid_module.uuid4())
|
operation_id = str(uuid_module.uuid4())
|
||||||
log_meta["task_id"] = task_id
|
|
||||||
|
|
||||||
task_create_start = time.perf_counter()
|
|
||||||
await stream_registry.create_task(
|
await stream_registry.create_task(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -306,46 +280,14 @@ async def stream_chat_post(
|
|||||||
tool_name="chat",
|
tool_name="chat",
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
)
|
)
|
||||||
logger.info(
|
|
||||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Background task that runs the AI generation independently of SSE connection
|
# Background task that runs the AI generation independently of SSE connection
|
||||||
async def run_ai_generation():
|
async def run_ai_generation():
|
||||||
import time as time_module
|
|
||||||
|
|
||||||
gen_start_time = time_module.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
first_chunk_time, ttfc = None, None
|
|
||||||
chunk_count = 0
|
|
||||||
try:
|
try:
|
||||||
# Emit a start event with task_id for reconnection
|
# Emit a start event with task_id for reconnection
|
||||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||||
logger.info(
|
|
||||||
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time)*1000:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
|
|
||||||
* 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"[TIMING] Calling stream_chat_completion",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
session_id,
|
session_id,
|
||||||
request.message,
|
request.message,
|
||||||
@@ -354,202 +296,54 @@ async def stream_chat_post(
|
|||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
context=request.context,
|
context=request.context,
|
||||||
):
|
):
|
||||||
chunk_count += 1
|
|
||||||
if first_chunk_time is None:
|
|
||||||
first_chunk_time = time_module.perf_counter()
|
|
||||||
ttfc = first_chunk_time - gen_start_time
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunk_type": type(chunk).__name__,
|
|
||||||
"time_to_first_chunk_ms": ttfc * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Write to Redis (subscribers will receive via XREAD)
|
# Write to Redis (subscribers will receive via XREAD)
|
||||||
await stream_registry.publish_chunk(task_id, chunk)
|
await stream_registry.publish_chunk(task_id, chunk)
|
||||||
|
|
||||||
gen_end_time = time_module.perf_counter()
|
# Mark task as completed
|
||||||
total_time = (gen_end_time - gen_start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
|
||||||
f"task={task_id}, session={session_id}, "
|
|
||||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"time_to_first_chunk_ms": (
|
|
||||||
ttfc * 1000 if ttfc is not None else None
|
|
||||||
),
|
|
||||||
"n_chunks": chunk_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
await stream_registry.mark_task_completed(task_id, "completed")
|
await stream_registry.mark_task_completed(task_id, "completed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed = time_module.perf_counter() - gen_start_time
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
f"Error in background AI generation for session {session_id}: {e}"
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed * 1000,
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
await stream_registry.mark_task_completed(task_id, "failed")
|
await stream_registry.mark_task_completed(task_id, "failed")
|
||||||
|
|
||||||
# Start the AI generation in a background task
|
# Start the AI generation in a background task
|
||||||
bg_task = asyncio.create_task(run_ai_generation())
|
bg_task = asyncio.create_task(run_ai_generation())
|
||||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
# SSE endpoint that subscribes to the task's stream
|
# SSE endpoint that subscribes to the task's stream
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
import time as time_module
|
|
||||||
|
|
||||||
event_gen_start = time_module.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
|
||||||
f"user={user_id}",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
subscriber_queue = None
|
subscriber_queue = None
|
||||||
first_chunk_yielded = False
|
|
||||||
chunks_yielded = 0
|
|
||||||
try:
|
try:
|
||||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||||
subscribe_start = time_module.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
"[TIMING] Calling subscribe_to_task",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
last_message_id="0-0", # Get all messages from the beginning
|
last_message_id="0-0", # Get all messages from the beginning
|
||||||
)
|
)
|
||||||
subscribe_time = (time_module.perf_counter() - subscribe_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] subscribe_to_task completed in {subscribe_time:.1f}ms, "
|
|
||||||
f"queue_ok={subscriber_queue is not None}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"duration_ms": subscribe_time,
|
|
||||||
"queue_obtained": subscriber_queue is not None,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if subscriber_queue is None:
|
if subscriber_queue is None:
|
||||||
logger.info(
|
|
||||||
"[TIMING] subscriber_queue is None, yielding finish",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
yield StreamFinish().to_sse()
|
yield StreamFinish().to_sse()
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
return
|
return
|
||||||
|
|
||||||
# Read from the subscriber queue and yield to SSE
|
# Read from the subscriber queue and yield to SSE
|
||||||
logger.info(
|
|
||||||
"[TIMING] Starting to read from subscriber_queue",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
queue_wait_start = time_module.perf_counter()
|
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||||
queue_wait_time = (
|
|
||||||
time_module.perf_counter() - queue_wait_start
|
|
||||||
) * 1000
|
|
||||||
chunks_yielded += 1
|
|
||||||
|
|
||||||
if not first_chunk_yielded:
|
|
||||||
first_chunk_yielded = True
|
|
||||||
elapsed = time_module.perf_counter() - event_gen_start
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
|
|
||||||
f"type={type(chunk).__name__}, "
|
|
||||||
f"wait={queue_wait_time:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunk_type": type(chunk).__name__,
|
|
||||||
"elapsed_ms": elapsed * 1000,
|
|
||||||
"queue_wait_ms": queue_wait_time,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
elif chunks_yielded % 50 == 0:
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Chunk #{chunks_yielded}, "
|
|
||||||
f"type={type(chunk).__name__}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunk_number": chunks_yielded,
|
|
||||||
"chunk_type": type(chunk).__name__,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
yield chunk.to_sse()
|
yield chunk.to_sse()
|
||||||
|
|
||||||
# Check for finish signal
|
# Check for finish signal
|
||||||
if isinstance(chunk, StreamFinish):
|
if isinstance(chunk, StreamFinish):
|
||||||
total_time = time_module.perf_counter() - event_gen_start
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
|
|
||||||
f"n_chunks={chunks_yielded}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunks_yielded": chunks_yielded,
|
|
||||||
"total_time_ms": total_time * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# Send heartbeat to keep connection alive
|
# Send heartbeat to keep connection alive
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Heartbeat timeout, chunks_so_far={chunks_yielded}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {**log_meta, "chunks_so_far": chunks_yielded}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
yield StreamHeartbeat().to_sse()
|
yield StreamHeartbeat().to_sse()
|
||||||
|
|
||||||
except GeneratorExit:
|
except GeneratorExit:
|
||||||
logger.info(
|
|
||||||
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunks_yielded": chunks_yielded,
|
|
||||||
"reason": "client_disconnect",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
pass # Client disconnected - background task continues
|
pass # Client disconnected - background task continues
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
||||||
logger.error(
|
|
||||||
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||||
if subscriber_queue is not None:
|
if subscriber_queue is not None:
|
||||||
@@ -563,18 +357,6 @@ async def stream_chat_post(
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||||
total_time = time_module.perf_counter() - event_gen_start
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
|
||||||
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time * 1000,
|
|
||||||
"chunks_yielded": chunks_yielded,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
@@ -643,7 +425,7 @@ async def stream_chat_get(
|
|||||||
"Chat stream completed",
|
"Chat stream completed",
|
||||||
extra={
|
extra={
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"n_chunks": chunk_count,
|
"chunk_count": chunk_count,
|
||||||
"first_chunk_type": first_chunk_type,
|
"first_chunk_type": first_chunk_type,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from backend.data.understanding import (
|
|||||||
get_business_understanding,
|
get_business_understanding,
|
||||||
)
|
)
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.settings import AppEnvironment, Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from . import db as chat_db
|
from . import db as chat_db
|
||||||
from . import stream_registry
|
from . import stream_registry
|
||||||
@@ -222,18 +222,8 @@ async def _get_system_prompt_template(context: str) -> str:
|
|||||||
try:
|
try:
|
||||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||||
# Use asyncio.to_thread to avoid blocking the event loop
|
# Use asyncio.to_thread to avoid blocking the event loop
|
||||||
# In non-production environments, fetch the latest prompt version
|
|
||||||
# instead of the production-labeled version for easier testing
|
|
||||||
label = (
|
|
||||||
None
|
|
||||||
if settings.config.app_env == AppEnvironment.PRODUCTION
|
|
||||||
else "latest"
|
|
||||||
)
|
|
||||||
prompt = await asyncio.to_thread(
|
prompt = await asyncio.to_thread(
|
||||||
langfuse.get_prompt,
|
langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0
|
||||||
config.langfuse_prompt_name,
|
|
||||||
label=label,
|
|
||||||
cache_ttl_seconds=0,
|
|
||||||
)
|
)
|
||||||
return prompt.compile(users_information=context)
|
return prompt.compile(users_information=context)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -371,45 +361,21 @@ async def stream_chat_completion(
|
|||||||
ValueError: If max_context_messages is exceeded
|
ValueError: If max_context_messages is exceeded
|
||||||
|
|
||||||
"""
|
"""
|
||||||
completion_start = time.monotonic()
|
|
||||||
|
|
||||||
# Build log metadata for structured logging
|
|
||||||
log_meta = {"component": "ChatService", "session_id": session_id}
|
|
||||||
if user_id:
|
|
||||||
log_meta["user_id"] = user_id
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
|
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
||||||
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"message_len": len(message) if message else 0,
|
|
||||||
"is_user_message": is_user_message,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only fetch from Redis if session not provided (initial call)
|
# Only fetch from Redis if session not provided (initial call)
|
||||||
if session is None:
|
if session is None:
|
||||||
fetch_start = time.monotonic()
|
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
fetch_time = (time.monotonic() - fetch_start) * 1000
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
|
f"Fetched session from Redis: {session.session_id if session else 'None'}, "
|
||||||
f"n_messages={len(session.messages) if session else 0}",
|
f"message_count={len(session.messages) if session else 0}"
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"duration_ms": fetch_time,
|
|
||||||
"n_messages": len(session.messages) if session else 0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] Using provided session, messages={len(session.messages)}",
|
f"Using provided session object: {session.session_id}, "
|
||||||
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
f"message_count={len(session.messages)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
@@ -430,25 +396,17 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
# Track user message in PostHog
|
# Track user message in PostHog
|
||||||
if is_user_message:
|
if is_user_message:
|
||||||
posthog_start = time.monotonic()
|
|
||||||
track_user_message(
|
track_user_message(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
message_length=len(message),
|
message_length=len(message),
|
||||||
)
|
)
|
||||||
posthog_time = (time.monotonic() - posthog_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
upsert_start = time.monotonic()
|
|
||||||
session = await upsert_chat_session(session)
|
|
||||||
upsert_time = (time.monotonic() - upsert_start) * 1000
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
f"Upserting session: {session.session_id} with user id {session.user_id}, "
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
|
f"message_count={len(session.messages)}"
|
||||||
)
|
)
|
||||||
|
session = await upsert_chat_session(session)
|
||||||
assert session, "Session not found"
|
assert session, "Session not found"
|
||||||
|
|
||||||
# Generate title for new sessions on first user message (non-blocking)
|
# Generate title for new sessions on first user message (non-blocking)
|
||||||
@@ -486,13 +444,7 @@ async def stream_chat_completion(
|
|||||||
asyncio.create_task(_update_title())
|
asyncio.create_task(_update_title())
|
||||||
|
|
||||||
# Build system prompt with business understanding
|
# Build system prompt with business understanding
|
||||||
prompt_start = time.monotonic()
|
|
||||||
system_prompt, understanding = await _build_system_prompt(user_id)
|
system_prompt, understanding = await _build_system_prompt(user_id)
|
||||||
prompt_time = (time.monotonic() - prompt_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize variables for streaming
|
# Initialize variables for streaming
|
||||||
assistant_response = ChatMessage(
|
assistant_response = ChatMessage(
|
||||||
@@ -521,18 +473,9 @@ async def stream_chat_completion(
|
|||||||
text_block_id = str(uuid_module.uuid4())
|
text_block_id = str(uuid_module.uuid4())
|
||||||
|
|
||||||
# Yield message start
|
# Yield message start
|
||||||
setup_time = (time.monotonic() - completion_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
|
||||||
)
|
|
||||||
yield StreamStart(messageId=message_id)
|
yield StreamStart(messageId=message_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(
|
|
||||||
"[TIMING] Calling _stream_chat_chunks",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
async for chunk in _stream_chat_chunks(
|
async for chunk in _stream_chat_chunks(
|
||||||
session=session,
|
session=session,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@@ -675,9 +618,6 @@ async def stream_chat_completion(
|
|||||||
total_tokens=chunk.totalTokens,
|
total_tokens=chunk.totalTokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(chunk, StreamHeartbeat):
|
|
||||||
# Pass through heartbeat to keep SSE connection alive
|
|
||||||
yield chunk
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
||||||
|
|
||||||
@@ -940,21 +880,9 @@ async def _stream_chat_chunks(
|
|||||||
SSE formatted JSON response objects
|
SSE formatted JSON response objects
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import time as time_module
|
|
||||||
|
|
||||||
stream_chunks_start = time_module.perf_counter()
|
|
||||||
model = config.model
|
model = config.model
|
||||||
|
|
||||||
# Build log metadata for structured logging
|
logger.info("Starting pure chat stream")
|
||||||
log_meta = {"component": "ChatService", "session_id": session.session_id}
|
|
||||||
if session.user_id:
|
|
||||||
log_meta["user_id"] = session.user_id
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
|
|
||||||
f"user={session.user_id}, n_messages={len(session.messages)}",
|
|
||||||
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = session.to_openai_messages()
|
messages = session.to_openai_messages()
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
@@ -965,18 +893,12 @@ async def _stream_chat_chunks(
|
|||||||
messages = [system_message] + messages
|
messages = [system_message] + messages
|
||||||
|
|
||||||
# Apply context window management
|
# Apply context window management
|
||||||
context_start = time_module.perf_counter()
|
|
||||||
context_result = await _manage_context_window(
|
context_result = await _manage_context_window(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=model,
|
model=model,
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
base_url=config.base_url,
|
base_url=config.base_url,
|
||||||
)
|
)
|
||||||
context_time = (time_module.perf_counter() - context_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
if context_result.error:
|
if context_result.error:
|
||||||
if "System prompt dropped" in context_result.error:
|
if "System prompt dropped" in context_result.error:
|
||||||
@@ -1011,19 +933,9 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
while retry_count <= MAX_RETRIES:
|
while retry_count <= MAX_RETRIES:
|
||||||
try:
|
try:
|
||||||
elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000
|
|
||||||
retry_info = (
|
|
||||||
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
|
|
||||||
)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
|
f"Creating OpenAI chat completion stream..."
|
||||||
extra={
|
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed,
|
|
||||||
"retry_count": retry_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build extra_body for OpenRouter tracing and PostHog analytics
|
# Build extra_body for OpenRouter tracing and PostHog analytics
|
||||||
@@ -1040,7 +952,6 @@ async def _stream_chat_chunks(
|
|||||||
:128
|
:128
|
||||||
] # OpenRouter limit
|
] # OpenRouter limit
|
||||||
|
|
||||||
api_call_start = time_module.perf_counter()
|
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
@@ -1050,11 +961,6 @@ async def _stream_chat_chunks(
|
|||||||
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Variables to accumulate tool calls
|
# Variables to accumulate tool calls
|
||||||
tool_calls: list[dict[str, Any]] = []
|
tool_calls: list[dict[str, Any]] = []
|
||||||
@@ -1065,13 +971,10 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
# Track if we've started the text block
|
# Track if we've started the text block
|
||||||
text_started = False
|
text_started = False
|
||||||
first_content_chunk = True
|
|
||||||
chunk_count = 0
|
|
||||||
|
|
||||||
# Process the stream
|
# Process the stream
|
||||||
chunk: ChatCompletionChunk
|
chunk: ChatCompletionChunk
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
chunk_count += 1
|
|
||||||
if chunk.usage:
|
if chunk.usage:
|
||||||
yield StreamUsage(
|
yield StreamUsage(
|
||||||
promptTokens=chunk.usage.prompt_tokens,
|
promptTokens=chunk.usage.prompt_tokens,
|
||||||
@@ -1094,23 +997,6 @@ async def _stream_chat_chunks(
|
|||||||
if not text_started and text_block_id:
|
if not text_started and text_block_id:
|
||||||
yield StreamTextStart(id=text_block_id)
|
yield StreamTextStart(id=text_block_id)
|
||||||
text_started = True
|
text_started = True
|
||||||
# Log timing for first content chunk
|
|
||||||
if first_content_chunk:
|
|
||||||
first_content_chunk = False
|
|
||||||
ttfc = (
|
|
||||||
time_module.perf_counter() - api_call_start
|
|
||||||
) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
|
|
||||||
f"(since API call), n_chunks={chunk_count}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"time_to_first_chunk_ms": ttfc,
|
|
||||||
"n_chunks": chunk_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Stream the text delta
|
# Stream the text delta
|
||||||
text_response = StreamTextDelta(
|
text_response = StreamTextDelta(
|
||||||
id=text_block_id or "",
|
id=text_block_id or "",
|
||||||
@@ -1167,21 +1053,7 @@ async def _stream_chat_chunks(
|
|||||||
toolName=tool_calls[idx]["function"]["name"],
|
toolName=tool_calls[idx]["function"]["name"],
|
||||||
)
|
)
|
||||||
emitted_start_for_idx.add(idx)
|
emitted_start_for_idx.add(idx)
|
||||||
stream_duration = time_module.perf_counter() - api_call_start
|
logger.info(f"Stream complete. Finish reason: {finish_reason}")
|
||||||
logger.info(
|
|
||||||
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
|
|
||||||
f"duration={stream_duration:.2f}s, "
|
|
||||||
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"stream_duration_ms": stream_duration * 1000,
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
"n_chunks": chunk_count,
|
|
||||||
"n_tool_calls": len(tool_calls),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Yield all accumulated tool calls after the stream is complete
|
# Yield all accumulated tool calls after the stream is complete
|
||||||
# This ensures all tool call arguments have been fully received
|
# This ensures all tool call arguments have been fully received
|
||||||
@@ -1201,12 +1073,6 @@ async def _stream_chat_chunks(
|
|||||||
# Re-raise to trigger retry logic in the parent function
|
# Re-raise to trigger retry logic in the parent function
|
||||||
raise
|
raise
|
||||||
|
|
||||||
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
|
|
||||||
f"session={session.session_id}, user={session.user_id}",
|
|
||||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
|
||||||
)
|
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -104,24 +104,6 @@ async def create_task(
|
|||||||
Returns:
|
Returns:
|
||||||
The created ActiveTask instance (metadata only)
|
The created ActiveTask instance (metadata only)
|
||||||
"""
|
"""
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
# Build log metadata for structured logging
|
|
||||||
log_meta = {
|
|
||||||
"component": "StreamRegistry",
|
|
||||||
"task_id": task_id,
|
|
||||||
"session_id": session_id,
|
|
||||||
}
|
|
||||||
if user_id:
|
|
||||||
log_meta["user_id"] = user_id
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
|
|
||||||
task = ActiveTask(
|
task = ActiveTask(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -132,18 +114,10 @@ async def create_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Store metadata in Redis
|
# Store metadata in Redis
|
||||||
redis_start = time.perf_counter()
|
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
redis_time = (time.perf_counter() - redis_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
meta_key = _get_task_meta_key(task_id)
|
||||||
op_key = _get_operation_mapping_key(operation_id)
|
op_key = _get_operation_mapping_key(operation_id)
|
||||||
|
|
||||||
hset_start = time.perf_counter()
|
|
||||||
await redis.hset( # type: ignore[misc]
|
await redis.hset( # type: ignore[misc]
|
||||||
meta_key,
|
meta_key,
|
||||||
mapping={
|
mapping={
|
||||||
@@ -157,22 +131,12 @@ async def create_task(
|
|||||||
"created_at": task.created_at.isoformat(),
|
"created_at": task.created_at.isoformat(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
hset_time = (time.perf_counter() - hset_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] redis.hset took {hset_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
await redis.expire(meta_key, config.stream_ttl)
|
await redis.expire(meta_key, config.stream_ttl)
|
||||||
|
|
||||||
# Create operation_id -> task_id mapping for webhook lookups
|
# Create operation_id -> task_id mapping for webhook lookups
|
||||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||||
|
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
logger.debug(f"Created task {task_id} for session {session_id}")
|
||||||
logger.info(
|
|
||||||
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
|
|
||||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
@@ -192,60 +156,26 @@ async def publish_chunk(
|
|||||||
Returns:
|
Returns:
|
||||||
The Redis Stream message ID
|
The Redis Stream message ID
|
||||||
"""
|
"""
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
chunk_type = type(chunk).__name__
|
|
||||||
chunk_json = chunk.model_dump_json()
|
chunk_json = chunk.model_dump_json()
|
||||||
message_id = "0-0"
|
message_id = "0-0"
|
||||||
|
|
||||||
# Build log metadata
|
|
||||||
log_meta = {
|
|
||||||
"component": "StreamRegistry",
|
|
||||||
"task_id": task_id,
|
|
||||||
"chunk_type": chunk_type,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
# Write to Redis Stream for persistence and real-time delivery
|
# Write to Redis Stream for persistence and real-time delivery
|
||||||
xadd_start = time.perf_counter()
|
|
||||||
raw_id = await redis.xadd(
|
raw_id = await redis.xadd(
|
||||||
stream_key,
|
stream_key,
|
||||||
{"data": chunk_json},
|
{"data": chunk_json},
|
||||||
maxlen=config.stream_max_length,
|
maxlen=config.stream_max_length,
|
||||||
)
|
)
|
||||||
xadd_time = (time.perf_counter() - xadd_start) * 1000
|
|
||||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||||
|
|
||||||
# Set TTL on stream to match task metadata TTL
|
# Set TTL on stream to match task metadata TTL
|
||||||
await redis.expire(stream_key, config.stream_ttl)
|
await redis.expire(stream_key, config.stream_ttl)
|
||||||
|
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
|
||||||
# Only log timing for significant chunks or slow operations
|
|
||||||
if (
|
|
||||||
chunk_type
|
|
||||||
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
|
|
||||||
or total_time > 50
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"xadd_time_ms": xadd_time,
|
|
||||||
"message_id": message_id,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed = (time.perf_counter() - start_time) * 1000
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}",
|
f"Failed to publish chunk for task {task_id}: {e}",
|
||||||
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -270,61 +200,24 @@ async def subscribe_to_task(
|
|||||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||||
or user doesn't have access
|
or user doesn't have access
|
||||||
"""
|
"""
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
# Build log metadata
|
|
||||||
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
|
||||||
if user_id:
|
|
||||||
log_meta["user_id"] = user_id
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
|
|
||||||
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
|
|
||||||
)
|
|
||||||
|
|
||||||
redis_start = time.perf_counter()
|
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
meta_key = _get_task_meta_key(task_id)
|
meta_key = _get_task_meta_key(task_id)
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
hgetall_time = (time.perf_counter() - redis_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
if not meta:
|
if not meta:
|
||||||
elapsed = (time.perf_counter() - start_time) * 1000
|
logger.debug(f"Task {task_id} not found in Redis")
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed,
|
|
||||||
"reason": "task_not_found",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||||
task_status = meta.get("status", "")
|
task_status = meta.get("status", "")
|
||||||
task_user_id = meta.get("user_id", "") or None
|
task_user_id = meta.get("user_id", "") or None
|
||||||
log_meta["session_id"] = meta.get("session_id", "")
|
|
||||||
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
# Validate ownership - if task has an owner, requester must match
|
||||||
if task_user_id:
|
if task_user_id:
|
||||||
if user_id != task_user_id:
|
if user_id != task_user_id:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}",
|
f"User {user_id} denied access to task {task_id} "
|
||||||
extra={
|
f"owned by {task_user_id}"
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"task_owner": task_user_id,
|
|
||||||
"reason": "access_denied",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -332,19 +225,7 @@ async def subscribe_to_task(
|
|||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
# Step 1: Replay messages from Redis Stream
|
# Step 1: Replay messages from Redis Stream
|
||||||
xread_start = time.perf_counter()
|
|
||||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"duration_ms": xread_time,
|
|
||||||
"task_status": task_status,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
replayed_count = 0
|
replayed_count = 0
|
||||||
replay_last_id = last_message_id
|
replay_last_id = last_message_id
|
||||||
@@ -363,48 +244,19 @@ async def subscribe_to_task(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to replay message: {e}")
|
logger.warning(f"Failed to replay message: {e}")
|
||||||
|
|
||||||
logger.info(
|
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
|
||||||
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"n_messages_replayed": replayed_count,
|
|
||||||
"replay_last_id": replay_last_id,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 2: If task is still running, start stream listener for live updates
|
# Step 2: If task is still running, start stream listener for live updates
|
||||||
if task_status == "running":
|
if task_status == "running":
|
||||||
logger.info(
|
|
||||||
"[TIMING] Task still running, starting _stream_listener",
|
|
||||||
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
|
||||||
)
|
|
||||||
listener_task = asyncio.create_task(
|
listener_task = asyncio.create_task(
|
||||||
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
|
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
||||||
)
|
)
|
||||||
# Track listener task for cleanup on unsubscribe
|
# Track listener task for cleanup on unsubscribe
|
||||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||||
else:
|
else:
|
||||||
# Task is completed/failed - add finish marker
|
# Task is completed/failed - add finish marker
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Task already {task_status}, adding StreamFinish",
|
|
||||||
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
|
||||||
)
|
|
||||||
await subscriber_queue.put(StreamFinish())
|
await subscriber_queue.put(StreamFinish())
|
||||||
|
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
|
|
||||||
f"n_messages_replayed={replayed_count}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"n_messages_replayed": replayed_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return subscriber_queue
|
return subscriber_queue
|
||||||
|
|
||||||
|
|
||||||
@@ -412,7 +264,6 @@ async def _stream_listener(
|
|||||||
task_id: str,
|
task_id: str,
|
||||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||||
last_replayed_id: str,
|
last_replayed_id: str,
|
||||||
log_meta: dict | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||||
|
|
||||||
@@ -423,27 +274,10 @@ async def _stream_listener(
|
|||||||
task_id: Task ID to listen for
|
task_id: Task ID to listen for
|
||||||
subscriber_queue: Queue to deliver messages to
|
subscriber_queue: Queue to deliver messages to
|
||||||
last_replayed_id: Last message ID from replay (continue from here)
|
last_replayed_id: Last message ID from replay (continue from here)
|
||||||
log_meta: Structured logging metadata
|
|
||||||
"""
|
"""
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
# Use provided log_meta or build minimal one
|
|
||||||
if log_meta is None:
|
|
||||||
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
|
|
||||||
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
|
|
||||||
)
|
|
||||||
|
|
||||||
queue_id = id(subscriber_queue)
|
queue_id = id(subscriber_queue)
|
||||||
# Track the last successfully delivered message ID for recovery hints
|
# Track the last successfully delivered message ID for recovery hints
|
||||||
last_delivered_id = last_replayed_id
|
last_delivered_id = last_replayed_id
|
||||||
messages_delivered = 0
|
|
||||||
first_message_time = None
|
|
||||||
xread_count = 0
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
@@ -453,39 +287,9 @@ async def _stream_listener(
|
|||||||
while True:
|
while True:
|
||||||
# Block for up to 30 seconds waiting for new messages
|
# Block for up to 30 seconds waiting for new messages
|
||||||
# This allows periodic checking if task is still running
|
# This allows periodic checking if task is still running
|
||||||
xread_start = time.perf_counter()
|
|
||||||
xread_count += 1
|
|
||||||
messages = await redis.xread(
|
messages = await redis.xread(
|
||||||
{stream_key: current_id}, block=30000, count=100
|
{stream_key: current_id}, block=30000, count=100
|
||||||
)
|
)
|
||||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
|
||||||
|
|
||||||
if messages:
|
|
||||||
msg_count = sum(len(msgs) for _, msgs in messages)
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"xread_count": xread_count,
|
|
||||||
"n_messages": msg_count,
|
|
||||||
"duration_ms": xread_time,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
elif xread_time > 1000:
|
|
||||||
# Only log timeouts (30s blocking)
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"xread_count": xread_count,
|
|
||||||
"duration_ms": xread_time,
|
|
||||||
"reason": "timeout",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if not messages:
|
if not messages:
|
||||||
# Timeout - check if task is still running
|
# Timeout - check if task is still running
|
||||||
@@ -522,30 +326,10 @@ async def _stream_listener(
|
|||||||
)
|
)
|
||||||
# Update last delivered ID on successful delivery
|
# Update last delivered ID on successful delivery
|
||||||
last_delivered_id = current_id
|
last_delivered_id = current_id
|
||||||
messages_delivered += 1
|
|
||||||
if first_message_time is None:
|
|
||||||
first_message_time = time.perf_counter()
|
|
||||||
elapsed = (first_message_time - start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed,
|
|
||||||
"chunk_type": type(chunk).__name__,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
|
f"Subscriber queue full for task {task_id}, "
|
||||||
extra={
|
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"timeout_s": QUEUE_PUT_TIMEOUT,
|
|
||||||
"reason": "queue_full",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
# Send overflow error with recovery info
|
# Send overflow error with recovery info
|
||||||
try:
|
try:
|
||||||
@@ -567,44 +351,15 @@ async def _stream_listener(
|
|||||||
|
|
||||||
# Stop listening on finish
|
# Stop listening on finish
|
||||||
if isinstance(chunk, StreamFinish):
|
if isinstance(chunk, StreamFinish):
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"messages_delivered": messages_delivered,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(f"Error processing stream message: {e}")
|
||||||
f"Error processing stream message: {e}",
|
|
||||||
extra={"json_fields": {**log_meta, "error": str(e)}},
|
|
||||||
)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
elapsed = (time.perf_counter() - start_time) * 1000
|
logger.debug(f"Stream listener cancelled for task {task_id}")
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed,
|
|
||||||
"messages_delivered": messages_delivered,
|
|
||||||
"reason": "cancelled",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
raise # Re-raise to propagate cancellation
|
raise # Re-raise to propagate cancellation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed = (time.perf_counter() - start_time) * 1000
|
logger.error(f"Stream listener error for task {task_id}: {e}")
|
||||||
logger.error(
|
|
||||||
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
|
|
||||||
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
|
||||||
)
|
|
||||||
# On error, send finish to unblock subscriber
|
# On error, send finish to unblock subscriber
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
@@ -613,24 +368,10 @@ async def _stream_listener(
|
|||||||
)
|
)
|
||||||
except (asyncio.TimeoutError, asyncio.QueueFull):
|
except (asyncio.TimeoutError, asyncio.QueueFull):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Could not deliver finish event after error",
|
f"Could not deliver finish event for task {task_id} after error"
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
# Clean up listener task mapping on exit
|
# Clean up listener task mapping on exit
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
|
|
||||||
f"delivered={messages_delivered}, xread_count={xread_count}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"messages_delivered": messages_delivered,
|
|
||||||
"xread_count": xread_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
_listener_tasks.pop(queue_id, None)
|
_listener_tasks.pop(queue_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,15 @@ from typing import Any, NotRequired, TypedDict
|
|||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
from backend.data.graph import (
|
||||||
|
Graph,
|
||||||
|
Link,
|
||||||
|
Node,
|
||||||
|
create_graph,
|
||||||
|
get_graph,
|
||||||
|
get_graph_all_versions,
|
||||||
|
get_store_listed_graphs,
|
||||||
|
)
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .service import (
|
||||||
@@ -20,6 +28,8 @@ from .service import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||||
|
|
||||||
|
|
||||||
class ExecutionSummary(TypedDict):
|
class ExecutionSummary(TypedDict):
|
||||||
"""Summary of a single execution for quality assessment."""
|
"""Summary of a single execution for quality assessment."""
|
||||||
@@ -659,6 +669,45 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _reassign_node_ids(graph: Graph) -> None:
|
||||||
|
"""Reassign all node and link IDs to new UUIDs.
|
||||||
|
|
||||||
|
This is needed when creating a new version to avoid unique constraint violations.
|
||||||
|
"""
|
||||||
|
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||||
|
|
||||||
|
for node in graph.nodes:
|
||||||
|
node.id = id_map[node.id]
|
||||||
|
|
||||||
|
for link in graph.links:
|
||||||
|
link.id = str(uuid.uuid4())
|
||||||
|
if link.source_id in id_map:
|
||||||
|
link.source_id = id_map[link.source_id]
|
||||||
|
if link.sink_id in id_map:
|
||||||
|
link.sink_id = id_map[link.sink_id]
|
||||||
|
|
||||||
|
|
||||||
|
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
|
||||||
|
"""Populate user_id in AgentExecutorBlock nodes.
|
||||||
|
|
||||||
|
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
|
||||||
|
This function fills in the actual user_id so sub-agents run with correct permissions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_json: Agent JSON dict (modified in place)
|
||||||
|
user_id: User ID to set
|
||||||
|
"""
|
||||||
|
for node in agent_json.get("nodes", []):
|
||||||
|
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
|
||||||
|
input_default = node.get("input_default") or {}
|
||||||
|
if not input_default.get("user_id"):
|
||||||
|
input_default["user_id"] = user_id
|
||||||
|
node["input_default"] = input_default
|
||||||
|
logger.debug(
|
||||||
|
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def save_agent_to_library(
|
async def save_agent_to_library(
|
||||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||||
) -> tuple[Graph, Any]:
|
) -> tuple[Graph, Any]:
|
||||||
@@ -672,10 +721,35 @@ async def save_agent_to_library(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (created Graph, LibraryAgent)
|
Tuple of (created Graph, LibraryAgent)
|
||||||
"""
|
"""
|
||||||
|
# Populate user_id in AgentExecutorBlock nodes before conversion
|
||||||
|
_populate_agent_executor_user_ids(agent_json, user_id)
|
||||||
|
|
||||||
graph = json_to_graph(agent_json)
|
graph = json_to_graph(agent_json)
|
||||||
|
|
||||||
if is_update:
|
if is_update:
|
||||||
return await library_db.update_graph_in_library(graph, user_id)
|
if graph.id:
|
||||||
return await library_db.create_graph_in_library(graph, user_id)
|
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
||||||
|
if existing_versions:
|
||||||
|
latest_version = max(v.version for v in existing_versions)
|
||||||
|
graph.version = latest_version + 1
|
||||||
|
_reassign_node_ids(graph)
|
||||||
|
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
||||||
|
else:
|
||||||
|
graph.id = str(uuid.uuid4())
|
||||||
|
graph.version = 1
|
||||||
|
_reassign_node_ids(graph)
|
||||||
|
logger.info(f"Creating new agent with ID {graph.id}")
|
||||||
|
|
||||||
|
created_graph = await create_graph(graph, user_id)
|
||||||
|
|
||||||
|
library_agents = await library_db.create_library_agent(
|
||||||
|
graph=created_graph,
|
||||||
|
user_id=user_id,
|
||||||
|
sensitive_action_safe_mode=True,
|
||||||
|
create_library_agents_for_sub_graphs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return created_graph, library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||||
|
|||||||
@@ -206,9 +206,9 @@ async def search_agents(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
no_results_msg = (
|
no_results_msg = (
|
||||||
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
|
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
|
else f"No agents matching '{query}' found in your library."
|
||||||
)
|
)
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||||
@@ -224,10 +224,10 @@ async def search_agents(
|
|||||||
message = (
|
message = (
|
||||||
"Now you have found some options for the user to choose from. "
|
"Now you have found some options for the user to choose from. "
|
||||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||||
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
|
"Please ask the user if they would like to use any of these agents."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
|
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgentsFoundResponse(
|
return AgentsFoundResponse(
|
||||||
|
|||||||
@@ -13,32 +13,10 @@ from backend.api.features.chat.tools.models import (
|
|||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
from backend.data.block import BlockType, get_block
|
from backend.data.block import get_block
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_TARGET_RESULTS = 10
|
|
||||||
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
|
|
||||||
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
|
|
||||||
_OVERFETCH_PAGE_SIZE = 40
|
|
||||||
|
|
||||||
# Block types that only work within graphs and cannot run standalone in CoPilot.
|
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES = {
|
|
||||||
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
|
|
||||||
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
|
|
||||||
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
|
|
||||||
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
|
|
||||||
BlockType.NOTE, # Visual annotation only - no runtime behavior
|
|
||||||
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
|
|
||||||
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
|
|
||||||
}
|
|
||||||
|
|
||||||
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
|
||||||
COPILOT_EXCLUDED_BLOCK_IDS = {
|
|
||||||
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
|
|
||||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class FindBlockTool(BaseTool):
|
class FindBlockTool(BaseTool):
|
||||||
"""Tool for searching available blocks."""
|
"""Tool for searching available blocks."""
|
||||||
@@ -110,7 +88,7 @@ class FindBlockTool(BaseTool):
|
|||||||
query=query,
|
query=query,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
page_size=_OVERFETCH_PAGE_SIZE,
|
page_size=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
@@ -130,90 +108,60 @@ class FindBlockTool(BaseTool):
|
|||||||
block = get_block(block_id)
|
block = get_block(block_id)
|
||||||
|
|
||||||
# Skip disabled blocks
|
# Skip disabled blocks
|
||||||
if not block or block.disabled:
|
if block and not block.disabled:
|
||||||
continue
|
# Get input/output schemas
|
||||||
|
input_schema = {}
|
||||||
|
output_schema = {}
|
||||||
|
try:
|
||||||
|
input_schema = block.input_schema.jsonschema()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
output_schema = block.output_schema.jsonschema()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Skip blocks excluded from CoPilot (graph-only blocks)
|
# Get categories from block instance
|
||||||
if (
|
categories = []
|
||||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
if hasattr(block, "categories") and block.categories:
|
||||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
categories = [cat.value for cat in block.categories]
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get input/output schemas
|
# Extract required inputs for easier use
|
||||||
input_schema = {}
|
required_inputs: list[BlockInputFieldInfo] = []
|
||||||
output_schema = {}
|
if input_schema:
|
||||||
try:
|
properties = input_schema.get("properties", {})
|
||||||
input_schema = block.input_schema.jsonschema()
|
required_fields = set(input_schema.get("required", []))
|
||||||
except Exception as e:
|
# Get credential field names to exclude from required inputs
|
||||||
logger.debug(
|
credentials_fields = set(
|
||||||
"Failed to generate input schema for block %s: %s",
|
block.input_schema.get_credentials_fields().keys()
|
||||||
block_id,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
output_schema = block.output_schema.jsonschema()
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(
|
|
||||||
"Failed to generate output schema for block %s: %s",
|
|
||||||
block_id,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get categories from block instance
|
|
||||||
categories = []
|
|
||||||
if hasattr(block, "categories") and block.categories:
|
|
||||||
categories = [cat.value for cat in block.categories]
|
|
||||||
|
|
||||||
# Extract required inputs for easier use
|
|
||||||
required_inputs: list[BlockInputFieldInfo] = []
|
|
||||||
if input_schema:
|
|
||||||
properties = input_schema.get("properties", {})
|
|
||||||
required_fields = set(input_schema.get("required", []))
|
|
||||||
# Get credential field names to exclude from required inputs
|
|
||||||
credentials_fields = set(
|
|
||||||
block.input_schema.get_credentials_fields().keys()
|
|
||||||
)
|
|
||||||
|
|
||||||
for field_name, field_schema in properties.items():
|
|
||||||
# Skip credential fields - they're handled separately
|
|
||||||
if field_name in credentials_fields:
|
|
||||||
continue
|
|
||||||
|
|
||||||
required_inputs.append(
|
|
||||||
BlockInputFieldInfo(
|
|
||||||
name=field_name,
|
|
||||||
type=field_schema.get("type", "string"),
|
|
||||||
description=field_schema.get("description", ""),
|
|
||||||
required=field_name in required_fields,
|
|
||||||
default=field_schema.get("default"),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
blocks.append(
|
for field_name, field_schema in properties.items():
|
||||||
BlockInfoSummary(
|
# Skip credential fields - they're handled separately
|
||||||
id=block_id,
|
if field_name in credentials_fields:
|
||||||
name=block.name,
|
continue
|
||||||
description=block.description or "",
|
|
||||||
categories=categories,
|
required_inputs.append(
|
||||||
input_schema=input_schema,
|
BlockInputFieldInfo(
|
||||||
output_schema=output_schema,
|
name=field_name,
|
||||||
required_inputs=required_inputs,
|
type=field_schema.get("type", "string"),
|
||||||
|
description=field_schema.get("description", ""),
|
||||||
|
required=field_name in required_fields,
|
||||||
|
default=field_schema.get("default"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
blocks.append(
|
||||||
|
BlockInfoSummary(
|
||||||
|
id=block_id,
|
||||||
|
name=block.name,
|
||||||
|
description=block.description or "",
|
||||||
|
categories=categories,
|
||||||
|
input_schema=input_schema,
|
||||||
|
output_schema=output_schema,
|
||||||
|
required_inputs=required_inputs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if len(blocks) >= _TARGET_RESULTS:
|
|
||||||
break
|
|
||||||
|
|
||||||
if blocks and len(blocks) < _TARGET_RESULTS:
|
|
||||||
logger.debug(
|
|
||||||
"find_block returned %d/%d results for query '%s' "
|
|
||||||
"(filtered %d excluded/disabled blocks)",
|
|
||||||
len(blocks),
|
|
||||||
_TARGET_RESULTS,
|
|
||||||
query,
|
|
||||||
len(results) - len(blocks),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not blocks:
|
if not blocks:
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
|
|||||||
@@ -1,139 +0,0 @@
|
|||||||
"""Tests for block filtering in FindBlockTool."""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.api.features.chat.tools.find_block import (
|
|
||||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
|
||||||
FindBlockTool,
|
|
||||||
)
|
|
||||||
from backend.api.features.chat.tools.models import BlockListResponse
|
|
||||||
from backend.data.block import BlockType
|
|
||||||
|
|
||||||
from ._test_data import make_session
|
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-find-block"
|
|
||||||
|
|
||||||
|
|
||||||
def make_mock_block(
|
|
||||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
|
||||||
):
|
|
||||||
"""Create a mock block for testing."""
|
|
||||||
mock = MagicMock()
|
|
||||||
mock.id = block_id
|
|
||||||
mock.name = name
|
|
||||||
mock.description = f"{name} description"
|
|
||||||
mock.block_type = block_type
|
|
||||||
mock.disabled = disabled
|
|
||||||
mock.input_schema = MagicMock()
|
|
||||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
|
||||||
mock.input_schema.get_credentials_fields.return_value = {}
|
|
||||||
mock.output_schema = MagicMock()
|
|
||||||
mock.output_schema.jsonschema.return_value = {}
|
|
||||||
mock.categories = []
|
|
||||||
return mock
|
|
||||||
|
|
||||||
|
|
||||||
class TestFindBlockFiltering:
|
|
||||||
"""Tests for block filtering in FindBlockTool."""
|
|
||||||
|
|
||||||
def test_excluded_block_types_contains_expected_types(self):
|
|
||||||
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
|
|
||||||
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
|
|
||||||
def test_excluded_block_ids_contains_smart_decision_maker(self):
|
|
||||||
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
|
||||||
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_excluded_block_type_filtered_from_results(self):
|
|
||||||
"""Verify blocks with excluded BlockTypes are filtered from search results."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
|
|
||||||
search_results = [
|
|
||||||
{"content_id": "input-block-id", "score": 0.9},
|
|
||||||
{"content_id": "standard-block-id", "score": 0.8},
|
|
||||||
]
|
|
||||||
|
|
||||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
|
||||||
standard_block = make_mock_block(
|
|
||||||
"standard-block-id", "HTTP Request", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
def mock_get_block(block_id):
|
|
||||||
return {
|
|
||||||
"input-block-id": input_block,
|
|
||||||
"standard-block-id": standard_block,
|
|
||||||
}.get(block_id)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=(search_results, 2),
|
|
||||||
):
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.get_block",
|
|
||||||
side_effect=mock_get_block,
|
|
||||||
):
|
|
||||||
tool = FindBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID, session=session, query="test"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should only return the standard block, not the INPUT block
|
|
||||||
assert isinstance(response, BlockListResponse)
|
|
||||||
assert len(response.blocks) == 1
|
|
||||||
assert response.blocks[0].id == "standard-block-id"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_excluded_block_id_filtered_from_results(self):
|
|
||||||
"""Verify SmartDecisionMakerBlock is filtered from search results."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
|
||||||
search_results = [
|
|
||||||
{"content_id": smart_decision_id, "score": 0.9},
|
|
||||||
{"content_id": "normal-block-id", "score": 0.8},
|
|
||||||
]
|
|
||||||
|
|
||||||
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
|
|
||||||
smart_block = make_mock_block(
|
|
||||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
normal_block = make_mock_block(
|
|
||||||
"normal-block-id", "Normal Block", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
def mock_get_block(block_id):
|
|
||||||
return {
|
|
||||||
smart_decision_id: smart_block,
|
|
||||||
"normal-block-id": normal_block,
|
|
||||||
}.get(block_id)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=(search_results, 2),
|
|
||||||
):
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.get_block",
|
|
||||||
side_effect=mock_get_block,
|
|
||||||
):
|
|
||||||
tool = FindBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID, session=session, query="decision"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should only return normal block, not SmartDecisionMakerBlock
|
|
||||||
assert isinstance(response, BlockListResponse)
|
|
||||||
assert len(response.blocks) == 1
|
|
||||||
assert response.blocks[0].id == "normal-block-id"
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
"""Shared helpers for chat tools."""
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
def get_inputs_from_schema(
|
|
||||||
input_schema: dict[str, Any],
|
|
||||||
exclude_fields: set[str] | None = None,
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Extract input field info from JSON schema."""
|
|
||||||
if not isinstance(input_schema, dict):
|
|
||||||
return []
|
|
||||||
|
|
||||||
exclude = exclude_fields or set()
|
|
||||||
properties = input_schema.get("properties", {})
|
|
||||||
required = set(input_schema.get("required", []))
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"name": name,
|
|
||||||
"title": schema.get("title", name),
|
|
||||||
"type": schema.get("type", "string"),
|
|
||||||
"description": schema.get("description", ""),
|
|
||||||
"required": name in required,
|
|
||||||
"default": schema.get("default"),
|
|
||||||
}
|
|
||||||
for name, schema in properties.items()
|
|
||||||
if name not in exclude
|
|
||||||
]
|
|
||||||
@@ -24,7 +24,6 @@ from backend.util.timezone_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .helpers import get_inputs_from_schema
|
|
||||||
from .models import (
|
from .models import (
|
||||||
AgentDetails,
|
AgentDetails,
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
@@ -262,7 +261,7 @@ class RunAgentTool(BaseTool):
|
|||||||
),
|
),
|
||||||
requirements={
|
requirements={
|
||||||
"credentials": requirements_creds_list,
|
"credentials": requirements_creds_list,
|
||||||
"inputs": get_inputs_from_schema(graph.input_schema),
|
"inputs": self._get_inputs_list(graph.input_schema),
|
||||||
"execution_modes": self._get_execution_modes(graph),
|
"execution_modes": self._get_execution_modes(graph),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@@ -370,6 +369,22 @@ class RunAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
|
"""Extract inputs list from schema."""
|
||||||
|
inputs_list = []
|
||||||
|
if isinstance(input_schema, dict) and "properties" in input_schema:
|
||||||
|
for field_name, field_schema in input_schema["properties"].items():
|
||||||
|
inputs_list.append(
|
||||||
|
{
|
||||||
|
"name": field_name,
|
||||||
|
"title": field_schema.get("title", field_name),
|
||||||
|
"type": field_schema.get("type", "string"),
|
||||||
|
"description": field_schema.get("description", ""),
|
||||||
|
"required": field_name in input_schema.get("required", []),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return inputs_list
|
||||||
|
|
||||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||||
"""Get available execution modes for the graph."""
|
"""Get available execution modes for the graph."""
|
||||||
trigger_info = graph.trigger_setup_info
|
trigger_info = graph.trigger_setup_info
|
||||||
@@ -383,7 +398,7 @@ class RunAgentTool(BaseTool):
|
|||||||
suffix: str,
|
suffix: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build a message describing available inputs for an agent."""
|
"""Build a message describing available inputs for an agent."""
|
||||||
inputs_list = get_inputs_from_schema(graph.input_schema)
|
inputs_list = self._get_inputs_list(graph.input_schema)
|
||||||
required_names = [i["name"] for i in inputs_list if i["required"]]
|
required_names = [i["name"] for i in inputs_list if i["required"]]
|
||||||
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
||||||
|
|
||||||
|
|||||||
@@ -8,19 +8,14 @@ from typing import Any
|
|||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.api.features.chat.tools.find_block import (
|
from backend.data.block import get_block
|
||||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
|
||||||
)
|
|
||||||
from backend.data.block import AnyBlockSchema, get_block
|
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .helpers import get_inputs_from_schema
|
|
||||||
from .models import (
|
from .models import (
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -29,10 +24,7 @@ from .models import (
|
|||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
UserReadiness,
|
UserReadiness,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import build_missing_credentials_from_field_info
|
||||||
build_missing_credentials_from_field_info,
|
|
||||||
match_credentials_to_requirements,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -81,6 +73,91 @@ class RunBlockTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def _check_block_credentials(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
block: Any,
|
||||||
|
input_data: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Check if user has required credentials for a block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
block: Block to check credentials for
|
||||||
|
input_data: Input data for the block (used to determine provider via discriminator)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[matched_credentials, missing_credentials]
|
||||||
|
"""
|
||||||
|
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||||
|
missing_credentials: list[CredentialsMetaInput] = []
|
||||||
|
input_data = input_data or {}
|
||||||
|
|
||||||
|
# Get credential field info from block's input schema
|
||||||
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
|
|
||||||
|
if not credentials_fields_info:
|
||||||
|
return matched_credentials, missing_credentials
|
||||||
|
|
||||||
|
# Get user's available credentials
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
|
for field_name, field_info in credentials_fields_info.items():
|
||||||
|
effective_field_info = field_info
|
||||||
|
if field_info.discriminator and field_info.discriminator_mapping:
|
||||||
|
# Get discriminator from input, falling back to schema default
|
||||||
|
discriminator_value = input_data.get(field_info.discriminator)
|
||||||
|
if discriminator_value is None:
|
||||||
|
field = block.input_schema.model_fields.get(
|
||||||
|
field_info.discriminator
|
||||||
|
)
|
||||||
|
if field and field.default is not PydanticUndefined:
|
||||||
|
discriminator_value = field.default
|
||||||
|
|
||||||
|
if (
|
||||||
|
discriminator_value
|
||||||
|
and discriminator_value in field_info.discriminator_mapping
|
||||||
|
):
|
||||||
|
effective_field_info = field_info.discriminate(discriminator_value)
|
||||||
|
logger.debug(
|
||||||
|
f"Discriminated provider for {field_name}: "
|
||||||
|
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
matching_cred = next(
|
||||||
|
(
|
||||||
|
cred
|
||||||
|
for cred in available_creds
|
||||||
|
if cred.provider in effective_field_info.provider
|
||||||
|
and cred.type in effective_field_info.supported_types
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if matching_cred:
|
||||||
|
matched_credentials[field_name] = CredentialsMetaInput(
|
||||||
|
id=matching_cred.id,
|
||||||
|
provider=matching_cred.provider, # type: ignore
|
||||||
|
type=matching_cred.type,
|
||||||
|
title=matching_cred.title,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Create a placeholder for the missing credential
|
||||||
|
provider = next(iter(effective_field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
||||||
|
missing_credentials.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=field_name.replace("_", " ").title(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matched_credentials, missing_credentials
|
||||||
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
@@ -135,24 +212,11 @@ class RunBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if block is excluded from CoPilot (graph-only blocks)
|
|
||||||
if (
|
|
||||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
|
||||||
):
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Block '{block.name}' cannot be run directly in CoPilot. "
|
|
||||||
"This block is designed for use within graphs only."
|
|
||||||
),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||||
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
creds_manager = IntegrationCredentialsManager()
|
||||||
matched_credentials, missing_credentials = (
|
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||||
await self._resolve_block_credentials(user_id, block, input_data)
|
user_id, block, input_data
|
||||||
)
|
)
|
||||||
|
|
||||||
if missing_credentials:
|
if missing_credentials:
|
||||||
@@ -281,75 +345,29 @@ class RunBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _resolve_block_credentials(
|
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
block: AnyBlockSchema,
|
|
||||||
input_data: dict[str, Any] | None = None,
|
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
|
||||||
"""
|
|
||||||
Resolve credentials for a block by matching user's available credentials.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
block: Block to resolve credentials for
|
|
||||||
input_data: Input data for the block (used to determine provider via discriminator)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple of (matched_credentials, missing_credentials) - matched credentials
|
|
||||||
are used for block execution, missing ones indicate setup requirements.
|
|
||||||
"""
|
|
||||||
input_data = input_data or {}
|
|
||||||
requirements = self._resolve_discriminated_credentials(block, input_data)
|
|
||||||
|
|
||||||
if not requirements:
|
|
||||||
return {}, []
|
|
||||||
|
|
||||||
return await match_credentials_to_requirements(user_id, requirements)
|
|
||||||
|
|
||||||
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
|
|
||||||
"""Extract non-credential inputs from block schema."""
|
"""Extract non-credential inputs from block schema."""
|
||||||
|
inputs_list = []
|
||||||
schema = block.input_schema.jsonschema()
|
schema = block.input_schema.jsonschema()
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
required_fields = set(schema.get("required", []))
|
||||||
|
|
||||||
|
# Get credential field names to exclude
|
||||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||||
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
|
||||||
|
|
||||||
def _resolve_discriminated_credentials(
|
for field_name, field_schema in properties.items():
|
||||||
self,
|
# Skip credential fields
|
||||||
block: AnyBlockSchema,
|
if field_name in credentials_fields:
|
||||||
input_data: dict[str, Any],
|
continue
|
||||||
) -> dict[str, CredentialsFieldInfo]:
|
|
||||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
|
||||||
if not credentials_fields_info:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
resolved: dict[str, CredentialsFieldInfo] = {}
|
inputs_list.append(
|
||||||
|
{
|
||||||
|
"name": field_name,
|
||||||
|
"title": field_schema.get("title", field_name),
|
||||||
|
"type": field_schema.get("type", "string"),
|
||||||
|
"description": field_schema.get("description", ""),
|
||||||
|
"required": field_name in required_fields,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
for field_name, field_info in credentials_fields_info.items():
|
return inputs_list
|
||||||
effective_field_info = field_info
|
|
||||||
|
|
||||||
if field_info.discriminator and field_info.discriminator_mapping:
|
|
||||||
discriminator_value = input_data.get(field_info.discriminator)
|
|
||||||
if discriminator_value is None:
|
|
||||||
field = block.input_schema.model_fields.get(
|
|
||||||
field_info.discriminator
|
|
||||||
)
|
|
||||||
if field and field.default is not PydanticUndefined:
|
|
||||||
discriminator_value = field.default
|
|
||||||
|
|
||||||
if (
|
|
||||||
discriminator_value
|
|
||||||
and discriminator_value in field_info.discriminator_mapping
|
|
||||||
):
|
|
||||||
effective_field_info = field_info.discriminate(discriminator_value)
|
|
||||||
# For host-scoped credentials, add the discriminator value
|
|
||||||
# (e.g., URL) so _credential_is_for_host can match it
|
|
||||||
effective_field_info.discriminator_values.add(discriminator_value)
|
|
||||||
logger.debug(
|
|
||||||
f"Discriminated provider for {field_name}: "
|
|
||||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
|
||||||
)
|
|
||||||
|
|
||||||
resolved[field_name] = effective_field_info
|
|
||||||
|
|
||||||
return resolved
|
|
||||||
|
|||||||
@@ -1,106 +0,0 @@
|
|||||||
"""Tests for block execution guards in RunBlockTool."""
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.api.features.chat.tools.models import ErrorResponse
|
|
||||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
|
||||||
from backend.data.block import BlockType
|
|
||||||
|
|
||||||
from ._test_data import make_session
|
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-run-block"
|
|
||||||
|
|
||||||
|
|
||||||
def make_mock_block(
|
|
||||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
|
||||||
):
|
|
||||||
"""Create a mock block for testing."""
|
|
||||||
mock = MagicMock()
|
|
||||||
mock.id = block_id
|
|
||||||
mock.name = name
|
|
||||||
mock.block_type = block_type
|
|
||||||
mock.disabled = disabled
|
|
||||||
mock.input_schema = MagicMock()
|
|
||||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
|
||||||
mock.input_schema.get_credentials_fields_info.return_value = []
|
|
||||||
return mock
|
|
||||||
|
|
||||||
|
|
||||||
class TestRunBlockFiltering:
|
|
||||||
"""Tests for block execution guards in RunBlockTool."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_excluded_block_type_returns_error(self):
|
|
||||||
"""Attempting to execute a block with excluded BlockType returns error."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=input_block,
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id="input-block-id",
|
|
||||||
input_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, ErrorResponse)
|
|
||||||
assert "cannot be run directly in CoPilot" in response.message
|
|
||||||
assert "designed for use within graphs only" in response.message
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_excluded_block_id_returns_error(self):
|
|
||||||
"""Attempting to execute SmartDecisionMakerBlock returns error."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
|
||||||
smart_block = make_mock_block(
|
|
||||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=smart_block,
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id=smart_decision_id,
|
|
||||||
input_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, ErrorResponse)
|
|
||||||
assert "cannot be run directly in CoPilot" in response.message
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_non_excluded_block_passes_guard(self):
|
|
||||||
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
standard_block = make_mock_block(
|
|
||||||
"standard-id", "HTTP Request", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=standard_block,
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id="standard-id",
|
|
||||||
input_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should NOT be an ErrorResponse about CoPilot exclusion
|
|
||||||
# (may be other errors like missing credentials, but not the exclusion guard)
|
|
||||||
if isinstance(response, ErrorResponse):
|
|
||||||
assert "cannot be run directly in CoPilot" not in response.message
|
|
||||||
@@ -6,16 +6,15 @@ from typing import Any
|
|||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.library import model as library_model
|
from backend.api.features.library import model as library_model
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
|
from backend.data import graph as graph_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
Credentials,
|
|
||||||
CredentialsFieldInfo,
|
CredentialsFieldInfo,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
HostScopedCredentials,
|
HostScopedCredentials,
|
||||||
OAuth2Credentials,
|
OAuth2Credentials,
|
||||||
)
|
)
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -45,8 +44,14 @@ async def fetch_graph_from_store_slug(
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Get the graph from store listing version
|
# Get the graph from store listing version
|
||||||
graph = await store_db.get_available_graph(
|
graph_meta = await store_db.get_available_graph(
|
||||||
store_agent.store_listing_version_id, hide_nodes=False
|
store_agent.store_listing_version_id
|
||||||
|
)
|
||||||
|
graph = await graph_db.get_graph(
|
||||||
|
graph_id=graph_meta.id,
|
||||||
|
version=graph_meta.version,
|
||||||
|
user_id=None, # Public access
|
||||||
|
include_subgraphs=True,
|
||||||
)
|
)
|
||||||
return graph, store_agent
|
return graph, store_agent
|
||||||
|
|
||||||
@@ -123,7 +128,7 @@ def build_missing_credentials_from_graph(
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
field_key: _serialize_missing_credential(field_key, field_info)
|
field_key: _serialize_missing_credential(field_key, field_info)
|
||||||
for field_key, (field_info, _, _) in aggregated_fields.items()
|
for field_key, (field_info, _node_fields) in aggregated_fields.items()
|
||||||
if field_key not in matched_keys
|
if field_key not in matched_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,99 +230,6 @@ async def get_or_create_library_agent(
|
|||||||
return library_agents[0]
|
return library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
async def match_credentials_to_requirements(
|
|
||||||
user_id: str,
|
|
||||||
requirements: dict[str, CredentialsFieldInfo],
|
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
|
||||||
"""
|
|
||||||
Match user's credentials against a dictionary of credential requirements.
|
|
||||||
|
|
||||||
This is the core matching logic shared by both graph and block credential matching.
|
|
||||||
"""
|
|
||||||
matched: dict[str, CredentialsMetaInput] = {}
|
|
||||||
missing: list[CredentialsMetaInput] = []
|
|
||||||
|
|
||||||
if not requirements:
|
|
||||||
return matched, missing
|
|
||||||
|
|
||||||
available_creds = await get_user_credentials(user_id)
|
|
||||||
|
|
||||||
for field_name, field_info in requirements.items():
|
|
||||||
matching_cred = find_matching_credential(available_creds, field_info)
|
|
||||||
|
|
||||||
if matching_cred:
|
|
||||||
try:
|
|
||||||
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
|
||||||
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
|
||||||
f"credential_id={matching_cred.id}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
provider = next(iter(field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
|
||||||
missing.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=f"{field_name} (validation failed: {e})",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
provider = next(iter(field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
|
||||||
missing.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=field_name.replace("_", " ").title(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matched, missing
|
|
||||||
|
|
||||||
|
|
||||||
async def get_user_credentials(user_id: str) -> list[Credentials]:
|
|
||||||
"""Get all available credentials for a user."""
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
return await creds_manager.store.get_all_creds(user_id)
|
|
||||||
|
|
||||||
|
|
||||||
def find_matching_credential(
|
|
||||||
available_creds: list[Credentials],
|
|
||||||
field_info: CredentialsFieldInfo,
|
|
||||||
) -> Credentials | None:
|
|
||||||
"""Find a credential that matches the required provider, type, scopes, and host."""
|
|
||||||
for cred in available_creds:
|
|
||||||
if cred.provider not in field_info.provider:
|
|
||||||
continue
|
|
||||||
if cred.type not in field_info.supported_types:
|
|
||||||
continue
|
|
||||||
if cred.type == "oauth2" and not _credential_has_required_scopes(
|
|
||||||
cred, field_info
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
|
|
||||||
continue
|
|
||||||
return cred
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def create_credential_meta_from_match(
|
|
||||||
matching_cred: Credentials,
|
|
||||||
) -> CredentialsMetaInput:
|
|
||||||
"""Create a CredentialsMetaInput from a matched credential."""
|
|
||||||
return CredentialsMetaInput(
|
|
||||||
id=matching_cred.id,
|
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def match_user_credentials_to_graph(
|
async def match_user_credentials_to_graph(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
graph: GraphModel,
|
graph: GraphModel,
|
||||||
@@ -357,10 +269,9 @@ async def match_user_credentials_to_graph(
|
|||||||
# provider is in the set of acceptable providers.
|
# provider is in the set of acceptable providers.
|
||||||
for credential_field_name, (
|
for credential_field_name, (
|
||||||
credential_requirements,
|
credential_requirements,
|
||||||
_,
|
_node_fields,
|
||||||
_,
|
|
||||||
) in aggregated_creds.items():
|
) in aggregated_creds.items():
|
||||||
# Find first matching credential by provider, type, scopes, and host/URL
|
# Find first matching credential by provider, type, and scopes
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
(
|
(
|
||||||
cred
|
cred
|
||||||
@@ -375,10 +286,6 @@ async def match_user_credentials_to_graph(
|
|||||||
cred.type != "host_scoped"
|
cred.type != "host_scoped"
|
||||||
or _credential_is_for_host(cred, credential_requirements)
|
or _credential_is_for_host(cred, credential_requirements)
|
||||||
)
|
)
|
||||||
and (
|
|
||||||
cred.provider != ProviderName.MCP
|
|
||||||
or _credential_is_for_mcp_server(cred, credential_requirements)
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -430,6 +337,8 @@ def _credential_has_required_scopes(
|
|||||||
# If no scopes are required, any credential matches
|
# If no scopes are required, any credential matches
|
||||||
if not requirements.required_scopes:
|
if not requirements.required_scopes:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# Check that credential scopes are a superset of required scopes
|
||||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
@@ -449,22 +358,6 @@ def _credential_is_for_host(
|
|||||||
return credential.matches_url(list(requirements.discriminator_values)[0])
|
return credential.matches_url(list(requirements.discriminator_values)[0])
|
||||||
|
|
||||||
|
|
||||||
def _credential_is_for_mcp_server(
|
|
||||||
credential: Credentials,
|
|
||||||
requirements: CredentialsFieldInfo,
|
|
||||||
) -> bool:
|
|
||||||
"""Check if an MCP OAuth credential matches the required server URL."""
|
|
||||||
if not requirements.discriminator_values:
|
|
||||||
return True
|
|
||||||
|
|
||||||
server_url = (
|
|
||||||
credential.metadata.get("mcp_server_url")
|
|
||||||
if isinstance(credential, OAuth2Credentials)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
return server_url in requirements.discriminator_values if server_url else False
|
|
||||||
|
|
||||||
|
|
||||||
async def check_user_has_required_credentials(
|
async def check_user_has_required_credentials(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
required_credentials: list[CredentialsMetaInput],
|
required_credentials: list[CredentialsMetaInput],
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import TYPE_CHECKING, Annotated, Any, List, Literal
|
from typing import TYPE_CHECKING, Annotated, List, Literal
|
||||||
|
|
||||||
from autogpt_libs.auth import get_user_id
|
from autogpt_libs.auth import get_user_id
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
@@ -14,7 +14,7 @@ from fastapi import (
|
|||||||
Security,
|
Security,
|
||||||
status,
|
status,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, Field, SecretStr, model_validator
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||||
|
|
||||||
from backend.api.features.library.db import set_preset_webhook, update_preset
|
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||||
@@ -102,37 +102,9 @@ class CredentialsMetaResponse(BaseModel):
|
|||||||
scopes: list[str] | None
|
scopes: list[str] | None
|
||||||
username: str | None
|
username: str | None
|
||||||
host: str | None = Field(
|
host: str | None = Field(
|
||||||
default=None,
|
default=None, description="Host pattern for host-scoped credentials"
|
||||||
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def _normalize_provider(cls, data: Any) -> Any:
|
|
||||||
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug."""
|
|
||||||
if isinstance(data, dict):
|
|
||||||
prov = data.get("provider", "")
|
|
||||||
if isinstance(prov, str) and prov.startswith("ProviderName."):
|
|
||||||
member = prov.removeprefix("ProviderName.")
|
|
||||||
try:
|
|
||||||
data = {**data, "provider": ProviderName[member].value}
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_host(cred: Credentials) -> str | None:
|
|
||||||
"""Extract host from credential: HostScoped host or MCP server URL."""
|
|
||||||
if isinstance(cred, HostScopedCredentials):
|
|
||||||
return cred.host
|
|
||||||
if isinstance(cred, OAuth2Credentials) and cred.provider in (
|
|
||||||
ProviderName.MCP,
|
|
||||||
ProviderName.MCP.value,
|
|
||||||
"ProviderName.MCP",
|
|
||||||
):
|
|
||||||
return (cred.metadata or {}).get("mcp_server_url")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||||
async def callback(
|
async def callback(
|
||||||
@@ -207,7 +179,9 @@ async def callback(
|
|||||||
title=credentials.title,
|
title=credentials.title,
|
||||||
scopes=credentials.scopes,
|
scopes=credentials.scopes,
|
||||||
username=credentials.username,
|
username=credentials.username,
|
||||||
host=(CredentialsMetaResponse.get_host(credentials)),
|
host=(
|
||||||
|
credentials.host if isinstance(credentials, HostScopedCredentials) else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -225,7 +199,7 @@ async def list_credentials(
|
|||||||
title=cred.title,
|
title=cred.title,
|
||||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||||
host=CredentialsMetaResponse.get_host(cred),
|
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||||
)
|
)
|
||||||
for cred in credentials
|
for cred in credentials
|
||||||
]
|
]
|
||||||
@@ -248,7 +222,7 @@ async def list_credentials_by_provider(
|
|||||||
title=cred.title,
|
title=cred.title,
|
||||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||||
host=CredentialsMetaResponse.get_host(cred),
|
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||||
)
|
)
|
||||||
for cred in credentials
|
for cred in credentials
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -19,10 +19,7 @@ from backend.data.graph import GraphSettings
|
|||||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||||
on_graph_activate,
|
|
||||||
on_graph_deactivate,
|
|
||||||
)
|
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
@@ -374,7 +371,7 @@ async def get_library_agent_by_graph_id(
|
|||||||
|
|
||||||
|
|
||||||
async def add_generated_agent_image(
|
async def add_generated_agent_image(
|
||||||
graph: graph_db.GraphBaseMeta,
|
graph: graph_db.BaseGraph,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
) -> Optional[prisma.models.LibraryAgent]:
|
) -> Optional[prisma.models.LibraryAgent]:
|
||||||
@@ -540,92 +537,6 @@ async def update_agent_version_in_library(
|
|||||||
return library_model.LibraryAgent.from_db(lib)
|
return library_model.LibraryAgent.from_db(lib)
|
||||||
|
|
||||||
|
|
||||||
async def create_graph_in_library(
|
|
||||||
graph: graph_db.Graph,
|
|
||||||
user_id: str,
|
|
||||||
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
|
||||||
"""Create a new graph and add it to the user's library."""
|
|
||||||
graph.version = 1
|
|
||||||
graph_model = graph_db.make_graph_model(graph, user_id)
|
|
||||||
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
|
||||||
|
|
||||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
|
||||||
|
|
||||||
library_agents = await create_library_agent(
|
|
||||||
graph=created_graph,
|
|
||||||
user_id=user_id,
|
|
||||||
sensitive_action_safe_mode=True,
|
|
||||||
create_library_agents_for_sub_graphs=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if created_graph.is_active:
|
|
||||||
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
|
||||||
|
|
||||||
return created_graph, library_agents[0]
|
|
||||||
|
|
||||||
|
|
||||||
async def update_graph_in_library(
|
|
||||||
graph: graph_db.Graph,
|
|
||||||
user_id: str,
|
|
||||||
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
|
||||||
"""Create a new version of an existing graph and update the library entry."""
|
|
||||||
existing_versions = await graph_db.get_graph_all_versions(graph.id, user_id)
|
|
||||||
current_active_version = (
|
|
||||||
next((v for v in existing_versions if v.is_active), None)
|
|
||||||
if existing_versions
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
graph.version = (
|
|
||||||
max(v.version for v in existing_versions) + 1 if existing_versions else 1
|
|
||||||
)
|
|
||||||
|
|
||||||
graph_model = graph_db.make_graph_model(graph, user_id)
|
|
||||||
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
|
||||||
|
|
||||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
|
||||||
|
|
||||||
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
|
||||||
if not library_agent:
|
|
||||||
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
|
||||||
|
|
||||||
library_agent = await update_library_agent_version_and_settings(
|
|
||||||
user_id, created_graph
|
|
||||||
)
|
|
||||||
|
|
||||||
if created_graph.is_active:
|
|
||||||
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
|
||||||
await graph_db.set_graph_active_version(
|
|
||||||
graph_id=created_graph.id,
|
|
||||||
version=created_graph.version,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
if current_active_version:
|
|
||||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
|
||||||
|
|
||||||
return created_graph, library_agent
|
|
||||||
|
|
||||||
|
|
||||||
async def update_library_agent_version_and_settings(
|
|
||||||
user_id: str, agent_graph: graph_db.GraphModel
|
|
||||||
) -> library_model.LibraryAgent:
|
|
||||||
"""Update library agent to point to new graph version and sync settings."""
|
|
||||||
library = await update_agent_version_in_library(
|
|
||||||
user_id, agent_graph.id, agent_graph.version
|
|
||||||
)
|
|
||||||
updated_settings = GraphSettings.from_graph(
|
|
||||||
graph=agent_graph,
|
|
||||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
|
||||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
|
||||||
)
|
|
||||||
if updated_settings != library.settings:
|
|
||||||
library = await update_library_agent(
|
|
||||||
library_agent_id=library.id,
|
|
||||||
user_id=user_id,
|
|
||||||
settings=updated_settings,
|
|
||||||
)
|
|
||||||
return library
|
|
||||||
|
|
||||||
|
|
||||||
async def update_library_agent(
|
async def update_library_agent(
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|||||||
@@ -1,414 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP (Model Context Protocol) API routes.
|
|
||||||
|
|
||||||
Provides endpoints for MCP tool discovery and OAuth authentication so the
|
|
||||||
frontend can list available tools on an MCP server before placing a block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Annotated, Any
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
from autogpt_libs.auth import get_user_id
|
|
||||||
from fastapi import Security
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from backend.api.features.integrations.router import CredentialsMetaResponse
|
|
||||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
|
||||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.request import HTTPClientError, Requests
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
router = fastapi.APIRouter(tags=["mcp"])
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Tool Discovery ====================== #
|
|
||||||
|
|
||||||
|
|
||||||
class DiscoverToolsRequest(BaseModel):
|
|
||||||
"""Request to discover tools on an MCP server."""
|
|
||||||
|
|
||||||
server_url: str = Field(description="URL of the MCP server")
|
|
||||||
auth_token: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Optional Bearer token for authenticated MCP servers",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MCPToolResponse(BaseModel):
|
|
||||||
"""A single MCP tool returned by discovery."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
input_schema: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class DiscoverToolsResponse(BaseModel):
|
|
||||||
"""Response containing the list of tools available on an MCP server."""
|
|
||||||
|
|
||||||
tools: list[MCPToolResponse]
|
|
||||||
server_name: str | None = None
|
|
||||||
protocol_version: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/discover-tools",
|
|
||||||
summary="Discover available tools on an MCP server",
|
|
||||||
response_model=DiscoverToolsResponse,
|
|
||||||
)
|
|
||||||
async def discover_tools(
|
|
||||||
request: DiscoverToolsRequest,
|
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
|
||||||
) -> DiscoverToolsResponse:
|
|
||||||
"""
|
|
||||||
Connect to an MCP server and return its available tools.
|
|
||||||
|
|
||||||
If the user has a stored MCP credential for this server URL, it will be
|
|
||||||
used automatically — no need to pass an explicit auth token.
|
|
||||||
"""
|
|
||||||
auth_token = request.auth_token
|
|
||||||
|
|
||||||
# Auto-use stored MCP credential when no explicit token is provided.
|
|
||||||
if not auth_token:
|
|
||||||
try:
|
|
||||||
mcp_creds = await creds_manager.store.get_creds_by_provider(
|
|
||||||
user_id, ProviderName.MCP.value
|
|
||||||
)
|
|
||||||
# Find the freshest credential for this server URL
|
|
||||||
best_cred: OAuth2Credentials | None = None
|
|
||||||
for cred in mcp_creds:
|
|
||||||
if (
|
|
||||||
isinstance(cred, OAuth2Credentials)
|
|
||||||
and cred.metadata.get("mcp_server_url") == request.server_url
|
|
||||||
):
|
|
||||||
if best_cred is None or (
|
|
||||||
(cred.access_token_expires_at or 0)
|
|
||||||
> (best_cred.access_token_expires_at or 0)
|
|
||||||
):
|
|
||||||
best_cred = cred
|
|
||||||
if best_cred:
|
|
||||||
# Refresh the token if expired before using it
|
|
||||||
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
|
|
||||||
logger.info(
|
|
||||||
f"Using MCP credential {best_cred.id} for {request.server_url}, "
|
|
||||||
f"expires_at={best_cred.access_token_expires_at}"
|
|
||||||
)
|
|
||||||
auth_token = best_cred.access_token.get_secret_value()
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not look up stored MCP credentials", exc_info=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
client = MCPClient(request.server_url, auth_token=auth_token)
|
|
||||||
|
|
||||||
init_result = await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
return DiscoverToolsResponse(
|
|
||||||
tools=[
|
|
||||||
MCPToolResponse(
|
|
||||||
name=t.name,
|
|
||||||
description=t.description,
|
|
||||||
input_schema=t.input_schema,
|
|
||||||
)
|
|
||||||
for t in tools
|
|
||||||
],
|
|
||||||
server_name=init_result.get("serverInfo", {}).get("name"),
|
|
||||||
protocol_version=init_result.get("protocolVersion"),
|
|
||||||
)
|
|
||||||
except HTTPClientError as e:
|
|
||||||
if e.status_code in (401, 403):
|
|
||||||
logger.warning(
|
|
||||||
f"MCP server returned {e.status_code} for {request.server_url}: {e}"
|
|
||||||
)
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="This MCP server requires authentication. "
|
|
||||||
"Please provide a valid auth token.",
|
|
||||||
)
|
|
||||||
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
|
||||||
except MCPClientError as e:
|
|
||||||
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("MCP tool discovery failed")
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=502,
|
|
||||||
detail=f"Failed to connect to MCP server: {str(e)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ======================== OAuth Flow ======================== #
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthLoginRequest(BaseModel):
|
|
||||||
"""Request to start an OAuth flow for an MCP server."""
|
|
||||||
|
|
||||||
server_url: str = Field(description="URL of the MCP server that requires OAuth")
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthLoginResponse(BaseModel):
|
|
||||||
"""Response with the OAuth login URL for the user to authenticate."""
|
|
||||||
|
|
||||||
login_url: str
|
|
||||||
state_token: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/oauth/login",
|
|
||||||
summary="Initiate OAuth login for an MCP server",
|
|
||||||
)
|
|
||||||
async def mcp_oauth_login(
|
|
||||||
request: MCPOAuthLoginRequest,
|
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
|
||||||
) -> MCPOAuthLoginResponse:
|
|
||||||
"""
|
|
||||||
Discover OAuth metadata from the MCP server and return a login URL.
|
|
||||||
|
|
||||||
1. Discovers the protected-resource metadata (RFC 9728)
|
|
||||||
2. Fetches the authorization server metadata (RFC 8414)
|
|
||||||
3. Performs Dynamic Client Registration (RFC 7591) if available
|
|
||||||
4. Returns the authorization URL for the frontend to open in a popup
|
|
||||||
"""
|
|
||||||
client = MCPClient(request.server_url)
|
|
||||||
|
|
||||||
# Step 1: Discover protected-resource metadata (RFC 9728)
|
|
||||||
try:
|
|
||||||
protected_resource = await client.discover_auth()
|
|
||||||
except Exception as e:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=502,
|
|
||||||
detail=f"Failed to discover OAuth metadata: {e}",
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
if protected_resource and "authorization_servers" in protected_resource:
|
|
||||||
auth_server_url = protected_resource["authorization_servers"][0]
|
|
||||||
resource_url = protected_resource.get("resource", request.server_url)
|
|
||||||
|
|
||||||
# Step 2a: Discover auth-server metadata (RFC 8414)
|
|
||||||
try:
|
|
||||||
metadata = await client.discover_auth_server_metadata(auth_server_url)
|
|
||||||
except Exception as e:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=502,
|
|
||||||
detail=f"Failed to discover authorization server metadata: {e}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
|
|
||||||
# and serve OAuth metadata directly without protected-resource metadata.
|
|
||||||
# Don't assume a resource_url — omitting it lets the auth server choose
|
|
||||||
# the correct audience for the token (RFC 8707 resource is optional).
|
|
||||||
resource_url = None
|
|
||||||
try:
|
|
||||||
metadata = await client.discover_auth_server_metadata(request.server_url)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not metadata or "authorization_endpoint" not in metadata:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="This MCP server does not advertise OAuth support. "
|
|
||||||
"You may need to provide an auth token manually.",
|
|
||||||
)
|
|
||||||
|
|
||||||
authorize_url = metadata["authorization_endpoint"]
|
|
||||||
token_url = metadata["token_endpoint"]
|
|
||||||
registration_endpoint = metadata.get("registration_endpoint")
|
|
||||||
revoke_url = metadata.get("revocation_endpoint")
|
|
||||||
|
|
||||||
# Step 3: Dynamic Client Registration (RFC 7591) if available
|
|
||||||
frontend_base_url = settings.config.frontend_base_url
|
|
||||||
if not frontend_base_url:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Frontend base URL is not configured.",
|
|
||||||
)
|
|
||||||
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
|
||||||
|
|
||||||
client_id = ""
|
|
||||||
client_secret = ""
|
|
||||||
if registration_endpoint:
|
|
||||||
reg_result = await _register_mcp_client(
|
|
||||||
registration_endpoint, redirect_uri, request.server_url
|
|
||||||
)
|
|
||||||
if reg_result:
|
|
||||||
client_id = reg_result.get("client_id", "")
|
|
||||||
client_secret = reg_result.get("client_secret", "")
|
|
||||||
|
|
||||||
if not client_id:
|
|
||||||
client_id = "autogpt-platform"
|
|
||||||
|
|
||||||
# Step 4: Store state token with OAuth metadata for the callback
|
|
||||||
scopes = (protected_resource or {}).get("scopes_supported") or metadata.get(
|
|
||||||
"scopes_supported", []
|
|
||||||
)
|
|
||||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
|
||||||
user_id,
|
|
||||||
ProviderName.MCP.value,
|
|
||||||
scopes,
|
|
||||||
state_metadata={
|
|
||||||
"authorize_url": authorize_url,
|
|
||||||
"token_url": token_url,
|
|
||||||
"revoke_url": revoke_url,
|
|
||||||
"resource_url": resource_url,
|
|
||||||
"server_url": request.server_url,
|
|
||||||
"client_id": client_id,
|
|
||||||
"client_secret": client_secret,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 5: Build and return the login URL
|
|
||||||
handler = MCPOAuthHandler(
|
|
||||||
client_id=client_id,
|
|
||||||
client_secret=client_secret,
|
|
||||||
redirect_uri=redirect_uri,
|
|
||||||
authorize_url=authorize_url,
|
|
||||||
token_url=token_url,
|
|
||||||
resource_url=resource_url,
|
|
||||||
)
|
|
||||||
login_url = handler.get_login_url(
|
|
||||||
scopes, state_token, code_challenge=code_challenge
|
|
||||||
)
|
|
||||||
|
|
||||||
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthCallbackRequest(BaseModel):
|
|
||||||
"""Request to exchange an OAuth code for tokens."""
|
|
||||||
|
|
||||||
code: str = Field(description="Authorization code from OAuth callback")
|
|
||||||
state_token: str = Field(description="State token for CSRF verification")
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthCallbackResponse(BaseModel):
|
|
||||||
"""Response after successfully storing OAuth credentials."""
|
|
||||||
|
|
||||||
credential_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/oauth/callback",
|
|
||||||
summary="Exchange OAuth code for MCP tokens",
|
|
||||||
)
|
|
||||||
async def mcp_oauth_callback(
|
|
||||||
request: MCPOAuthCallbackRequest,
|
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
|
||||||
) -> CredentialsMetaResponse:
|
|
||||||
"""
|
|
||||||
Exchange the authorization code for tokens and store the credential.
|
|
||||||
|
|
||||||
The frontend calls this after receiving the OAuth code from the popup.
|
|
||||||
On success, subsequent ``/discover-tools`` calls for the same server URL
|
|
||||||
will automatically use the stored credential.
|
|
||||||
"""
|
|
||||||
valid_state = await creds_manager.store.verify_state_token(
|
|
||||||
user_id, request.state_token, ProviderName.MCP.value
|
|
||||||
)
|
|
||||||
if not valid_state:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Invalid or expired state token.",
|
|
||||||
)
|
|
||||||
|
|
||||||
meta = valid_state.state_metadata
|
|
||||||
frontend_base_url = settings.config.frontend_base_url
|
|
||||||
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
|
||||||
|
|
||||||
handler = MCPOAuthHandler(
|
|
||||||
client_id=meta["client_id"],
|
|
||||||
client_secret=meta.get("client_secret", ""),
|
|
||||||
redirect_uri=redirect_uri,
|
|
||||||
authorize_url=meta["authorize_url"],
|
|
||||||
token_url=meta["token_url"],
|
|
||||||
revoke_url=meta.get("revoke_url"),
|
|
||||||
resource_url=meta.get("resource_url"),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
credentials = await handler.exchange_code_for_tokens(
|
|
||||||
request.code, valid_state.scopes, valid_state.code_verifier
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("MCP OAuth token exchange failed")
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"OAuth token exchange failed: {e}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Enrich credential metadata for future lookup and token refresh
|
|
||||||
if credentials.metadata is None:
|
|
||||||
credentials.metadata = {}
|
|
||||||
credentials.metadata["mcp_server_url"] = meta["server_url"]
|
|
||||||
credentials.metadata["mcp_client_id"] = meta["client_id"]
|
|
||||||
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
|
|
||||||
credentials.metadata["mcp_token_url"] = meta["token_url"]
|
|
||||||
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
|
|
||||||
|
|
||||||
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
|
|
||||||
credentials.title = f"MCP: {hostname}"
|
|
||||||
|
|
||||||
# Remove old MCP credentials for the same server to prevent stale token buildup.
|
|
||||||
try:
|
|
||||||
old_creds = await creds_manager.store.get_creds_by_provider(
|
|
||||||
user_id, ProviderName.MCP.value
|
|
||||||
)
|
|
||||||
for old in old_creds:
|
|
||||||
if (
|
|
||||||
isinstance(old, OAuth2Credentials)
|
|
||||||
and old.metadata.get("mcp_server_url") == meta["server_url"]
|
|
||||||
):
|
|
||||||
await creds_manager.store.delete_creds_by_id(user_id, old.id)
|
|
||||||
logger.info(
|
|
||||||
f"Removed old MCP credential {old.id} for {meta['server_url']}"
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not clean up old MCP credentials", exc_info=True)
|
|
||||||
|
|
||||||
await creds_manager.create(user_id, credentials)
|
|
||||||
|
|
||||||
return CredentialsMetaResponse(
|
|
||||||
id=credentials.id,
|
|
||||||
provider=credentials.provider,
|
|
||||||
type=credentials.type,
|
|
||||||
title=credentials.title,
|
|
||||||
scopes=credentials.scopes,
|
|
||||||
username=credentials.username,
|
|
||||||
host=credentials.metadata.get("mcp_server_url"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ======================== Helpers ======================== #
|
|
||||||
|
|
||||||
|
|
||||||
async def _register_mcp_client(
|
|
||||||
registration_endpoint: str,
|
|
||||||
redirect_uri: str,
|
|
||||||
server_url: str,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
|
|
||||||
try:
|
|
||||||
response = await Requests(raise_for_status=True).post(
|
|
||||||
registration_endpoint,
|
|
||||||
json={
|
|
||||||
"client_name": "AutoGPT Platform",
|
|
||||||
"redirect_uris": [redirect_uri],
|
|
||||||
"grant_types": ["authorization_code"],
|
|
||||||
"response_types": ["code"],
|
|
||||||
"token_endpoint_auth_method": "client_secret_post",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
data = response.json()
|
|
||||||
if isinstance(data, dict) and "client_id" in data:
|
|
||||||
return data
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
|
|
||||||
return None
|
|
||||||
@@ -1,389 +0,0 @@
|
|||||||
"""Tests for MCP API routes."""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
import fastapi.testclient
|
|
||||||
from autogpt_libs.auth import get_user_id
|
|
||||||
|
|
||||||
from backend.api.features.mcp.routes import router
|
|
||||||
from backend.blocks.mcp.client import MCPClientError, MCPTool
|
|
||||||
from backend.util.request import HTTPClientError
|
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
|
||||||
app.include_router(router)
|
|
||||||
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
|
|
||||||
client = fastapi.testclient.TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDiscoverTools:
|
|
||||||
def test_discover_tools_success(self):
|
|
||||||
mock_tools = [
|
|
||||||
MCPTool(
|
|
||||||
name="get_weather",
|
|
||||||
description="Get weather for a city",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
MCPTool(
|
|
||||||
name="add_numbers",
|
|
||||||
description="Add two numbers",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"a": {"type": "number"},
|
|
||||||
"b": {"type": "number"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
with (patch("backend.api.features.mcp.routes.MCPClient") as MockClient,):
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"protocolVersion": "2025-03-26",
|
|
||||||
"serverInfo": {"name": "test-server"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
instance.list_tools = AsyncMock(return_value=mock_tools)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert len(data["tools"]) == 2
|
|
||||||
assert data["tools"][0]["name"] == "get_weather"
|
|
||||||
assert data["tools"][1]["name"] == "add_numbers"
|
|
||||||
assert data["server_name"] == "test-server"
|
|
||||||
assert data["protocol_version"] == "2025-03-26"
|
|
||||||
|
|
||||||
def test_discover_tools_with_auth_token(self):
|
|
||||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
|
||||||
)
|
|
||||||
instance.list_tools = AsyncMock(return_value=[])
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={
|
|
||||||
"server_url": "https://mcp.example.com/mcp",
|
|
||||||
"auth_token": "my-secret-token",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
MockClient.assert_called_once_with(
|
|
||||||
"https://mcp.example.com/mcp",
|
|
||||||
auth_token="my-secret-token",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_discover_tools_auto_uses_stored_credential(self):
|
|
||||||
"""When no explicit token is given, stored MCP credentials are used."""
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
stored_cred = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
title="MCP: example.com",
|
|
||||||
access_token=SecretStr("stored-token-123"),
|
|
||||||
refresh_token=None,
|
|
||||||
access_token_expires_at=None,
|
|
||||||
refresh_token_expires_at=None,
|
|
||||||
scopes=[],
|
|
||||||
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
):
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
|
|
||||||
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
|
||||||
)
|
|
||||||
instance.list_tools = AsyncMock(return_value=[])
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
MockClient.assert_called_once_with(
|
|
||||||
"https://mcp.example.com/mcp",
|
|
||||||
auth_token="stored-token-123",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_discover_tools_mcp_error(self):
|
|
||||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
side_effect=MCPClientError("Connection refused")
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://bad-server.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 502
|
|
||||||
assert "Connection refused" in response.json()["detail"]
|
|
||||||
|
|
||||||
def test_discover_tools_generic_error(self):
|
|
||||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://timeout.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 502
|
|
||||||
assert "Failed to connect" in response.json()["detail"]
|
|
||||||
|
|
||||||
def test_discover_tools_auth_required(self):
|
|
||||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://auth-server.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 401
|
|
||||||
assert "requires authentication" in response.json()["detail"]
|
|
||||||
|
|
||||||
def test_discover_tools_forbidden(self):
|
|
||||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://auth-server.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 401
|
|
||||||
assert "requires authentication" in response.json()["detail"]
|
|
||||||
|
|
||||||
def test_discover_tools_missing_url(self):
|
|
||||||
response = client.post("/discover-tools", json={})
|
|
||||||
assert response.status_code == 422
|
|
||||||
|
|
||||||
|
|
||||||
class TestOAuthLogin:
|
|
||||||
def test_oauth_login_success(self):
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
|
||||||
patch(
|
|
||||||
"backend.api.features.mcp.routes._register_mcp_client"
|
|
||||||
) as mock_register,
|
|
||||||
):
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.discover_auth = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"authorization_servers": ["https://auth.sentry.io"],
|
|
||||||
"resource": "https://mcp.sentry.dev/mcp",
|
|
||||||
"scopes_supported": ["openid"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
instance.discover_auth_server_metadata = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"authorization_endpoint": "https://auth.sentry.io/authorize",
|
|
||||||
"token_endpoint": "https://auth.sentry.io/token",
|
|
||||||
"registration_endpoint": "https://auth.sentry.io/register",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
mock_register.return_value = {
|
|
||||||
"client_id": "registered-client-id",
|
|
||||||
"client_secret": "registered-secret",
|
|
||||||
}
|
|
||||||
mock_cm.store.store_state_token = AsyncMock(
|
|
||||||
return_value=("state-token-123", "code-challenge-abc")
|
|
||||||
)
|
|
||||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/oauth/login",
|
|
||||||
json={"server_url": "https://mcp.sentry.dev/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "login_url" in data
|
|
||||||
assert data["state_token"] == "state-token-123"
|
|
||||||
assert "auth.sentry.io/authorize" in data["login_url"]
|
|
||||||
assert "registered-client-id" in data["login_url"]
|
|
||||||
|
|
||||||
def test_oauth_login_no_oauth_support(self):
|
|
||||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.discover_auth = AsyncMock(return_value=None)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/oauth/login",
|
|
||||||
json={"server_url": "https://simple-server.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "does not advertise OAuth" in response.json()["detail"]
|
|
||||||
|
|
||||||
def test_oauth_login_fallback_to_public_client(self):
|
|
||||||
"""When DCR is unavailable, falls back to default public client ID."""
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
|
||||||
):
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.discover_auth = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"authorization_servers": ["https://auth.example.com"],
|
|
||||||
"resource": "https://mcp.example.com/mcp",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
instance.discover_auth_server_metadata = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"authorization_endpoint": "https://auth.example.com/authorize",
|
|
||||||
"token_endpoint": "https://auth.example.com/token",
|
|
||||||
# No registration_endpoint
|
|
||||||
}
|
|
||||||
)
|
|
||||||
mock_cm.store.store_state_token = AsyncMock(
|
|
||||||
return_value=("state-abc", "challenge-xyz")
|
|
||||||
)
|
|
||||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/oauth/login",
|
|
||||||
json={"server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "autogpt-platform" in data["login_url"]
|
|
||||||
|
|
||||||
|
|
||||||
class TestOAuthCallback:
|
|
||||||
def test_oauth_callback_success(self):
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
mock_creds = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
title=None,
|
|
||||||
access_token=SecretStr("access-token-xyz"),
|
|
||||||
refresh_token=None,
|
|
||||||
access_token_expires_at=None,
|
|
||||||
refresh_token_expires_at=None,
|
|
||||||
scopes=[],
|
|
||||||
metadata={
|
|
||||||
"mcp_token_url": "https://auth.sentry.io/token",
|
|
||||||
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
|
||||||
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
|
||||||
):
|
|
||||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
|
||||||
|
|
||||||
# Mock state verification
|
|
||||||
mock_state = AsyncMock()
|
|
||||||
mock_state.state_metadata = {
|
|
||||||
"authorize_url": "https://auth.sentry.io/authorize",
|
|
||||||
"token_url": "https://auth.sentry.io/token",
|
|
||||||
"client_id": "test-client-id",
|
|
||||||
"client_secret": "test-secret",
|
|
||||||
"server_url": "https://mcp.sentry.dev/mcp",
|
|
||||||
}
|
|
||||||
mock_state.scopes = ["openid"]
|
|
||||||
mock_state.code_verifier = "verifier-123"
|
|
||||||
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
|
||||||
mock_cm.create = AsyncMock()
|
|
||||||
|
|
||||||
handler_instance = MockHandler.return_value
|
|
||||||
handler_instance.exchange_code_for_tokens = AsyncMock(
|
|
||||||
return_value=mock_creds
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock old credential cleanup
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/oauth/callback",
|
|
||||||
json={"code": "auth-code-abc", "state_token": "state-token-123"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "id" in data
|
|
||||||
assert data["provider"] == "mcp"
|
|
||||||
assert data["type"] == "oauth2"
|
|
||||||
mock_cm.create.assert_called_once()
|
|
||||||
|
|
||||||
def test_oauth_callback_invalid_state(self):
|
|
||||||
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
|
||||||
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/oauth/callback",
|
|
||||||
json={"code": "auth-code", "state_token": "bad-state"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "Invalid or expired" in response.json()["detail"]
|
|
||||||
|
|
||||||
def test_oauth_callback_token_exchange_fails(self):
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
|
||||||
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
|
||||||
):
|
|
||||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
|
||||||
mock_state = AsyncMock()
|
|
||||||
mock_state.state_metadata = {
|
|
||||||
"authorize_url": "https://auth.example.com/authorize",
|
|
||||||
"token_url": "https://auth.example.com/token",
|
|
||||||
"client_id": "cid",
|
|
||||||
"server_url": "https://mcp.example.com/mcp",
|
|
||||||
}
|
|
||||||
mock_state.scopes = []
|
|
||||||
mock_state.code_verifier = "v"
|
|
||||||
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
|
||||||
|
|
||||||
handler_instance = MockHandler.return_value
|
|
||||||
handler_instance.exchange_code_for_tokens = AsyncMock(
|
|
||||||
side_effect=RuntimeError("Token exchange failed")
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/oauth/callback",
|
|
||||||
json={"code": "bad-code", "state_token": "state"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "token exchange failed" in response.json()["detail"].lower()
|
|
||||||
@@ -20,6 +20,7 @@ from typing import AsyncGenerator
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||||
from prisma.enums import APIKeyPermission
|
from prisma.enums import APIKeyPermission
|
||||||
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
|
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
|
||||||
@@ -38,13 +39,13 @@ keysmith = APIKeySmith()
|
|||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(scope="session")
|
||||||
def test_user_id() -> str:
|
def test_user_id() -> str:
|
||||||
"""Test user ID for OAuth tests."""
|
"""Test user ID for OAuth tests."""
|
||||||
return str(uuid.uuid4())
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||||
async def test_user(server, test_user_id: str):
|
async def test_user(server, test_user_id: str):
|
||||||
"""Create a test user in the database."""
|
"""Create a test user in the database."""
|
||||||
await PrismaUser.prisma().create(
|
await PrismaUser.prisma().create(
|
||||||
@@ -67,7 +68,7 @@ async def test_user(server, test_user_id: str):
|
|||||||
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def test_oauth_app(test_user: str):
|
async def test_oauth_app(test_user: str):
|
||||||
"""Create a test OAuth application in the database."""
|
"""Create a test OAuth application in the database."""
|
||||||
app_id = str(uuid.uuid4())
|
app_id = str(uuid.uuid4())
|
||||||
@@ -122,7 +123,7 @@ def pkce_credentials() -> tuple[str, str]:
|
|||||||
return generate_pkce()
|
return generate_pkce()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, None]:
|
async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||||
"""
|
"""
|
||||||
Create an async HTTP client that talks directly to the FastAPI app.
|
Create an async HTTP client that talks directly to the FastAPI app.
|
||||||
@@ -287,7 +288,7 @@ async def test_authorize_invalid_client_returns_error(
|
|||||||
assert query_params["error"][0] == "invalid_client"
|
assert query_params["error"][0] == "invalid_client"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def inactive_oauth_app(test_user: str):
|
async def inactive_oauth_app(test_user: str):
|
||||||
"""Create an inactive test OAuth application in the database."""
|
"""Create an inactive test OAuth application in the database."""
|
||||||
app_id = str(uuid.uuid4())
|
app_id = str(uuid.uuid4())
|
||||||
@@ -1004,7 +1005,7 @@ async def test_token_refresh_revoked(
|
|||||||
assert "revoked" in response.json()["detail"].lower()
|
assert "revoked" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def other_oauth_app(test_user: str):
|
async def other_oauth_app(test_user: str):
|
||||||
"""Create a second OAuth application for cross-app tests."""
|
"""Create a second OAuth application for cross-app tests."""
|
||||||
app_id = str(uuid.uuid4())
|
app_id = str(uuid.uuid4())
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Literal, overload
|
from typing import Any, Literal
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
@@ -11,8 +11,8 @@ import prisma.types
|
|||||||
|
|
||||||
from backend.data.db import transaction
|
from backend.data.db import transaction
|
||||||
from backend.data.graph import (
|
from backend.data.graph import (
|
||||||
|
GraphMeta,
|
||||||
GraphModel,
|
GraphModel,
|
||||||
GraphModelWithoutNodes,
|
|
||||||
get_graph,
|
get_graph,
|
||||||
get_graph_as_admin,
|
get_graph_as_admin,
|
||||||
get_sub_graphs,
|
get_sub_graphs,
|
||||||
@@ -334,22 +334,7 @@ async def get_store_agent_details(
|
|||||||
raise DatabaseError("Failed to fetch agent details") from e
|
raise DatabaseError("Failed to fetch agent details") from e
|
||||||
|
|
||||||
|
|
||||||
@overload
|
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
||||||
async def get_available_graph(
|
|
||||||
store_listing_version_id: str, hide_nodes: Literal[False]
|
|
||||||
) -> GraphModel: ...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
async def get_available_graph(
|
|
||||||
store_listing_version_id: str, hide_nodes: Literal[True] = True
|
|
||||||
) -> GraphModelWithoutNodes: ...
|
|
||||||
|
|
||||||
|
|
||||||
async def get_available_graph(
|
|
||||||
store_listing_version_id: str,
|
|
||||||
hide_nodes: bool = True,
|
|
||||||
) -> GraphModelWithoutNodes | GraphModel:
|
|
||||||
try:
|
try:
|
||||||
# Get avaialble, non-deleted store listing version
|
# Get avaialble, non-deleted store listing version
|
||||||
store_listing_version = (
|
store_listing_version = (
|
||||||
@@ -359,7 +344,7 @@ async def get_available_graph(
|
|||||||
"isAvailable": True,
|
"isAvailable": True,
|
||||||
"isDeleted": False,
|
"isDeleted": False,
|
||||||
},
|
},
|
||||||
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
include={"AgentGraph": {"include": {"Nodes": True}}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -369,9 +354,7 @@ async def get_available_graph(
|
|||||||
detail=f"Store listing version {store_listing_version_id} not found",
|
detail=f"Store listing version {store_listing_version_id} not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
return (GraphModelWithoutNodes if hide_nodes else GraphModel).from_db(
|
return GraphModel.from_db(store_listing_version.AgentGraph).meta()
|
||||||
store_listing_version.AgentGraph
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting agent: {e}")
|
logger.error(f"Error getting agent: {e}")
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ Includes BM25 reranking for improved lexical relevance.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@@ -363,11 +362,7 @@ async def unified_hybrid_search(
|
|||||||
LIMIT {limit_param} OFFSET {offset_param}
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
results = await query_raw_with_schema(sql_query, *params)
|
||||||
results = await query_raw_with_schema(sql_query, *params)
|
|
||||||
except Exception as e:
|
|
||||||
await _log_vector_error_diagnostics(e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
total = results[0]["total_count"] if results else 0
|
total = results[0]["total_count"] if results else 0
|
||||||
# Apply BM25 reranking
|
# Apply BM25 reranking
|
||||||
@@ -691,11 +686,7 @@ async def hybrid_search(
|
|||||||
LIMIT {limit_param} OFFSET {offset_param}
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
results = await query_raw_with_schema(sql_query, *params)
|
||||||
results = await query_raw_with_schema(sql_query, *params)
|
|
||||||
except Exception as e:
|
|
||||||
await _log_vector_error_diagnostics(e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
total = results[0]["total_count"] if results else 0
|
total = results[0]["total_count"] if results else 0
|
||||||
|
|
||||||
@@ -727,87 +718,6 @@ async def hybrid_search_simple(
|
|||||||
return await hybrid_search(query=query, page=page, page_size=page_size)
|
return await hybrid_search(query=query, page=page, page_size=page_size)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Diagnostics
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
# Rate limit: only log vector error diagnostics once per this interval
|
|
||||||
_VECTOR_DIAG_INTERVAL_SECONDS = 60
|
|
||||||
_last_vector_diag_time: float = 0
|
|
||||||
|
|
||||||
|
|
||||||
async def _log_vector_error_diagnostics(error: Exception) -> None:
|
|
||||||
"""Log diagnostic info when 'type vector does not exist' error occurs.
|
|
||||||
|
|
||||||
Note: Diagnostic queries use query_raw_with_schema which may run on a different
|
|
||||||
pooled connection than the one that failed. Session-level search_path can differ,
|
|
||||||
so these diagnostics show cluster-wide state, not necessarily the failed session.
|
|
||||||
|
|
||||||
Includes rate limiting to avoid log spam - only logs once per minute.
|
|
||||||
Caller should re-raise the error after calling this function.
|
|
||||||
"""
|
|
||||||
global _last_vector_diag_time
|
|
||||||
|
|
||||||
# Check if this is the vector type error
|
|
||||||
error_str = str(error).lower()
|
|
||||||
if not (
|
|
||||||
"type" in error_str and "vector" in error_str and "does not exist" in error_str
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Rate limit: only log once per interval
|
|
||||||
now = time.time()
|
|
||||||
if now - _last_vector_diag_time < _VECTOR_DIAG_INTERVAL_SECONDS:
|
|
||||||
return
|
|
||||||
_last_vector_diag_time = now
|
|
||||||
|
|
||||||
try:
|
|
||||||
diagnostics: dict[str, object] = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
search_path_result = await query_raw_with_schema("SHOW search_path")
|
|
||||||
diagnostics["search_path"] = search_path_result
|
|
||||||
except Exception as e:
|
|
||||||
diagnostics["search_path"] = f"Error: {e}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
schema_result = await query_raw_with_schema("SELECT current_schema()")
|
|
||||||
diagnostics["current_schema"] = schema_result
|
|
||||||
except Exception as e:
|
|
||||||
diagnostics["current_schema"] = f"Error: {e}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
user_result = await query_raw_with_schema(
|
|
||||||
"SELECT current_user, session_user, current_database()"
|
|
||||||
)
|
|
||||||
diagnostics["user_info"] = user_result
|
|
||||||
except Exception as e:
|
|
||||||
diagnostics["user_info"] = f"Error: {e}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Check pgvector extension installation (cluster-wide, stable info)
|
|
||||||
ext_result = await query_raw_with_schema(
|
|
||||||
"SELECT extname, extversion, nspname as schema "
|
|
||||||
"FROM pg_extension e "
|
|
||||||
"JOIN pg_namespace n ON e.extnamespace = n.oid "
|
|
||||||
"WHERE extname = 'vector'"
|
|
||||||
)
|
|
||||||
diagnostics["pgvector_extension"] = ext_result
|
|
||||||
except Exception as e:
|
|
||||||
diagnostics["pgvector_extension"] = f"Error: {e}"
|
|
||||||
|
|
||||||
logger.error(
|
|
||||||
f"Vector type error diagnostics:\n"
|
|
||||||
f" Error: {error}\n"
|
|
||||||
f" search_path: {diagnostics.get('search_path')}\n"
|
|
||||||
f" current_schema: {diagnostics.get('current_schema')}\n"
|
|
||||||
f" user_info: {diagnostics.get('user_info')}\n"
|
|
||||||
f" pgvector_extension: {diagnostics.get('pgvector_extension')}"
|
|
||||||
)
|
|
||||||
except Exception as diag_error:
|
|
||||||
logger.error(f"Failed to collect vector error diagnostics: {diag_error}")
|
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
||||||
# for existing code that expects the popularity parameter
|
# for existing code that expects the popularity parameter
|
||||||
HybridSearchWeights = StoreAgentSearchWeights
|
HybridSearchWeights = StoreAgentSearchWeights
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from backend.blocks.ideogram import (
|
|||||||
StyleType,
|
StyleType,
|
||||||
UpscaleOption,
|
UpscaleOption,
|
||||||
)
|
)
|
||||||
from backend.data.graph import GraphBaseMeta
|
from backend.data.graph import BaseGraph
|
||||||
from backend.data.model import CredentialsMetaInput, ProviderName
|
from backend.data.model import CredentialsMetaInput, ProviderName
|
||||||
from backend.integrations.credentials_store import ideogram_credentials
|
from backend.integrations.credentials_store import ideogram_credentials
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
@@ -34,14 +34,14 @@ class ImageStyle(str, Enum):
|
|||||||
DIGITAL_ART = "digital art"
|
DIGITAL_ART = "digital art"
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||||
if settings.config.use_agent_image_generation_v2:
|
if settings.config.use_agent_image_generation_v2:
|
||||||
return await generate_agent_image_v2(graph=agent)
|
return await generate_agent_image_v2(graph=agent)
|
||||||
else:
|
else:
|
||||||
return await generate_agent_image_v1(agent=agent)
|
return await generate_agent_image_v1(agent=agent)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Generate an image for an agent using Ideogram model.
|
Generate an image for an agent using Ideogram model.
|
||||||
Returns:
|
Returns:
|
||||||
@@ -54,17 +54,14 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
|
|||||||
description = f"{name} ({graph.description})" if graph.description else name
|
description = f"{name} ({graph.description})" if graph.description else name
|
||||||
|
|
||||||
prompt = (
|
prompt = (
|
||||||
"Create a visually striking retro-futuristic vector pop art illustration "
|
f"Create a visually striking retro-futuristic vector pop art illustration prominently featuring "
|
||||||
f'prominently featuring "{name}" in bold typography. The image clearly and '
|
f'"{name}" in bold typography. The image clearly and literally depicts a {description}, '
|
||||||
f"literally depicts a {description}, along with recognizable objects directly "
|
f"along with recognizable objects directly associated with the primary function of a {name}. "
|
||||||
f"associated with the primary function of a {name}. "
|
f"Ensure the imagery is concrete, intuitive, and immediately understandable, clearly conveying the "
|
||||||
f"Ensure the imagery is concrete, intuitive, and immediately understandable, "
|
f"purpose of a {name}. Maintain vibrant, limited-palette colors, sharp vector lines, geometric "
|
||||||
f"clearly conveying the purpose of a {name}. "
|
f"shapes, flat illustration techniques, and solid colors without gradients or shading. Preserve a "
|
||||||
"Maintain vibrant, limited-palette colors, sharp vector lines, "
|
f"retro-futuristic aesthetic influenced by mid-century futurism and 1960s psychedelia, "
|
||||||
"geometric shapes, flat illustration techniques, and solid colors "
|
f"prioritizing clear visual storytelling and thematic clarity above all else."
|
||||||
"without gradients or shading. Preserve a retro-futuristic aesthetic "
|
|
||||||
"influenced by mid-century futurism and 1960s psychedelia, "
|
|
||||||
"prioritizing clear visual storytelling and thematic clarity above all else."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
custom_colors = [
|
custom_colors = [
|
||||||
@@ -102,12 +99,12 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
|
|||||||
return io.BytesIO(response.content)
|
return io.BytesIO(response.content)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Generate an image for an agent using Flux model via Replicate API.
|
Generate an image for an agent using Flux model via Replicate API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent (GraphBaseMeta | AgentGraph): The agent to generate an image for
|
agent (Graph): The agent to generate an image for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
io.BytesIO: The generated image as bytes
|
io.BytesIO: The generated image as bytes
|
||||||
@@ -117,13 +114,7 @@ async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.Bytes
|
|||||||
raise ValueError("Missing Replicate API key in settings")
|
raise ValueError("Missing Replicate API key in settings")
|
||||||
|
|
||||||
# Construct prompt from agent details
|
# Construct prompt from agent details
|
||||||
prompt = (
|
prompt = f"Create a visually engaging app store thumbnail for the AI agent that highlights what it does in a clear and captivating way:\n- **Name**: {agent.name}\n- **Description**: {agent.description}\nFocus on showcasing its core functionality with an appealing design."
|
||||||
"Create a visually engaging app store thumbnail for the AI agent "
|
|
||||||
"that highlights what it does in a clear and captivating way:\n"
|
|
||||||
f"- **Name**: {agent.name}\n"
|
|
||||||
f"- **Description**: {agent.description}\n"
|
|
||||||
f"Focus on showcasing its core functionality with an appealing design."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set up Replicate client
|
# Set up Replicate client
|
||||||
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ async def get_agent(
|
|||||||
)
|
)
|
||||||
async def get_graph_meta_by_store_listing_version_id(
|
async def get_graph_meta_by_store_listing_version_id(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
) -> backend.data.graph.GraphMeta:
|
||||||
"""
|
"""
|
||||||
Get Agent Graph from Store Listing Version ID.
|
Get Agent Graph from Store Listing Version ID.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ from backend.util.timezone_utils import (
|
|||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
|
||||||
from .library import db as library_db
|
from .library import db as library_db
|
||||||
|
from .library import model as library_model
|
||||||
from .store.model import StoreAgentDetails
|
from .store.model import StoreAgentDetails
|
||||||
|
|
||||||
|
|
||||||
@@ -822,16 +823,18 @@ async def update_graph(
|
|||||||
graph: graph_db.Graph,
|
graph: graph_db.Graph,
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> graph_db.GraphModel:
|
) -> graph_db.GraphModel:
|
||||||
|
# Sanity check
|
||||||
if graph.id and graph.id != graph_id:
|
if graph.id and graph.id != graph_id:
|
||||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||||
|
|
||||||
|
# Determine new version
|
||||||
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||||
if not existing_versions:
|
if not existing_versions:
|
||||||
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
||||||
|
latest_version_number = max(g.version for g in existing_versions)
|
||||||
|
graph.version = latest_version_number + 1
|
||||||
|
|
||||||
graph.version = max(g.version for g in existing_versions) + 1
|
|
||||||
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
||||||
|
|
||||||
graph = graph_db.make_graph_model(graph, user_id)
|
graph = graph_db.make_graph_model(graph, user_id)
|
||||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||||
graph.validate_graph(for_run=False)
|
graph.validate_graph(for_run=False)
|
||||||
@@ -839,23 +842,27 @@ async def update_graph(
|
|||||||
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
||||||
|
|
||||||
if new_graph_version.is_active:
|
if new_graph_version.is_active:
|
||||||
await library_db.update_library_agent_version_and_settings(
|
# Keep the library agent up to date with the new active version
|
||||||
user_id, new_graph_version
|
await _update_library_agent_version_and_settings(user_id, new_graph_version)
|
||||||
)
|
|
||||||
|
# Handle activation of the new graph first to ensure continuity
|
||||||
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
||||||
|
# Ensure new version is the only active version
|
||||||
await graph_db.set_graph_active_version(
|
await graph_db.set_graph_active_version(
|
||||||
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
||||||
)
|
)
|
||||||
if current_active_version:
|
if current_active_version:
|
||||||
|
# Handle deactivation of the previously active version
|
||||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||||
|
|
||||||
|
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
|
||||||
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
||||||
graph_id,
|
graph_id,
|
||||||
new_graph_version.version,
|
new_graph_version.version,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
include_subgraphs=True,
|
include_subgraphs=True,
|
||||||
)
|
)
|
||||||
assert new_graph_version_with_subgraphs
|
assert new_graph_version_with_subgraphs # make type checker happy
|
||||||
return new_graph_version_with_subgraphs
|
return new_graph_version_with_subgraphs
|
||||||
|
|
||||||
|
|
||||||
@@ -893,15 +900,33 @@ async def set_graph_active_version(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Keep the library agent up to date with the new active version
|
# Keep the library agent up to date with the new active version
|
||||||
await library_db.update_library_agent_version_and_settings(
|
await _update_library_agent_version_and_settings(user_id, new_active_graph)
|
||||||
user_id, new_active_graph
|
|
||||||
)
|
|
||||||
|
|
||||||
if current_active_graph and current_active_graph.version != new_active_version:
|
if current_active_graph and current_active_graph.version != new_active_version:
|
||||||
# Handle deactivation of the previously active version
|
# Handle deactivation of the previously active version
|
||||||
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_library_agent_version_and_settings(
|
||||||
|
user_id: str, agent_graph: graph_db.GraphModel
|
||||||
|
) -> library_model.LibraryAgent:
|
||||||
|
library = await library_db.update_agent_version_in_library(
|
||||||
|
user_id, agent_graph.id, agent_graph.version
|
||||||
|
)
|
||||||
|
updated_settings = GraphSettings.from_graph(
|
||||||
|
graph=agent_graph,
|
||||||
|
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||||
|
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||||
|
)
|
||||||
|
if updated_settings != library.settings:
|
||||||
|
library = await library_db.update_library_agent(
|
||||||
|
library_agent_id=library.id,
|
||||||
|
user_id=user_id,
|
||||||
|
settings=updated_settings,
|
||||||
|
)
|
||||||
|
return library
|
||||||
|
|
||||||
|
|
||||||
@v1_router.patch(
|
@v1_router.patch(
|
||||||
path="/graphs/{graph_id}/settings",
|
path="/graphs/{graph_id}/settings",
|
||||||
summary="Update graph settings",
|
summary="Update graph settings",
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ import backend.api.features.executions.review.routes
|
|||||||
import backend.api.features.library.db
|
import backend.api.features.library.db
|
||||||
import backend.api.features.library.model
|
import backend.api.features.library.model
|
||||||
import backend.api.features.library.routes
|
import backend.api.features.library.routes
|
||||||
import backend.api.features.mcp.routes as mcp_routes
|
|
||||||
import backend.api.features.oauth
|
import backend.api.features.oauth
|
||||||
import backend.api.features.otto.routes
|
import backend.api.features.otto.routes
|
||||||
import backend.api.features.postmark.postmark
|
import backend.api.features.postmark.postmark
|
||||||
@@ -344,11 +343,6 @@ app.include_router(
|
|||||||
tags=["workspace"],
|
tags=["workspace"],
|
||||||
prefix="/api/workspace",
|
prefix="/api/workspace",
|
||||||
)
|
)
|
||||||
app.include_router(
|
|
||||||
mcp_routes.router,
|
|
||||||
tags=["v2", "mcp"],
|
|
||||||
prefix="/api/mcp",
|
|
||||||
)
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.oauth.router,
|
backend.api.features.oauth.router,
|
||||||
tags=["oauth"],
|
tags=["oauth"],
|
||||||
|
|||||||
@@ -1,28 +0,0 @@
|
|||||||
"""ElevenLabs integration blocks - test credentials and shared utilities."""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
TEST_CREDENTIALS = APIKeyCredentials(
|
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
|
||||||
provider="elevenlabs",
|
|
||||||
api_key=SecretStr("mock-elevenlabs-api-key"),
|
|
||||||
title="Mock ElevenLabs API key",
|
|
||||||
expires_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
TEST_CREDENTIALS_INPUT = {
|
|
||||||
"provider": TEST_CREDENTIALS.provider,
|
|
||||||
"id": TEST_CREDENTIALS.id,
|
|
||||||
"type": TEST_CREDENTIALS.type,
|
|
||||||
"title": TEST_CREDENTIALS.title,
|
|
||||||
}
|
|
||||||
|
|
||||||
ElevenLabsCredentials = APIKeyCredentials
|
|
||||||
ElevenLabsCredentialsInput = CredentialsMetaInput[
|
|
||||||
Literal[ProviderName.ELEVENLABS], Literal["api_key"]
|
|
||||||
]
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
"""Text encoding block for converting special characters to escape sequences."""
|
|
||||||
|
|
||||||
import codecs
|
|
||||||
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
|
|
||||||
|
|
||||||
class TextEncoderBlock(Block):
|
|
||||||
"""
|
|
||||||
Encodes a string by converting special characters into escape sequences.
|
|
||||||
|
|
||||||
This block is the inverse of TextDecoderBlock. It takes text containing
|
|
||||||
special characters (like newlines, tabs, etc.) and converts them into
|
|
||||||
their escape sequence representations (e.g., newline becomes \\n).
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
"""Input schema for TextEncoderBlock."""
|
|
||||||
|
|
||||||
text: str = SchemaField(
|
|
||||||
description="A string containing special characters to be encoded",
|
|
||||||
placeholder="Your text with newlines and quotes to encode",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
"""Output schema for TextEncoderBlock."""
|
|
||||||
|
|
||||||
encoded_text: str = SchemaField(
|
|
||||||
description="The encoded text with special characters converted to escape sequences"
|
|
||||||
)
|
|
||||||
error: str = SchemaField(description="Error message if encoding fails")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="5185f32e-4b65-4ecf-8fbb-873f003f09d6",
|
|
||||||
description="Encodes a string by converting special characters into escape sequences",
|
|
||||||
categories={BlockCategory.TEXT},
|
|
||||||
input_schema=TextEncoderBlock.Input,
|
|
||||||
output_schema=TextEncoderBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"text": """Hello
|
|
||||||
World!
|
|
||||||
This is a "quoted" string."""
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"encoded_text",
|
|
||||||
"""Hello\\nWorld!\\nThis is a "quoted" string.""",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
|
||||||
"""
|
|
||||||
Encode the input text by converting special characters to escape sequences.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_data: The input containing the text to encode.
|
|
||||||
**kwargs: Additional keyword arguments (unused).
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
The encoded text with escape sequences, or an error message if encoding fails.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
encoded_text = codecs.encode(input_data.text, "unicode_escape").decode(
|
|
||||||
"utf-8"
|
|
||||||
)
|
|
||||||
yield "encoded_text", encoded_text
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", f"Encoding error: {str(e)}"
|
|
||||||
@@ -478,7 +478,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
webset = await aexa.websets.get(id=input_data.external_id)
|
webset = aexa.websets.get(id=input_data.external_id)
|
||||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||||
|
|
||||||
yield "webset", webset_result
|
yield "webset", webset_result
|
||||||
@@ -494,7 +494,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
count=input_data.search_count,
|
count=input_data.search_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
webset = await aexa.websets.create(
|
webset = aexa.websets.create(
|
||||||
params=CreateWebsetParameters(
|
params=CreateWebsetParameters(
|
||||||
search=search_params,
|
search=search_params,
|
||||||
external_id=input_data.external_id,
|
external_id=input_data.external_id,
|
||||||
@@ -554,7 +554,7 @@ class ExaUpdateWebsetBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
|
sdk_webset = aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -617,7 +617,7 @@ class ExaListWebsetsBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = await aexa.websets.list(
|
response = aexa.websets.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -678,7 +678,7 @@ class ExaGetWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
|
sdk_webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -748,7 +748,7 @@ class ExaDeleteWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
|
deleted_webset = aexa.websets.delete(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
deleted_webset.status.value
|
deleted_webset.status.value
|
||||||
@@ -798,7 +798,7 @@ class ExaCancelWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
|
canceled_webset = aexa.websets.cancel(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
canceled_webset.status.value
|
canceled_webset.status.value
|
||||||
@@ -968,7 +968,7 @@ class ExaPreviewWebsetBlock(Block):
|
|||||||
entity["description"] = input_data.entity_description
|
entity["description"] = input_data.entity_description
|
||||||
payload["entity"] = entity
|
payload["entity"] = entity
|
||||||
|
|
||||||
sdk_preview = await aexa.websets.preview(params=payload)
|
sdk_preview = aexa.websets.preview(params=payload)
|
||||||
|
|
||||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||||
|
|
||||||
@@ -1051,7 +1051,7 @@ class ExaWebsetStatusBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
@@ -1185,7 +1185,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Extract basic info
|
# Extract basic info
|
||||||
webset_id = webset.id
|
webset_id = webset.id
|
||||||
@@ -1211,7 +1211,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
total_items = 0
|
total_items = 0
|
||||||
|
|
||||||
if input_data.include_sample_items and input_data.sample_size > 0:
|
if input_data.include_sample_items and input_data.sample_size > 0:
|
||||||
items_response = await aexa.websets.items.list(
|
items_response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
sample_items_data = [
|
sample_items_data = [
|
||||||
@@ -1362,7 +1362,7 @@ class ExaWebsetReadyCheckBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset details
|
# Get webset details
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = await aexa.websets.enrichments.create(
|
sdk_enrichment = aexa.websets.enrichments.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -223,7 +223,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_enrich = await aexa.websets.enrichments.get(
|
current_enrich = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=enrichment_id
|
webset_id=input_data.webset_id, id=enrichment_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -234,7 +234,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
|
|
||||||
if current_status in ["completed", "failed", "cancelled"]:
|
if current_status in ["completed", "failed", "cancelled"]:
|
||||||
# Estimate items from webset searches
|
# Estimate items from webset searches
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
for search in webset.searches:
|
for search in webset.searches:
|
||||||
if search.progress:
|
if search.progress:
|
||||||
@@ -329,7 +329,7 @@ class ExaGetEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = await aexa.websets.enrichments.get(
|
sdk_enrichment = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -474,7 +474,7 @@ class ExaDeleteEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_enrichment = await aexa.websets.enrichments.delete(
|
deleted_enrichment = aexa.websets.enrichments.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -525,13 +525,13 @@ class ExaCancelEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_enrichment = await aexa.websets.enrichments.cancel(
|
canceled_enrichment = aexa.websets.enrichments.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to estimate how many items were enriched before cancellation
|
# Try to estimate how many items were enriched before cancellation
|
||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
items_response = await aexa.websets.items.list(
|
items_response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=100
|
webset_id=input_data.webset_id, limit=100
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
# Create mock SDK import object
|
# Create mock SDK import object
|
||||||
mock_import = MagicMock()
|
mock_import = MagicMock()
|
||||||
@@ -247,7 +247,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
imports=MagicMock(create=AsyncMock(return_value=mock_import))
|
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -294,7 +294,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_import = await aexa.websets.imports.create(
|
sdk_import = aexa.websets.imports.create(
|
||||||
params=payload, csv_data=input_data.csv_data
|
params=payload, csv_data=input_data.csv_data
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -360,7 +360,7 @@ class ExaGetImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
|
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
|
||||||
|
|
||||||
import_obj = ImportModel.from_sdk(sdk_import)
|
import_obj = ImportModel.from_sdk(sdk_import)
|
||||||
|
|
||||||
@@ -426,7 +426,7 @@ class ExaListImportsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = await aexa.websets.imports.list(
|
response = aexa.websets.imports.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -474,9 +474,7 @@ class ExaDeleteImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_import = await aexa.websets.imports.delete(
|
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
|
||||||
import_id=input_data.import_id
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "import_id", deleted_import.id
|
yield "import_id", deleted_import.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -575,14 +573,14 @@ class ExaExportWebsetBlock(Block):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create async iterator for list_all
|
# Create mock iterator
|
||||||
async def async_item_iterator(*args, **kwargs):
|
mock_items = [mock_item1, mock_item2]
|
||||||
for item in [mock_item1, mock_item2]:
|
|
||||||
yield item
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(items=MagicMock(list_all=async_item_iterator))
|
websets=MagicMock(
|
||||||
|
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
|
||||||
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -604,7 +602,7 @@ class ExaExportWebsetBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
async for sdk_item in item_iterator:
|
for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class ExaGetWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_item = await aexa.websets.items.get(
|
sdk_item = aexa.websets.items.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -269,7 +269,7 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
response = None
|
response = None
|
||||||
|
|
||||||
while time.time() - start_time < input_data.wait_timeout:
|
while time.time() - start_time < input_data.wait_timeout:
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -282,13 +282,13 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
interval = min(interval * 1.2, 10)
|
interval = min(interval * 1.2, 10)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -340,7 +340,7 @@ class ExaDeleteWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_item = await aexa.websets.items.delete(
|
deleted_item = aexa.websets.items.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -408,7 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
async for sdk_item in item_iterator:
|
for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -475,7 +475,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
entity_type = "unknown"
|
entity_type = "unknown"
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
@@ -495,7 +495,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Get sample items if requested
|
# Get sample items if requested
|
||||||
sample_items: List[WebsetItemModel] = []
|
sample_items: List[WebsetItemModel] = []
|
||||||
if input_data.sample_size > 0:
|
if input_data.sample_size > 0:
|
||||||
items_response = await aexa.websets.items.list(
|
items_response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
# Convert to our stable models
|
# Convert to our stable models
|
||||||
@@ -569,7 +569,7 @@ class ExaGetNewItemsBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get items starting from cursor
|
# Get items starting from cursor
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.since_cursor,
|
cursor=input_data.since_cursor,
|
||||||
limit=input_data.max_items,
|
limit=input_data.max_items,
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
# Create mock SDK monitor object
|
# Create mock SDK monitor object
|
||||||
mock_monitor = MagicMock()
|
mock_monitor = MagicMock()
|
||||||
@@ -263,7 +263,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
monitors=MagicMock(create=AsyncMock(return_value=mock_monitor))
|
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -320,7 +320,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = await aexa.websets.monitors.create(params=payload)
|
sdk_monitor = aexa.websets.monitors.create(params=payload)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -384,7 +384,7 @@ class ExaGetMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -476,7 +476,7 @@ class ExaUpdateMonitorBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = await aexa.websets.monitors.update(
|
sdk_monitor = aexa.websets.monitors.update(
|
||||||
monitor_id=input_data.monitor_id, params=payload
|
monitor_id=input_data.monitor_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -522,9 +522,7 @@ class ExaDeleteMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_monitor = await aexa.websets.monitors.delete(
|
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
|
||||||
monitor_id=input_data.monitor_id
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "monitor_id", deleted_monitor.id
|
yield "monitor_id", deleted_monitor.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -581,7 +579,7 @@ class ExaListMonitorsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = await aexa.websets.monitors.list(
|
response = aexa.websets.monitors.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
WebsetTargetStatus.IDLE,
|
WebsetTargetStatus.IDLE,
|
||||||
WebsetTargetStatus.ANY_COMPLETE,
|
WebsetTargetStatus.ANY_COMPLETE,
|
||||||
]:
|
]:
|
||||||
final_webset = await aexa.websets.wait_until_idle(
|
final_webset = aexa.websets.wait_until_idle(
|
||||||
id=input_data.webset_id,
|
id=input_data.webset_id,
|
||||||
timeout=input_data.timeout,
|
timeout=input_data.timeout,
|
||||||
poll_interval=input_data.check_interval,
|
poll_interval=input_data.check_interval,
|
||||||
@@ -164,7 +164,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
interval = input_data.check_interval
|
interval = input_data.check_interval
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current webset status
|
# Get current webset status
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
current_status = (
|
current_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -209,7 +209,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
|
|
||||||
# Timeout reached
|
# Timeout reached
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
final_status = (
|
final_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -345,7 +345,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current search status using SDK
|
# Get current search status using SDK
|
||||||
search = await aexa.websets.searches.get(
|
search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -401,7 +401,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
search = await aexa.websets.searches.get(
|
search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -503,7 +503,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current enrichment status using SDK
|
# Get current enrichment status using SDK
|
||||||
enrichment = await aexa.websets.enrichments.get(
|
enrichment = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -548,7 +548,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
enrichment = await aexa.websets.enrichments.get(
|
enrichment = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -575,7 +575,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
) -> tuple[list[SampleEnrichmentModel], int]:
|
) -> tuple[list[SampleEnrichmentModel], int]:
|
||||||
"""Get sample enriched data and count."""
|
"""Get sample enriched data and count."""
|
||||||
# Get a few items to see enrichment results using SDK
|
# Get a few items to see enrichment results using SDK
|
||||||
response = await aexa.websets.items.list(webset_id=webset_id, limit=5)
|
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||||
|
|
||||||
sample_data: list[SampleEnrichmentModel] = []
|
sample_data: list[SampleEnrichmentModel] = []
|
||||||
enriched_count = 0
|
enriched_count = 0
|
||||||
|
|||||||
@@ -317,7 +317,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
|
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = await aexa.websets.searches.create(
|
sdk_search = aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -350,7 +350,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
poll_start = time.time()
|
poll_start = time.time()
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_search = await aexa.websets.searches.get(
|
current_search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=search_id
|
webset_id=input_data.webset_id, id=search_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -442,7 +442,7 @@ class ExaGetWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = await aexa.websets.searches.get(
|
sdk_search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -523,7 +523,7 @@ class ExaCancelWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_search = await aexa.websets.searches.cancel(
|
canceled_search = aexa.websets.searches.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -604,7 +604,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset to check existing searches
|
# Get webset to check existing searches
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Look for existing search with same query
|
# Look for existing search with same query
|
||||||
existing_search = None
|
existing_search = None
|
||||||
@@ -636,7 +636,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
if input_data.entity_type != SearchEntityType.AUTO:
|
if input_data.entity_type != SearchEntityType.AUTO:
|
||||||
payload["entity"] = {"type": input_data.entity_type.value}
|
payload["entity"] = {"type": input_data.entity_type.value}
|
||||||
|
|
||||||
sdk_search = await aexa.websets.searches.create(
|
sdk_search = aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -115,7 +115,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||||
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
|
||||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||||
# AI/ML API models
|
# AI/ML API models
|
||||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||||
@@ -271,9 +270,6 @@ MODEL_METADATA = {
|
|||||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
||||||
), # claude-4-sonnet-20250514
|
), # claude-4-sonnet-20250514
|
||||||
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
|
||||||
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
|
||||||
), # claude-opus-4-6
|
|
||||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||||
), # claude-opus-4-5-20251101
|
), # claude-opus-4-5-20251101
|
||||||
@@ -531,12 +527,12 @@ class LLMResponse(BaseModel):
|
|||||||
|
|
||||||
def convert_openai_tool_fmt_to_anthropic(
|
def convert_openai_tool_fmt_to_anthropic(
|
||||||
openai_tools: list[dict] | None = None,
|
openai_tools: list[dict] | None = None,
|
||||||
) -> Iterable[ToolParam] | anthropic.Omit:
|
) -> Iterable[ToolParam] | anthropic.NotGiven:
|
||||||
"""
|
"""
|
||||||
Convert OpenAI tool format to Anthropic tool format.
|
Convert OpenAI tool format to Anthropic tool format.
|
||||||
"""
|
"""
|
||||||
if not openai_tools or len(openai_tools) == 0:
|
if not openai_tools or len(openai_tools) == 0:
|
||||||
return anthropic.omit
|
return anthropic.NOT_GIVEN
|
||||||
|
|
||||||
anthropic_tools = []
|
anthropic_tools = []
|
||||||
for tool in openai_tools:
|
for tool in openai_tools:
|
||||||
@@ -596,10 +592,10 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
|||||||
|
|
||||||
def get_parallel_tool_calls_param(
|
def get_parallel_tool_calls_param(
|
||||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||||
) -> bool | openai.Omit:
|
):
|
||||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||||
return openai.omit
|
return openai.NOT_GIVEN
|
||||||
return parallel_tool_calls
|
return parallel_tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,301 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP (Model Context Protocol) Tool Block.
|
|
||||||
|
|
||||||
A single dynamic block that can connect to any MCP server, discover available tools,
|
|
||||||
and execute them. Works like AgentExecutorBlock — the user selects a tool from a
|
|
||||||
dropdown and the input/output schema adapts dynamically.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockInput,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
BlockType,
|
|
||||||
)
|
|
||||||
from backend.data.model import (
|
|
||||||
CredentialsField,
|
|
||||||
CredentialsMetaInput,
|
|
||||||
OAuth2Credentials,
|
|
||||||
SchemaField,
|
|
||||||
)
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.json import validate_with_jsonschema
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
TEST_CREDENTIALS = OAuth2Credentials(
|
|
||||||
id="test-mcp-cred",
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("mock-mcp-token"),
|
|
||||||
refresh_token=SecretStr("mock-refresh"),
|
|
||||||
scopes=[],
|
|
||||||
title="Mock MCP credential",
|
|
||||||
)
|
|
||||||
TEST_CREDENTIALS_INPUT = {
|
|
||||||
"provider": TEST_CREDENTIALS.provider,
|
|
||||||
"id": TEST_CREDENTIALS.id,
|
|
||||||
"type": TEST_CREDENTIALS.type,
|
|
||||||
"title": TEST_CREDENTIALS.title,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
MCPCredentials = CredentialsMetaInput[Literal[ProviderName.MCP], Literal["oauth2"]]
|
|
||||||
|
|
||||||
|
|
||||||
class MCPToolBlock(Block):
|
|
||||||
"""
|
|
||||||
A block that connects to an MCP server, lets the user pick a tool,
|
|
||||||
and executes it with dynamic input/output schema.
|
|
||||||
|
|
||||||
The flow:
|
|
||||||
1. User provides an MCP server URL (and optional credentials)
|
|
||||||
2. Frontend calls the backend to get tool list from that URL
|
|
||||||
3. User selects a tool from a dropdown (available_tools)
|
|
||||||
4. The block's input schema updates to reflect the selected tool's parameters
|
|
||||||
5. On execution, the block calls the MCP server to run the tool
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
server_url: str = SchemaField(
|
|
||||||
description="URL of the MCP server (Streamable HTTP endpoint)",
|
|
||||||
placeholder="https://mcp.example.com/mcp",
|
|
||||||
)
|
|
||||||
credentials: MCPCredentials = CredentialsField(
|
|
||||||
discriminator="server_url",
|
|
||||||
description="MCP server OAuth credentials",
|
|
||||||
default={},
|
|
||||||
)
|
|
||||||
selected_tool: str = SchemaField(
|
|
||||||
description="The MCP tool to execute",
|
|
||||||
placeholder="Select a tool",
|
|
||||||
default="",
|
|
||||||
)
|
|
||||||
tool_input_schema: dict[str, Any] = SchemaField(
|
|
||||||
description="JSON Schema for the selected tool's input parameters. "
|
|
||||||
"Populated automatically when a tool is selected.",
|
|
||||||
default={},
|
|
||||||
hidden=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_arguments: dict[str, Any] = SchemaField(
|
|
||||||
description="Arguments to pass to the selected MCP tool. "
|
|
||||||
"The fields here are defined by the tool's input schema.",
|
|
||||||
default={},
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
|
||||||
"""Return the tool's input schema so the builder UI renders dynamic fields."""
|
|
||||||
return data.get("tool_input_schema", {})
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
|
||||||
"""Return the current tool_arguments as defaults for the dynamic fields."""
|
|
||||||
return data.get("tool_arguments", {})
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
|
||||||
"""Check which required tool arguments are missing."""
|
|
||||||
required_fields = cls.get_input_schema(data).get("required", [])
|
|
||||||
tool_arguments = data.get("tool_arguments", {})
|
|
||||||
return set(required_fields) - set(tool_arguments)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
|
||||||
"""Validate tool_arguments against the tool's input schema."""
|
|
||||||
tool_schema = cls.get_input_schema(data)
|
|
||||||
if not tool_schema:
|
|
||||||
return None
|
|
||||||
tool_arguments = data.get("tool_arguments", {})
|
|
||||||
return validate_with_jsonschema(tool_schema, tool_arguments)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
result: Any = SchemaField(description="The result returned by the MCP tool")
|
|
||||||
error: str = SchemaField(description="Error message if the tool call failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
|
|
||||||
description="Connect to any MCP server and execute its tools. "
|
|
||||||
"Provide a server URL, select a tool, and pass arguments dynamically.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=MCPToolBlock.Input,
|
|
||||||
output_schema=MCPToolBlock.Output,
|
|
||||||
block_type=BlockType.STANDARD,
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_input={
|
|
||||||
"server_url": "https://mcp.example.com/mcp",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
"selected_tool": "get_weather",
|
|
||||||
"tool_input_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
"tool_arguments": {"city": "London"},
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"result",
|
|
||||||
{"weather": "sunny", "temperature": 20},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_call_mcp_tool": lambda *a, **kw: {
|
|
||||||
"weather": "sunny",
|
|
||||||
"temperature": 20,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _call_mcp_tool(
|
|
||||||
self,
|
|
||||||
server_url: str,
|
|
||||||
tool_name: str,
|
|
||||||
arguments: dict[str, Any],
|
|
||||||
auth_token: str | None = None,
|
|
||||||
) -> Any:
|
|
||||||
"""Call a tool on the MCP server. Extracted for easy mocking in tests."""
|
|
||||||
client = MCPClient(server_url, auth_token=auth_token)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool(tool_name, arguments)
|
|
||||||
|
|
||||||
if result.is_error:
|
|
||||||
error_text = ""
|
|
||||||
for item in result.content:
|
|
||||||
if item.get("type") == "text":
|
|
||||||
error_text += item.get("text", "")
|
|
||||||
raise MCPClientError(
|
|
||||||
f"MCP tool '{tool_name}' returned an error: "
|
|
||||||
f"{error_text or 'Unknown error'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract text content from the result
|
|
||||||
output_parts = []
|
|
||||||
for item in result.content:
|
|
||||||
if item.get("type") == "text":
|
|
||||||
text = item.get("text", "")
|
|
||||||
# Try to parse as JSON for structured output
|
|
||||||
try:
|
|
||||||
output_parts.append(json.loads(text))
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
output_parts.append(text)
|
|
||||||
elif item.get("type") == "image":
|
|
||||||
output_parts.append(
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"data": item.get("data"),
|
|
||||||
"mimeType": item.get("mimeType"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif item.get("type") == "resource":
|
|
||||||
output_parts.append(item.get("resource", {}))
|
|
||||||
|
|
||||||
# If single result, unwrap
|
|
||||||
if len(output_parts) == 1:
|
|
||||||
return output_parts[0]
|
|
||||||
return output_parts if output_parts else None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _auto_lookup_credential(
|
|
||||||
user_id: str, server_url: str
|
|
||||||
) -> "OAuth2Credentials | None":
|
|
||||||
"""Auto-lookup stored MCP credential for a server URL.
|
|
||||||
|
|
||||||
This is a fallback for nodes that don't have ``credentials`` explicitly
|
|
||||||
set (e.g. nodes created before the credential field was wired up).
|
|
||||||
"""
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
try:
|
|
||||||
mgr = IntegrationCredentialsManager()
|
|
||||||
mcp_creds = await mgr.store.get_creds_by_provider(
|
|
||||||
user_id, ProviderName.MCP.value
|
|
||||||
)
|
|
||||||
best: OAuth2Credentials | None = None
|
|
||||||
for cred in mcp_creds:
|
|
||||||
if (
|
|
||||||
isinstance(cred, OAuth2Credentials)
|
|
||||||
and cred.metadata.get("mcp_server_url") == server_url
|
|
||||||
):
|
|
||||||
if best is None or (
|
|
||||||
(cred.access_token_expires_at or 0)
|
|
||||||
> (best.access_token_expires_at or 0)
|
|
||||||
):
|
|
||||||
best = cred
|
|
||||||
if best:
|
|
||||||
best = await mgr.refresh_if_needed(user_id, best)
|
|
||||||
logger.info(
|
|
||||||
"Auto-resolved MCP credential %s for %s", best.id, server_url
|
|
||||||
)
|
|
||||||
return best
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Auto-lookup MCP credential failed", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
user_id: str,
|
|
||||||
credentials: OAuth2Credentials | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
if not input_data.server_url:
|
|
||||||
yield "error", "MCP server URL is required"
|
|
||||||
return
|
|
||||||
|
|
||||||
if not input_data.selected_tool:
|
|
||||||
yield "error", "No tool selected. Please select a tool from the dropdown."
|
|
||||||
return
|
|
||||||
|
|
||||||
# Validate required tool arguments before calling the server.
|
|
||||||
# The executor-level validation is bypassed for MCP blocks because
|
|
||||||
# get_input_defaults() flattens tool_arguments, stripping tool_input_schema
|
|
||||||
# from the validation context.
|
|
||||||
required = set(input_data.tool_input_schema.get("required", []))
|
|
||||||
if required:
|
|
||||||
missing = required - set(input_data.tool_arguments.keys())
|
|
||||||
if missing:
|
|
||||||
yield "error", (
|
|
||||||
f"Missing required argument(s): {', '.join(sorted(missing))}. "
|
|
||||||
f"Please fill in all required fields marked with * in the block form."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# If no credentials were injected by the executor (e.g. legacy nodes
|
|
||||||
# that don't have the credentials field set), try to auto-lookup
|
|
||||||
# the stored MCP credential for this server URL.
|
|
||||||
if credentials is None:
|
|
||||||
credentials = await self._auto_lookup_credential(
|
|
||||||
user_id, input_data.server_url
|
|
||||||
)
|
|
||||||
|
|
||||||
auth_token = (
|
|
||||||
credentials.access_token.get_secret_value() if credentials else None
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await self._call_mcp_tool(
|
|
||||||
server_url=input_data.server_url,
|
|
||||||
tool_name=input_data.selected_tool,
|
|
||||||
arguments=input_data.tool_arguments,
|
|
||||||
auth_token=auth_token,
|
|
||||||
)
|
|
||||||
yield "result", result
|
|
||||||
except MCPClientError as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"MCP tool call failed: {e}")
|
|
||||||
yield "error", f"MCP tool call failed: {str(e)}"
|
|
||||||
@@ -1,318 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP (Model Context Protocol) HTTP client.
|
|
||||||
|
|
||||||
Implements the MCP Streamable HTTP transport for listing tools and calling tools
|
|
||||||
on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST.
|
|
||||||
|
|
||||||
Handles both JSON and SSE (text/event-stream) response formats per the MCP spec.
|
|
||||||
|
|
||||||
Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MCPTool:
|
|
||||||
"""Represents an MCP tool discovered from a server."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
input_schema: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MCPCallResult:
|
|
||||||
"""Result from calling an MCP tool."""
|
|
||||||
|
|
||||||
content: list[dict[str, Any]] = field(default_factory=list)
|
|
||||||
is_error: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class MCPClientError(Exception):
|
|
||||||
"""Raised when an MCP protocol error occurs."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MCPClient:
|
|
||||||
"""
|
|
||||||
Async HTTP client for the MCP Streamable HTTP transport.
|
|
||||||
|
|
||||||
Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST.
|
|
||||||
Supports optional Bearer token authentication.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
server_url: str,
|
|
||||||
auth_token: str | None = None,
|
|
||||||
):
|
|
||||||
self.server_url = server_url.rstrip("/")
|
|
||||||
self.auth_token = auth_token
|
|
||||||
self._request_id = 0
|
|
||||||
self._session_id: str | None = None
|
|
||||||
|
|
||||||
def _next_id(self) -> int:
|
|
||||||
self._request_id += 1
|
|
||||||
return self._request_id
|
|
||||||
|
|
||||||
def _build_headers(self) -> dict[str, str]:
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Accept": "application/json, text/event-stream",
|
|
||||||
}
|
|
||||||
if self.auth_token:
|
|
||||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
|
||||||
if self._session_id:
|
|
||||||
headers["Mcp-Session-Id"] = self._session_id
|
|
||||||
return headers
|
|
||||||
|
|
||||||
def _build_jsonrpc_request(
|
|
||||||
self, method: str, params: dict[str, Any] | None = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
req: dict[str, Any] = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"method": method,
|
|
||||||
"id": self._next_id(),
|
|
||||||
}
|
|
||||||
if params is not None:
|
|
||||||
req["params"] = params
|
|
||||||
return req
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _parse_sse_response(text: str) -> dict[str, Any]:
|
|
||||||
"""Parse an SSE (text/event-stream) response body into JSON-RPC data.
|
|
||||||
|
|
||||||
MCP servers may return responses as SSE with format:
|
|
||||||
event: message
|
|
||||||
data: {"jsonrpc":"2.0","result":{...},"id":1}
|
|
||||||
|
|
||||||
We extract the last `data:` line that contains a JSON-RPC response
|
|
||||||
(i.e. has an "id" field), which is the reply to our request.
|
|
||||||
"""
|
|
||||||
last_data: dict[str, Any] | None = None
|
|
||||||
for line in text.splitlines():
|
|
||||||
stripped = line.strip()
|
|
||||||
if stripped.startswith("data:"):
|
|
||||||
payload = stripped[len("data:") :].strip()
|
|
||||||
if not payload:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
parsed = json.loads(payload)
|
|
||||||
# Only keep JSON-RPC responses (have "id"), skip notifications
|
|
||||||
if isinstance(parsed, dict) and "id" in parsed:
|
|
||||||
last_data = parsed
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
continue
|
|
||||||
if last_data is None:
|
|
||||||
raise MCPClientError("No JSON-RPC response found in SSE stream")
|
|
||||||
return last_data
|
|
||||||
|
|
||||||
async def _send_request(
|
|
||||||
self, method: str, params: dict[str, Any] | None = None
|
|
||||||
) -> Any:
|
|
||||||
"""Send a JSON-RPC request to the MCP server and return the result.
|
|
||||||
|
|
||||||
Handles both ``application/json`` and ``text/event-stream`` responses
|
|
||||||
as required by the MCP Streamable HTTP transport specification.
|
|
||||||
"""
|
|
||||||
payload = self._build_jsonrpc_request(method, params)
|
|
||||||
headers = self._build_headers()
|
|
||||||
|
|
||||||
requests = Requests(
|
|
||||||
raise_for_status=True,
|
|
||||||
extra_headers=headers,
|
|
||||||
)
|
|
||||||
response = await requests.post(self.server_url, json=payload)
|
|
||||||
|
|
||||||
# Capture session ID from response (MCP Streamable HTTP transport)
|
|
||||||
session_id = response.headers.get("Mcp-Session-Id")
|
|
||||||
if session_id:
|
|
||||||
self._session_id = session_id
|
|
||||||
|
|
||||||
content_type = response.headers.get("content-type", "")
|
|
||||||
if "text/event-stream" in content_type:
|
|
||||||
body = self._parse_sse_response(response.text())
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
body = response.json()
|
|
||||||
except Exception as e:
|
|
||||||
raise MCPClientError(
|
|
||||||
f"MCP server returned non-JSON response: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
# Handle JSON-RPC error
|
|
||||||
if "error" in body:
|
|
||||||
error = body["error"]
|
|
||||||
if isinstance(error, dict):
|
|
||||||
raise MCPClientError(
|
|
||||||
f"MCP server error [{error.get('code', '?')}]: "
|
|
||||||
f"{error.get('message', 'Unknown error')}"
|
|
||||||
)
|
|
||||||
raise MCPClientError(f"MCP server error: {error}")
|
|
||||||
|
|
||||||
return body.get("result")
|
|
||||||
|
|
||||||
async def _send_notification(self, method: str) -> None:
|
|
||||||
"""Send a JSON-RPC notification (no id, no response expected)."""
|
|
||||||
headers = self._build_headers()
|
|
||||||
notification = {"jsonrpc": "2.0", "method": method}
|
|
||||||
requests = Requests(
|
|
||||||
raise_for_status=False,
|
|
||||||
extra_headers=headers,
|
|
||||||
)
|
|
||||||
await requests.post(self.server_url, json=notification)
|
|
||||||
|
|
||||||
async def discover_auth(self) -> dict[str, Any] | None:
|
|
||||||
"""Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec).
|
|
||||||
|
|
||||||
Returns ``None`` if the server doesn't require auth, otherwise returns
|
|
||||||
a dict with:
|
|
||||||
- ``authorization_servers``: list of authorization server URLs
|
|
||||||
- ``resource``: the resource indicator URL (usually the MCP endpoint)
|
|
||||||
- ``scopes_supported``: optional list of supported scopes
|
|
||||||
|
|
||||||
The caller can then fetch the authorization server metadata to get
|
|
||||||
``authorization_endpoint``, ``token_endpoint``, etc.
|
|
||||||
"""
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
parsed = urlparse(self.server_url)
|
|
||||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
|
||||||
|
|
||||||
# Build candidates for protected-resource metadata (per RFC 9728)
|
|
||||||
path = parsed.path.rstrip("/")
|
|
||||||
candidates = []
|
|
||||||
if path and path != "/":
|
|
||||||
candidates.append(f"{base}/.well-known/oauth-protected-resource{path}")
|
|
||||||
candidates.append(f"{base}/.well-known/oauth-protected-resource")
|
|
||||||
|
|
||||||
requests = Requests(
|
|
||||||
raise_for_status=False,
|
|
||||||
)
|
|
||||||
for url in candidates:
|
|
||||||
try:
|
|
||||||
resp = await requests.get(url)
|
|
||||||
if resp.status == 200:
|
|
||||||
data = resp.json()
|
|
||||||
if isinstance(data, dict) and "authorization_servers" in data:
|
|
||||||
return data
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def discover_auth_server_metadata(
|
|
||||||
self, auth_server_url: str
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Fetch the OAuth Authorization Server Metadata (RFC 8414).
|
|
||||||
|
|
||||||
Given an authorization server URL, returns a dict with:
|
|
||||||
- ``authorization_endpoint``
|
|
||||||
- ``token_endpoint``
|
|
||||||
- ``registration_endpoint`` (for dynamic client registration)
|
|
||||||
- ``scopes_supported``
|
|
||||||
- ``code_challenge_methods_supported``
|
|
||||||
- etc.
|
|
||||||
"""
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
parsed = urlparse(auth_server_url)
|
|
||||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
|
||||||
path = parsed.path.rstrip("/")
|
|
||||||
|
|
||||||
# Try standard metadata endpoints (RFC 8414 and OpenID Connect)
|
|
||||||
candidates = []
|
|
||||||
if path and path != "/":
|
|
||||||
candidates.append(f"{base}/.well-known/oauth-authorization-server{path}")
|
|
||||||
candidates.append(f"{base}/.well-known/oauth-authorization-server")
|
|
||||||
candidates.append(f"{base}/.well-known/openid-configuration")
|
|
||||||
|
|
||||||
requests = Requests(
|
|
||||||
raise_for_status=False,
|
|
||||||
)
|
|
||||||
for url in candidates:
|
|
||||||
try:
|
|
||||||
resp = await requests.get(url)
|
|
||||||
if resp.status == 200:
|
|
||||||
data = resp.json()
|
|
||||||
if isinstance(data, dict) and "authorization_endpoint" in data:
|
|
||||||
return data
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def initialize(self) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Send the MCP initialize request.
|
|
||||||
|
|
||||||
This is required by the MCP protocol before any other requests.
|
|
||||||
Returns the server's capabilities.
|
|
||||||
"""
|
|
||||||
result = await self._send_request(
|
|
||||||
"initialize",
|
|
||||||
{
|
|
||||||
"protocolVersion": "2025-03-26",
|
|
||||||
"capabilities": {},
|
|
||||||
"clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Send initialized notification (no response expected)
|
|
||||||
await self._send_notification("notifications/initialized")
|
|
||||||
|
|
||||||
return result or {}
|
|
||||||
|
|
||||||
async def list_tools(self) -> list[MCPTool]:
|
|
||||||
"""
|
|
||||||
Discover available tools from the MCP server.
|
|
||||||
|
|
||||||
Returns a list of MCPTool objects with name, description, and input schema.
|
|
||||||
"""
|
|
||||||
result = await self._send_request("tools/list")
|
|
||||||
if not result or "tools" not in result:
|
|
||||||
return []
|
|
||||||
|
|
||||||
tools = []
|
|
||||||
for tool_data in result["tools"]:
|
|
||||||
tools.append(
|
|
||||||
MCPTool(
|
|
||||||
name=tool_data.get("name", ""),
|
|
||||||
description=tool_data.get("description", ""),
|
|
||||||
input_schema=tool_data.get("inputSchema", {}),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return tools
|
|
||||||
|
|
||||||
async def call_tool(
|
|
||||||
self, tool_name: str, arguments: dict[str, Any]
|
|
||||||
) -> MCPCallResult:
|
|
||||||
"""
|
|
||||||
Call a tool on the MCP server.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: The name of the tool to call.
|
|
||||||
arguments: The arguments to pass to the tool.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MCPCallResult with the tool's response content.
|
|
||||||
"""
|
|
||||||
result = await self._send_request(
|
|
||||||
"tools/call",
|
|
||||||
{"name": tool_name, "arguments": arguments},
|
|
||||||
)
|
|
||||||
if not result:
|
|
||||||
return MCPCallResult(is_error=True)
|
|
||||||
|
|
||||||
return MCPCallResult(
|
|
||||||
content=result.get("content", []),
|
|
||||||
is_error=result.get("isError", False),
|
|
||||||
)
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
"""
|
|
||||||
Conftest for MCP block tests.
|
|
||||||
|
|
||||||
Override the session-scoped server and graph_cleanup fixtures from
|
|
||||||
backend/conftest.py so that MCP integration tests don't spin up the
|
|
||||||
full SpinTestServer infrastructure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config: pytest.Config) -> None:
|
|
||||||
config.addinivalue_line("markers", "e2e: end-to-end tests requiring network")
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(
|
|
||||||
config: pytest.Config, items: list[pytest.Item]
|
|
||||||
) -> None:
|
|
||||||
"""Skip e2e tests unless --run-e2e is passed."""
|
|
||||||
if not config.getoption("--run-e2e", default=False):
|
|
||||||
skip_e2e = pytest.mark.skip(reason="need --run-e2e option to run")
|
|
||||||
for item in items:
|
|
||||||
if "e2e" in item.keywords:
|
|
||||||
item.add_marker(skip_e2e)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
|
||||||
parser.addoption(
|
|
||||||
"--run-e2e", action="store_true", default=False, help="run e2e tests"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def server():
|
|
||||||
"""No-op override — MCP tests don't need the full platform server."""
|
|
||||||
yield None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def graph_cleanup(server):
|
|
||||||
"""No-op override — MCP tests don't create graphs."""
|
|
||||||
yield
|
|
||||||
@@ -1,198 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP OAuth handler for MCP servers that use OAuth 2.1 authorization.
|
|
||||||
|
|
||||||
Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed,
|
|
||||||
MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata.
|
|
||||||
This handler accepts those endpoints at construction time.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import urllib.parse
|
|
||||||
from typing import ClassVar, Optional
|
|
||||||
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthHandler(BaseOAuthHandler):
|
|
||||||
"""
|
|
||||||
OAuth handler for MCP servers with dynamically-discovered endpoints.
|
|
||||||
|
|
||||||
Construction requires the authorization and token endpoint URLs,
|
|
||||||
which are obtained via MCP OAuth metadata discovery
|
|
||||||
(``MCPClient.discover_auth`` + ``discover_auth_server_metadata``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP
|
|
||||||
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
client_id: str,
|
|
||||||
client_secret: str,
|
|
||||||
redirect_uri: str,
|
|
||||||
*,
|
|
||||||
authorize_url: str,
|
|
||||||
token_url: str,
|
|
||||||
revoke_url: str | None = None,
|
|
||||||
resource_url: str | None = None,
|
|
||||||
):
|
|
||||||
self.client_id = client_id
|
|
||||||
self.client_secret = client_secret
|
|
||||||
self.redirect_uri = redirect_uri
|
|
||||||
self.authorize_url = authorize_url
|
|
||||||
self.token_url = token_url
|
|
||||||
self.revoke_url = revoke_url
|
|
||||||
self.resource_url = resource_url
|
|
||||||
|
|
||||||
def get_login_url(
|
|
||||||
self,
|
|
||||||
scopes: list[str],
|
|
||||||
state: str,
|
|
||||||
code_challenge: Optional[str],
|
|
||||||
) -> str:
|
|
||||||
scopes = self.handle_default_scopes(scopes)
|
|
||||||
|
|
||||||
params: dict[str, str] = {
|
|
||||||
"response_type": "code",
|
|
||||||
"client_id": self.client_id,
|
|
||||||
"redirect_uri": self.redirect_uri,
|
|
||||||
"state": state,
|
|
||||||
}
|
|
||||||
if scopes:
|
|
||||||
params["scope"] = " ".join(scopes)
|
|
||||||
# PKCE (S256) — included when the caller provides a code_challenge
|
|
||||||
if code_challenge:
|
|
||||||
params["code_challenge"] = code_challenge
|
|
||||||
params["code_challenge_method"] = "S256"
|
|
||||||
# MCP spec requires resource indicator (RFC 8707)
|
|
||||||
if self.resource_url:
|
|
||||||
params["resource"] = self.resource_url
|
|
||||||
|
|
||||||
return f"{self.authorize_url}?{urllib.parse.urlencode(params)}"
|
|
||||||
|
|
||||||
async def exchange_code_for_tokens(
|
|
||||||
self,
|
|
||||||
code: str,
|
|
||||||
scopes: list[str],
|
|
||||||
code_verifier: Optional[str],
|
|
||||||
) -> OAuth2Credentials:
|
|
||||||
data: dict[str, str] = {
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
"code": code,
|
|
||||||
"redirect_uri": self.redirect_uri,
|
|
||||||
"client_id": self.client_id,
|
|
||||||
}
|
|
||||||
if self.client_secret:
|
|
||||||
data["client_secret"] = self.client_secret
|
|
||||||
if code_verifier:
|
|
||||||
data["code_verifier"] = code_verifier
|
|
||||||
if self.resource_url:
|
|
||||||
data["resource"] = self.resource_url
|
|
||||||
|
|
||||||
response = await Requests(raise_for_status=True).post(
|
|
||||||
self.token_url,
|
|
||||||
data=data,
|
|
||||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
||||||
)
|
|
||||||
tokens = response.json()
|
|
||||||
|
|
||||||
if "error" in tokens:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
now = int(time.time())
|
|
||||||
expires_in = tokens.get("expires_in")
|
|
||||||
|
|
||||||
return OAuth2Credentials(
|
|
||||||
provider=self.PROVIDER_NAME,
|
|
||||||
title=None,
|
|
||||||
access_token=SecretStr(tokens["access_token"]),
|
|
||||||
refresh_token=(
|
|
||||||
SecretStr(tokens["refresh_token"])
|
|
||||||
if tokens.get("refresh_token")
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
access_token_expires_at=now + expires_in if expires_in else None,
|
|
||||||
refresh_token_expires_at=None,
|
|
||||||
scopes=scopes,
|
|
||||||
metadata={
|
|
||||||
"mcp_token_url": self.token_url,
|
|
||||||
"mcp_resource_url": self.resource_url,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _refresh_tokens(
|
|
||||||
self, credentials: OAuth2Credentials
|
|
||||||
) -> OAuth2Credentials:
|
|
||||||
if not credentials.refresh_token:
|
|
||||||
raise ValueError("No refresh token available for MCP OAuth credentials")
|
|
||||||
|
|
||||||
data: dict[str, str] = {
|
|
||||||
"grant_type": "refresh_token",
|
|
||||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
|
||||||
"client_id": self.client_id,
|
|
||||||
}
|
|
||||||
if self.client_secret:
|
|
||||||
data["client_secret"] = self.client_secret
|
|
||||||
if self.resource_url:
|
|
||||||
data["resource"] = self.resource_url
|
|
||||||
|
|
||||||
response = await Requests(raise_for_status=True).post(
|
|
||||||
self.token_url,
|
|
||||||
data=data,
|
|
||||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
||||||
)
|
|
||||||
tokens = response.json()
|
|
||||||
|
|
||||||
if "error" in tokens:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
now = int(time.time())
|
|
||||||
expires_in = tokens.get("expires_in")
|
|
||||||
|
|
||||||
return OAuth2Credentials(
|
|
||||||
id=credentials.id,
|
|
||||||
provider=self.PROVIDER_NAME,
|
|
||||||
title=credentials.title,
|
|
||||||
access_token=SecretStr(tokens["access_token"]),
|
|
||||||
refresh_token=(
|
|
||||||
SecretStr(tokens["refresh_token"])
|
|
||||||
if tokens.get("refresh_token")
|
|
||||||
else credentials.refresh_token
|
|
||||||
),
|
|
||||||
access_token_expires_at=now + expires_in if expires_in else None,
|
|
||||||
refresh_token_expires_at=credentials.refresh_token_expires_at,
|
|
||||||
scopes=credentials.scopes,
|
|
||||||
metadata=credentials.metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
|
||||||
if not self.revoke_url:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = {
|
|
||||||
"token": credentials.access_token.get_secret_value(),
|
|
||||||
"token_type_hint": "access_token",
|
|
||||||
"client_id": self.client_id,
|
|
||||||
}
|
|
||||||
await Requests().post(
|
|
||||||
self.revoke_url,
|
|
||||||
data=data,
|
|
||||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True)
|
|
||||||
return False
|
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
"""
|
|
||||||
End-to-end tests against a real public MCP server.
|
|
||||||
|
|
||||||
These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp)
|
|
||||||
which is publicly accessible without authentication and returns SSE responses.
|
|
||||||
|
|
||||||
Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped
|
|
||||||
independently of the rest of the test suite (they require network access).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.blocks.mcp.client import MCPClient
|
|
||||||
|
|
||||||
# Public MCP server that requires no authentication
|
|
||||||
OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.e2e
|
|
||||||
class TestRealMCPServer:
|
|
||||||
"""Tests against the live OpenAI docs MCP server."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_initialize(self):
|
|
||||||
"""Verify we can complete the MCP handshake with a real server."""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
assert result["protocolVersion"] == "2025-03-26"
|
|
||||||
assert "serverInfo" in result
|
|
||||||
assert result["serverInfo"]["name"] == "openai-docs-mcp"
|
|
||||||
assert "tools" in result.get("capabilities", {})
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tools(self):
|
|
||||||
"""Verify we can discover tools from a real MCP server."""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert len(tools) >= 3 # server has at least 5 tools as of writing
|
|
||||||
|
|
||||||
tool_names = {t.name for t in tools}
|
|
||||||
# These tools are documented and should be stable
|
|
||||||
assert "search_openai_docs" in tool_names
|
|
||||||
assert "list_openai_docs" in tool_names
|
|
||||||
assert "fetch_openai_doc" in tool_names
|
|
||||||
|
|
||||||
# Verify schema structure
|
|
||||||
search_tool = next(t for t in tools if t.name == "search_openai_docs")
|
|
||||||
assert "query" in search_tool.input_schema.get("properties", {})
|
|
||||||
assert "query" in search_tool.input_schema.get("required", [])
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_list_api_endpoints(self):
|
|
||||||
"""Call the list_api_endpoints tool and verify we get real data."""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("list_api_endpoints", {})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert len(result.content) >= 1
|
|
||||||
assert result.content[0]["type"] == "text"
|
|
||||||
|
|
||||||
data = json.loads(result.content[0]["text"])
|
|
||||||
assert "paths" in data or "urls" in data
|
|
||||||
# The OpenAI API should have many endpoints
|
|
||||||
total = data.get("total", len(data.get("paths", [])))
|
|
||||||
assert total > 50
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_search(self):
|
|
||||||
"""Search for docs and verify we get results."""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool(
|
|
||||||
"search_openai_docs", {"query": "chat completions", "limit": 3}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert len(result.content) >= 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sse_response_handling(self):
|
|
||||||
"""Verify the client correctly handles SSE responses from a real server.
|
|
||||||
|
|
||||||
This is the key test — our local test server returns JSON,
|
|
||||||
but real MCP servers typically return SSE. This proves the
|
|
||||||
SSE parsing works end-to-end.
|
|
||||||
"""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
# initialize() internally calls _send_request which must parse SSE
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
# If we got here without error, SSE parsing works
|
|
||||||
assert isinstance(result, dict)
|
|
||||||
assert "protocolVersion" in result
|
|
||||||
|
|
||||||
# Also verify list_tools works (another SSE response)
|
|
||||||
tools = await client.list_tools()
|
|
||||||
assert len(tools) > 0
|
|
||||||
assert all(hasattr(t, "name") for t in tools)
|
|
||||||
@@ -1,389 +0,0 @@
|
|||||||
"""
|
|
||||||
Integration tests for MCP client and MCPToolBlock against a real HTTP server.
|
|
||||||
|
|
||||||
These tests spin up a local MCP test server and run the full client/block flow
|
|
||||||
against it — no mocking, real HTTP requests.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import threading
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from aiohttp import web
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.blocks.mcp.block import MCPToolBlock
|
|
||||||
from backend.blocks.mcp.client import MCPClient
|
|
||||||
from backend.blocks.mcp.test_server import create_test_mcp_app
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
MOCK_USER_ID = "test-user-integration"
|
|
||||||
|
|
||||||
|
|
||||||
class _MCPTestServer:
|
|
||||||
"""
|
|
||||||
Run an MCP test server in a background thread with its own event loop.
|
|
||||||
This avoids event loop conflicts with pytest-asyncio.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, auth_token: str | None = None):
|
|
||||||
self.auth_token = auth_token
|
|
||||||
self.url: str = ""
|
|
||||||
self._runner: web.AppRunner | None = None
|
|
||||||
self._loop: asyncio.AbstractEventLoop | None = None
|
|
||||||
self._thread: threading.Thread | None = None
|
|
||||||
self._started = threading.Event()
|
|
||||||
|
|
||||||
def _run(self):
|
|
||||||
self._loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(self._loop)
|
|
||||||
self._loop.run_until_complete(self._start())
|
|
||||||
self._started.set()
|
|
||||||
self._loop.run_forever()
|
|
||||||
|
|
||||||
async def _start(self):
|
|
||||||
app = create_test_mcp_app(auth_token=self.auth_token)
|
|
||||||
self._runner = web.AppRunner(app)
|
|
||||||
await self._runner.setup()
|
|
||||||
site = web.TCPSite(self._runner, "127.0.0.1", 0)
|
|
||||||
await site.start()
|
|
||||||
port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr]
|
|
||||||
self.url = f"http://127.0.0.1:{port}/mcp"
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
|
||||||
self._thread.start()
|
|
||||||
if not self._started.wait(timeout=5):
|
|
||||||
raise RuntimeError("MCP test server failed to start within 5 seconds")
|
|
||||||
return self
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
if self._loop and self._runner:
|
|
||||||
asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result(
|
|
||||||
timeout=5
|
|
||||||
)
|
|
||||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
||||||
if self._thread:
|
|
||||||
self._thread.join(timeout=5)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def mcp_server():
|
|
||||||
"""Start a local MCP test server in a background thread."""
|
|
||||||
server = _MCPTestServer()
|
|
||||||
server.start()
|
|
||||||
yield server.url
|
|
||||||
server.stop()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def mcp_server_with_auth():
|
|
||||||
"""Start a local MCP test server with auth in a background thread."""
|
|
||||||
server = _MCPTestServer(auth_token="test-secret-token")
|
|
||||||
server.start()
|
|
||||||
yield server.url, "test-secret-token"
|
|
||||||
server.stop()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _allow_localhost():
|
|
||||||
"""
|
|
||||||
Allow 127.0.0.1 through SSRF protection for integration tests.
|
|
||||||
|
|
||||||
The Requests class blocks private IPs by default. We patch the Requests
|
|
||||||
constructor to always include 127.0.0.1 as a trusted origin so the local
|
|
||||||
test server is reachable.
|
|
||||||
"""
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
original_init = Requests.__init__
|
|
||||||
|
|
||||||
def patched_init(self, *args, **kwargs):
|
|
||||||
trusted = list(kwargs.get("trusted_origins") or [])
|
|
||||||
trusted.append("http://127.0.0.1")
|
|
||||||
kwargs["trusted_origins"] = trusted
|
|
||||||
original_init(self, *args, **kwargs)
|
|
||||||
|
|
||||||
with patch.object(Requests, "__init__", patched_init):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def _make_client(url: str, auth_token: str | None = None) -> MCPClient:
|
|
||||||
"""Create an MCPClient for integration tests."""
|
|
||||||
return MCPClient(url, auth_token=auth_token)
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPClient integration tests ──────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPClientIntegration:
|
|
||||||
"""Test MCPClient against a real local MCP server."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_initialize(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
assert result["protocolVersion"] == "2025-03-26"
|
|
||||||
assert result["serverInfo"]["name"] == "test-mcp-server"
|
|
||||||
assert "tools" in result["capabilities"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tools(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert len(tools) == 3
|
|
||||||
|
|
||||||
tool_names = {t.name for t in tools}
|
|
||||||
assert tool_names == {"get_weather", "add_numbers", "echo"}
|
|
||||||
|
|
||||||
# Check get_weather schema
|
|
||||||
weather = next(t for t in tools if t.name == "get_weather")
|
|
||||||
assert weather.description == "Get current weather for a city"
|
|
||||||
assert "city" in weather.input_schema["properties"]
|
|
||||||
assert weather.input_schema["required"] == ["city"]
|
|
||||||
|
|
||||||
# Check add_numbers schema
|
|
||||||
add = next(t for t in tools if t.name == "add_numbers")
|
|
||||||
assert "a" in add.input_schema["properties"]
|
|
||||||
assert "b" in add.input_schema["properties"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_get_weather(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("get_weather", {"city": "London"})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert len(result.content) == 1
|
|
||||||
assert result.content[0]["type"] == "text"
|
|
||||||
|
|
||||||
data = json.loads(result.content[0]["text"])
|
|
||||||
assert data["city"] == "London"
|
|
||||||
assert data["temperature"] == 22
|
|
||||||
assert data["condition"] == "sunny"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_add_numbers(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("add_numbers", {"a": 3, "b": 7})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
data = json.loads(result.content[0]["text"])
|
|
||||||
assert data["result"] == 10
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_echo(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("echo", {"message": "Hello MCP!"})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert result.content[0]["text"] == "Hello MCP!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_unknown_tool(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("nonexistent_tool", {})
|
|
||||||
|
|
||||||
assert result.is_error
|
|
||||||
assert "Unknown tool" in result.content[0]["text"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auth_success(self, mcp_server_with_auth):
|
|
||||||
url, token = mcp_server_with_auth
|
|
||||||
client = _make_client(url, auth_token=token)
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
assert result["protocolVersion"] == "2025-03-26"
|
|
||||||
|
|
||||||
tools = await client.list_tools()
|
|
||||||
assert len(tools) == 3
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auth_failure(self, mcp_server_with_auth):
|
|
||||||
url, _ = mcp_server_with_auth
|
|
||||||
client = _make_client(url, auth_token="wrong-token")
|
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
await client.initialize()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auth_missing(self, mcp_server_with_auth):
|
|
||||||
url, _ = mcp_server_with_auth
|
|
||||||
client = _make_client(url)
|
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
await client.initialize()
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPToolBlock integration tests ───────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPToolBlockIntegration:
|
|
||||||
"""Test MCPToolBlock end-to-end against a real local MCP server."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_full_flow_get_weather(self, mcp_server):
|
|
||||||
"""Full flow: discover tools, select one, execute it."""
|
|
||||||
# Step 1: Discover tools (simulating what the frontend/API would do)
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
assert len(tools) == 3
|
|
||||||
|
|
||||||
# Step 2: User selects "get_weather" and we get its schema
|
|
||||||
weather_tool = next(t for t in tools if t.name == "get_weather")
|
|
||||||
|
|
||||||
# Step 3: Execute the block — no credentials (public server)
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="get_weather",
|
|
||||||
tool_input_schema=weather_tool.input_schema,
|
|
||||||
tool_arguments={"city": "Paris"},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
result = outputs[0][1]
|
|
||||||
assert result["city"] == "Paris"
|
|
||||||
assert result["temperature"] == 22
|
|
||||||
assert result["condition"] == "sunny"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_full_flow_add_numbers(self, mcp_server):
|
|
||||||
"""Full flow for add_numbers tool."""
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
add_tool = next(t for t in tools if t.name == "add_numbers")
|
|
||||||
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="add_numbers",
|
|
||||||
tool_input_schema=add_tool.input_schema,
|
|
||||||
tool_arguments={"a": 42, "b": 58},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1]["result"] == 100
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_full_flow_echo_plain_text(self, mcp_server):
|
|
||||||
"""Verify plain text (non-JSON) responses work."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="echo",
|
|
||||||
tool_input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"message": {"type": "string"}},
|
|
||||||
"required": ["message"],
|
|
||||||
},
|
|
||||||
tool_arguments={"message": "Hello from AutoGPT!"},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1] == "Hello from AutoGPT!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_full_flow_unknown_tool_yields_error(self, mcp_server):
|
|
||||||
"""Calling an unknown tool should yield an error output."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="nonexistent_tool",
|
|
||||||
tool_arguments={},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "error"
|
|
||||||
assert "returned an error" in outputs[0][1]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_full_flow_with_auth(self, mcp_server_with_auth):
|
|
||||||
"""Full flow with authentication via credentials kwarg."""
|
|
||||||
url, token = mcp_server_with_auth
|
|
||||||
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=url,
|
|
||||||
selected_tool="echo",
|
|
||||||
tool_input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"message": {"type": "string"}},
|
|
||||||
"required": ["message"],
|
|
||||||
},
|
|
||||||
tool_arguments={"message": "Authenticated!"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pass credentials via the standard kwarg (as the executor would)
|
|
||||||
test_creds = OAuth2Credentials(
|
|
||||||
id="test-cred",
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr(token),
|
|
||||||
refresh_token=SecretStr(""),
|
|
||||||
scopes=[],
|
|
||||||
title="Test MCP credential",
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(
|
|
||||||
input_data, user_id=MOCK_USER_ID, credentials=test_creds
|
|
||||||
):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1] == "Authenticated!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_no_credentials_runs_without_auth(self, mcp_server):
|
|
||||||
"""Block runs without auth when no credentials are provided."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="echo",
|
|
||||||
tool_input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"message": {"type": "string"}},
|
|
||||||
"required": ["message"],
|
|
||||||
},
|
|
||||||
tool_arguments={"message": "No auth needed"},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(
|
|
||||||
input_data, user_id=MOCK_USER_ID, credentials=None
|
|
||||||
):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1] == "No auth needed"
|
|
||||||
@@ -1,619 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for MCP client and MCPToolBlock.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.blocks.mcp.block import MCPToolBlock
|
|
||||||
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
|
|
||||||
from backend.util.test import execute_block_test
|
|
||||||
|
|
||||||
# ── SSE parsing unit tests ───────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestSSEParsing:
|
|
||||||
"""Tests for SSE (text/event-stream) response parsing."""
|
|
||||||
|
|
||||||
def test_parse_sse_simple(self):
|
|
||||||
sse = (
|
|
||||||
"event: message\n"
|
|
||||||
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert body["result"] == {"tools": []}
|
|
||||||
assert body["id"] == 1
|
|
||||||
|
|
||||||
def test_parse_sse_with_notifications(self):
|
|
||||||
"""SSE streams can contain notifications (no id) before the response."""
|
|
||||||
sse = (
|
|
||||||
"event: message\n"
|
|
||||||
'data: {"jsonrpc":"2.0","method":"some/notification"}\n'
|
|
||||||
"\n"
|
|
||||||
"event: message\n"
|
|
||||||
'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n'
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert body["result"] == {"ok": True}
|
|
||||||
assert body["id"] == 2
|
|
||||||
|
|
||||||
def test_parse_sse_error_response(self):
|
|
||||||
sse = (
|
|
||||||
"event: message\n"
|
|
||||||
'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n'
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert "error" in body
|
|
||||||
assert body["error"]["code"] == -32600
|
|
||||||
|
|
||||||
def test_parse_sse_no_data_raises(self):
|
|
||||||
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
|
||||||
MCPClient._parse_sse_response("event: message\n\n")
|
|
||||||
|
|
||||||
def test_parse_sse_empty_raises(self):
|
|
||||||
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
|
||||||
MCPClient._parse_sse_response("")
|
|
||||||
|
|
||||||
def test_parse_sse_ignores_non_data_lines(self):
|
|
||||||
sse = (
|
|
||||||
": comment line\n"
|
|
||||||
"event: message\n"
|
|
||||||
"id: 123\n"
|
|
||||||
'data: {"jsonrpc":"2.0","result":"ok","id":1}\n'
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert body["result"] == "ok"
|
|
||||||
|
|
||||||
def test_parse_sse_uses_last_response(self):
|
|
||||||
"""If multiple responses exist, use the last one."""
|
|
||||||
sse = (
|
|
||||||
'data: {"jsonrpc":"2.0","result":"first","id":1}\n'
|
|
||||||
"\n"
|
|
||||||
'data: {"jsonrpc":"2.0","result":"second","id":2}\n'
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert body["result"] == "second"
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPClient unit tests ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPClient:
|
|
||||||
"""Tests for the MCP HTTP client."""
|
|
||||||
|
|
||||||
def test_build_headers_without_auth(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
headers = client._build_headers()
|
|
||||||
assert "Authorization" not in headers
|
|
||||||
assert headers["Content-Type"] == "application/json"
|
|
||||||
|
|
||||||
def test_build_headers_with_auth(self):
|
|
||||||
client = MCPClient("https://mcp.example.com", auth_token="my-token")
|
|
||||||
headers = client._build_headers()
|
|
||||||
assert headers["Authorization"] == "Bearer my-token"
|
|
||||||
|
|
||||||
def test_build_jsonrpc_request(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
req = client._build_jsonrpc_request("tools/list")
|
|
||||||
assert req["jsonrpc"] == "2.0"
|
|
||||||
assert req["method"] == "tools/list"
|
|
||||||
assert "id" in req
|
|
||||||
assert "params" not in req
|
|
||||||
|
|
||||||
def test_build_jsonrpc_request_with_params(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
req = client._build_jsonrpc_request(
|
|
||||||
"tools/call", {"name": "test", "arguments": {"x": 1}}
|
|
||||||
)
|
|
||||||
assert req["params"] == {"name": "test", "arguments": {"x": 1}}
|
|
||||||
|
|
||||||
def test_request_id_increments(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
req1 = client._build_jsonrpc_request("tools/list")
|
|
||||||
req2 = client._build_jsonrpc_request("tools/list")
|
|
||||||
assert req2["id"] > req1["id"]
|
|
||||||
|
|
||||||
def test_server_url_trailing_slash_stripped(self):
|
|
||||||
client = MCPClient("https://mcp.example.com/mcp/")
|
|
||||||
assert client.server_url == "https://mcp.example.com/mcp"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_send_request_success(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_response = AsyncMock()
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"result": {"tools": []},
|
|
||||||
"id": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value={"tools": []}):
|
|
||||||
result = await client._send_request("tools/list")
|
|
||||||
assert result == {"tools": []}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_send_request_error(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
async def mock_send(*args, **kwargs):
|
|
||||||
raise MCPClientError("MCP server error [-32600]: Invalid Request")
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", side_effect=mock_send):
|
|
||||||
with pytest.raises(MCPClientError, match="Invalid Request"):
|
|
||||||
await client._send_request("tools/list")
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tools(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_result = {
|
|
||||||
"tools": [
|
|
||||||
{
|
|
||||||
"name": "get_weather",
|
|
||||||
"description": "Get current weather for a city",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "search",
|
|
||||||
"description": "Search the web",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"query": {"type": "string"}},
|
|
||||||
"required": ["query"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=mock_result):
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert len(tools) == 2
|
|
||||||
assert tools[0].name == "get_weather"
|
|
||||||
assert tools[0].description == "Get current weather for a city"
|
|
||||||
assert tools[0].input_schema["properties"]["city"]["type"] == "string"
|
|
||||||
assert tools[1].name == "search"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tools_empty(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value={"tools": []}):
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert tools == []
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tools_none_result(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=None):
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert tools == []
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_success(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_result = {
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": json.dumps({"temp": 20, "city": "London"})}
|
|
||||||
],
|
|
||||||
"isError": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=mock_result):
|
|
||||||
result = await client.call_tool("get_weather", {"city": "London"})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert len(result.content) == 1
|
|
||||||
assert result.content[0]["type"] == "text"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_error(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_result = {
|
|
||||||
"content": [{"type": "text", "text": "City not found"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=mock_result):
|
|
||||||
result = await client.call_tool("get_weather", {"city": "???"})
|
|
||||||
|
|
||||||
assert result.is_error
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_none_result(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=None):
|
|
||||||
result = await client.call_tool("get_weather", {"city": "London"})
|
|
||||||
|
|
||||||
assert result.is_error
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_initialize(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_result = {
|
|
||||||
"protocolVersion": "2025-03-26",
|
|
||||||
"capabilities": {"tools": {}},
|
|
||||||
"serverInfo": {"name": "test-server", "version": "1.0.0"},
|
|
||||||
}
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(client, "_send_request", return_value=mock_result) as mock_req,
|
|
||||||
patch.object(client, "_send_notification") as mock_notif,
|
|
||||||
):
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
mock_req.assert_called_once()
|
|
||||||
mock_notif.assert_called_once_with("notifications/initialized")
|
|
||||||
assert result["protocolVersion"] == "2025-03-26"
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPToolBlock unit tests ──────────────────────────────────────────
|
|
||||||
|
|
||||||
MOCK_USER_ID = "test-user-123"
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPToolBlock:
|
|
||||||
"""Tests for the MCPToolBlock."""
|
|
||||||
|
|
||||||
def test_block_instantiation(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
|
||||||
assert block.name == "MCPToolBlock"
|
|
||||||
|
|
||||||
def test_input_schema_has_required_fields(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
schema = block.input_schema.jsonschema()
|
|
||||||
props = schema.get("properties", {})
|
|
||||||
assert "server_url" in props
|
|
||||||
assert "selected_tool" in props
|
|
||||||
assert "tool_arguments" in props
|
|
||||||
assert "credentials" in props
|
|
||||||
|
|
||||||
def test_output_schema(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
schema = block.output_schema.jsonschema()
|
|
||||||
props = schema.get("properties", {})
|
|
||||||
assert "result" in props
|
|
||||||
assert "error" in props
|
|
||||||
|
|
||||||
def test_get_input_schema_with_tool_schema(self):
|
|
||||||
tool_schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"query": {"type": "string"}},
|
|
||||||
"required": ["query"],
|
|
||||||
}
|
|
||||||
data = {"tool_input_schema": tool_schema}
|
|
||||||
result = MCPToolBlock.Input.get_input_schema(data)
|
|
||||||
assert result == tool_schema
|
|
||||||
|
|
||||||
def test_get_input_schema_without_tool_schema(self):
|
|
||||||
result = MCPToolBlock.Input.get_input_schema({})
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
def test_get_input_defaults(self):
|
|
||||||
data = {"tool_arguments": {"city": "London"}}
|
|
||||||
result = MCPToolBlock.Input.get_input_defaults(data)
|
|
||||||
assert result == {"city": "London"}
|
|
||||||
|
|
||||||
def test_get_missing_input(self):
|
|
||||||
data = {
|
|
||||||
"tool_input_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {"type": "string"},
|
|
||||||
"units": {"type": "string"},
|
|
||||||
},
|
|
||||||
"required": ["city", "units"],
|
|
||||||
},
|
|
||||||
"tool_arguments": {"city": "London"},
|
|
||||||
}
|
|
||||||
missing = MCPToolBlock.Input.get_missing_input(data)
|
|
||||||
assert missing == {"units"}
|
|
||||||
|
|
||||||
def test_get_missing_input_all_present(self):
|
|
||||||
data = {
|
|
||||||
"tool_input_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
"tool_arguments": {"city": "London"},
|
|
||||||
}
|
|
||||||
missing = MCPToolBlock.Input.get_missing_input(data)
|
|
||||||
assert missing == set()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_with_mock(self):
|
|
||||||
"""Test the block using the built-in test infrastructure."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
await execute_block_test(block)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_missing_server_url(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="",
|
|
||||||
selected_tool="test",
|
|
||||||
)
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
assert outputs == [("error", "MCP server URL is required")]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_missing_tool(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="",
|
|
||||||
)
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
assert outputs == [
|
|
||||||
("error", "No tool selected. Please select a tool from the dropdown.")
|
|
||||||
]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_success(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="get_weather",
|
|
||||||
tool_input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
},
|
|
||||||
tool_arguments={"city": "London"},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_call(*args, **kwargs):
|
|
||||||
return {"temp": 20, "city": "London"}
|
|
||||||
|
|
||||||
block._call_mcp_tool = mock_call # type: ignore
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1] == {"temp": 20, "city": "London"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_mcp_error(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="bad_tool",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_call(*args, **kwargs):
|
|
||||||
raise MCPClientError("Tool not found")
|
|
||||||
|
|
||||||
block._call_mcp_tool = mock_call # type: ignore
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert outputs[0][0] == "error"
|
|
||||||
assert "Tool not found" in outputs[0][1]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_mcp_tool_parses_json_text(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": '{"temp": 20}'},
|
|
||||||
],
|
|
||||||
is_error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
result = await block._call_mcp_tool(
|
|
||||||
"https://mcp.example.com", "test_tool", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == {"temp": 20}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_mcp_tool_plain_text(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "Hello, world!"},
|
|
||||||
],
|
|
||||||
is_error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
result = await block._call_mcp_tool(
|
|
||||||
"https://mcp.example.com", "test_tool", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == "Hello, world!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_mcp_tool_multiple_content(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "Part 1"},
|
|
||||||
{"type": "text", "text": '{"part": 2}'},
|
|
||||||
],
|
|
||||||
is_error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
result = await block._call_mcp_tool(
|
|
||||||
"https://mcp.example.com", "test_tool", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == ["Part 1", {"part": 2}]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_mcp_tool_error_result(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[{"type": "text", "text": "Something went wrong"}],
|
|
||||||
is_error=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
with pytest.raises(MCPClientError, match="returned an error"):
|
|
||||||
await block._call_mcp_tool("https://mcp.example.com", "test_tool", {})
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_mcp_tool_image_content(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"data": "base64data==",
|
|
||||||
"mimeType": "image/png",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
is_error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
result = await block._call_mcp_tool(
|
|
||||||
"https://mcp.example.com", "test_tool", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == {
|
|
||||||
"type": "image",
|
|
||||||
"data": "base64data==",
|
|
||||||
"mimeType": "image/png",
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_with_credentials(self):
|
|
||||||
"""Verify the block uses OAuth2Credentials and passes auth token."""
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="test_tool",
|
|
||||||
)
|
|
||||||
|
|
||||||
captured_tokens: list[str | None] = []
|
|
||||||
|
|
||||||
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
|
||||||
captured_tokens.append(auth_token)
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
block._call_mcp_tool = mock_call # type: ignore
|
|
||||||
|
|
||||||
test_creds = OAuth2Credentials(
|
|
||||||
id="cred-123",
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("resolved-token"),
|
|
||||||
refresh_token=SecretStr(""),
|
|
||||||
scopes=[],
|
|
||||||
title="Test MCP credential",
|
|
||||||
)
|
|
||||||
|
|
||||||
async for _ in block.run(
|
|
||||||
input_data, user_id=MOCK_USER_ID, credentials=test_creds
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
assert captured_tokens == ["resolved-token"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_without_credentials(self):
|
|
||||||
"""Verify the block works without credentials (public server)."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="test_tool",
|
|
||||||
)
|
|
||||||
|
|
||||||
captured_tokens: list[str | None] = []
|
|
||||||
|
|
||||||
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
|
||||||
captured_tokens.append(auth_token)
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
block._call_mcp_tool = mock_call # type: ignore
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert captured_tokens == [None]
|
|
||||||
assert outputs == [("result", "ok")]
|
|
||||||
@@ -1,242 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for MCP OAuth handler.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.blocks.mcp.client import MCPClient
|
|
||||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_response(json_data: dict, status: int = 200) -> MagicMock:
|
|
||||||
"""Create a mock Response with synchronous json() (matching Requests.Response)."""
|
|
||||||
resp = MagicMock()
|
|
||||||
resp.status = status
|
|
||||||
resp.ok = 200 <= status < 300
|
|
||||||
resp.json.return_value = json_data
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPOAuthHandler:
|
|
||||||
"""Tests for the MCPOAuthHandler."""
|
|
||||||
|
|
||||||
def _make_handler(self, **overrides) -> MCPOAuthHandler:
|
|
||||||
defaults = {
|
|
||||||
"client_id": "test-client-id",
|
|
||||||
"client_secret": "test-client-secret",
|
|
||||||
"redirect_uri": "https://app.example.com/callback",
|
|
||||||
"authorize_url": "https://auth.example.com/authorize",
|
|
||||||
"token_url": "https://auth.example.com/token",
|
|
||||||
}
|
|
||||||
defaults.update(overrides)
|
|
||||||
return MCPOAuthHandler(**defaults)
|
|
||||||
|
|
||||||
def test_get_login_url_basic(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
url = handler.get_login_url(
|
|
||||||
scopes=["read", "write"],
|
|
||||||
state="random-state-token",
|
|
||||||
code_challenge="S256-challenge-value",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "https://auth.example.com/authorize?" in url
|
|
||||||
assert "response_type=code" in url
|
|
||||||
assert "client_id=test-client-id" in url
|
|
||||||
assert "state=random-state-token" in url
|
|
||||||
assert "code_challenge=S256-challenge-value" in url
|
|
||||||
assert "code_challenge_method=S256" in url
|
|
||||||
assert "scope=read+write" in url
|
|
||||||
|
|
||||||
def test_get_login_url_with_resource(self):
|
|
||||||
handler = self._make_handler(resource_url="https://mcp.example.com/mcp")
|
|
||||||
url = handler.get_login_url(
|
|
||||||
scopes=[], state="state", code_challenge="challenge"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "resource=https" in url
|
|
||||||
|
|
||||||
def test_get_login_url_without_pkce(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None)
|
|
||||||
|
|
||||||
assert "code_challenge" not in url
|
|
||||||
assert "code_challenge_method" not in url
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_exchange_code_for_tokens(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
|
|
||||||
resp = _mock_response(
|
|
||||||
{
|
|
||||||
"access_token": "new-access-token",
|
|
||||||
"refresh_token": "new-refresh-token",
|
|
||||||
"expires_in": 3600,
|
|
||||||
"token_type": "Bearer",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.post = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
creds = await handler.exchange_code_for_tokens(
|
|
||||||
code="auth-code",
|
|
||||||
scopes=["read"],
|
|
||||||
code_verifier="pkce-verifier",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(creds, OAuth2Credentials)
|
|
||||||
assert creds.access_token.get_secret_value() == "new-access-token"
|
|
||||||
assert creds.refresh_token is not None
|
|
||||||
assert creds.refresh_token.get_secret_value() == "new-refresh-token"
|
|
||||||
assert creds.scopes == ["read"]
|
|
||||||
assert creds.access_token_expires_at is not None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_refresh_tokens(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
|
|
||||||
existing_creds = OAuth2Credentials(
|
|
||||||
id="existing-id",
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("old-token"),
|
|
||||||
refresh_token=SecretStr("old-refresh"),
|
|
||||||
scopes=["read"],
|
|
||||||
title="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
resp = _mock_response(
|
|
||||||
{
|
|
||||||
"access_token": "refreshed-token",
|
|
||||||
"refresh_token": "new-refresh",
|
|
||||||
"expires_in": 3600,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.post = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
refreshed = await handler._refresh_tokens(existing_creds)
|
|
||||||
|
|
||||||
assert refreshed.id == "existing-id"
|
|
||||||
assert refreshed.access_token.get_secret_value() == "refreshed-token"
|
|
||||||
assert refreshed.refresh_token is not None
|
|
||||||
assert refreshed.refresh_token.get_secret_value() == "new-refresh"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_refresh_tokens_no_refresh_token(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
|
|
||||||
creds = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("token"),
|
|
||||||
scopes=["read"],
|
|
||||||
title="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="No refresh token"):
|
|
||||||
await handler._refresh_tokens(creds)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_revoke_tokens_no_url(self):
|
|
||||||
handler = self._make_handler(revoke_url=None)
|
|
||||||
|
|
||||||
creds = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("token"),
|
|
||||||
scopes=[],
|
|
||||||
title="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await handler.revoke_tokens(creds)
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_revoke_tokens_with_url(self):
|
|
||||||
handler = self._make_handler(revoke_url="https://auth.example.com/revoke")
|
|
||||||
|
|
||||||
creds = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("token"),
|
|
||||||
scopes=[],
|
|
||||||
title="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
resp = _mock_response({}, status=200)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.post = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
result = await handler.revoke_tokens(creds)
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPClientDiscovery:
|
|
||||||
"""Tests for MCPClient OAuth metadata discovery."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_discover_auth_found(self):
|
|
||||||
client = MCPClient("https://mcp.example.com/mcp")
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"authorization_servers": ["https://auth.example.com"],
|
|
||||||
"resource": "https://mcp.example.com/mcp",
|
|
||||||
}
|
|
||||||
|
|
||||||
resp = _mock_response(metadata, status=200)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.get = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
result = await client.discover_auth()
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert result["authorization_servers"] == ["https://auth.example.com"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_discover_auth_not_found(self):
|
|
||||||
client = MCPClient("https://mcp.example.com/mcp")
|
|
||||||
|
|
||||||
resp = _mock_response({}, status=404)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.get = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
result = await client.discover_auth()
|
|
||||||
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_discover_auth_server_metadata(self):
|
|
||||||
client = MCPClient("https://mcp.example.com/mcp")
|
|
||||||
|
|
||||||
server_metadata = {
|
|
||||||
"issuer": "https://auth.example.com",
|
|
||||||
"authorization_endpoint": "https://auth.example.com/authorize",
|
|
||||||
"token_endpoint": "https://auth.example.com/token",
|
|
||||||
"registration_endpoint": "https://auth.example.com/register",
|
|
||||||
"code_challenge_methods_supported": ["S256"],
|
|
||||||
}
|
|
||||||
|
|
||||||
resp = _mock_response(server_metadata, status=200)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.get = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
result = await client.discover_auth_server_metadata(
|
|
||||||
"https://auth.example.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert result["authorization_endpoint"] == "https://auth.example.com/authorize"
|
|
||||||
assert result["token_endpoint"] == "https://auth.example.com/token"
|
|
||||||
@@ -1,162 +0,0 @@
|
|||||||
"""
|
|
||||||
Minimal MCP server for integration testing.
|
|
||||||
|
|
||||||
Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST)
|
|
||||||
with a few sample tools. Runs on localhost with a random available port.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from aiohttp import web
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Sample tools this test server exposes
|
|
||||||
TEST_TOOLS = [
|
|
||||||
{
|
|
||||||
"name": "get_weather",
|
|
||||||
"description": "Get current weather for a city",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "City name",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "add_numbers",
|
|
||||||
"description": "Add two numbers together",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"a": {"type": "number", "description": "First number"},
|
|
||||||
"b": {"type": "number", "description": "Second number"},
|
|
||||||
},
|
|
||||||
"required": ["a", "b"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "echo",
|
|
||||||
"description": "Echo back the input message",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"message": {"type": "string", "description": "Message to echo"},
|
|
||||||
},
|
|
||||||
"required": ["message"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_initialize(params: dict) -> dict:
|
|
||||||
return {
|
|
||||||
"protocolVersion": "2025-03-26",
|
|
||||||
"capabilities": {"tools": {"listChanged": False}},
|
|
||||||
"serverInfo": {"name": "test-mcp-server", "version": "1.0.0"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_tools_list(params: dict) -> dict:
|
|
||||||
return {"tools": TEST_TOOLS}
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_tools_call(params: dict) -> dict:
|
|
||||||
tool_name = params.get("name", "")
|
|
||||||
arguments = params.get("arguments", {})
|
|
||||||
|
|
||||||
if tool_name == "get_weather":
|
|
||||||
city = arguments.get("city", "Unknown")
|
|
||||||
return {
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": json.dumps(
|
|
||||||
{"city": city, "temperature": 22, "condition": "sunny"}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
elif tool_name == "add_numbers":
|
|
||||||
a = arguments.get("a", 0)
|
|
||||||
b = arguments.get("b", 0)
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": json.dumps({"result": a + b})}],
|
|
||||||
}
|
|
||||||
|
|
||||||
elif tool_name == "echo":
|
|
||||||
message = arguments.get("message", "")
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": message}],
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
HANDLERS = {
|
|
||||||
"initialize": _handle_initialize,
|
|
||||||
"tools/list": _handle_tools_list,
|
|
||||||
"tools/call": _handle_tools_call,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_request(request: web.Request) -> web.Response:
|
|
||||||
"""Handle incoming MCP JSON-RPC 2.0 requests."""
|
|
||||||
# Check auth if configured
|
|
||||||
expected_token = request.app.get("auth_token")
|
|
||||||
if expected_token:
|
|
||||||
auth_header = request.headers.get("Authorization", "")
|
|
||||||
if auth_header != f"Bearer {expected_token}":
|
|
||||||
return web.json_response(
|
|
||||||
{
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"error": {"code": -32001, "message": "Unauthorized"},
|
|
||||||
"id": None,
|
|
||||||
},
|
|
||||||
status=401,
|
|
||||||
)
|
|
||||||
|
|
||||||
body = await request.json()
|
|
||||||
|
|
||||||
# Handle notifications (no id field) — just acknowledge
|
|
||||||
if "id" not in body:
|
|
||||||
return web.Response(status=202)
|
|
||||||
|
|
||||||
method = body.get("method", "")
|
|
||||||
params = body.get("params", {})
|
|
||||||
request_id = body.get("id")
|
|
||||||
|
|
||||||
handler = HANDLERS.get(method)
|
|
||||||
if not handler:
|
|
||||||
return web.json_response(
|
|
||||||
{
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"error": {
|
|
||||||
"code": -32601,
|
|
||||||
"message": f"Method not found: {method}",
|
|
||||||
},
|
|
||||||
"id": request_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
result = handler(params)
|
|
||||||
return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id})
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_mcp_app(auth_token: str | None = None) -> web.Application:
|
|
||||||
"""Create an aiohttp app that acts as an MCP server."""
|
|
||||||
app = web.Application()
|
|
||||||
app.router.add_post("/mcp", handle_mcp_request)
|
|
||||||
if auth_token:
|
|
||||||
app["auth_token"] = auth_token
|
|
||||||
return app
|
|
||||||
246
autogpt_platform/backend/backend/blocks/media.py
Normal file
246
autogpt_platform/backend/backend/blocks/media.py
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
|
from moviepy.video.fx.Loop import Loop
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class MediaDurationBlock(Block):
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
media_in: MediaFileType = SchemaField(
|
||||||
|
description="Media input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
is_video: bool = SchemaField(
|
||||||
|
description="Whether the media is a video (True) or audio (False).",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
duration: float = SchemaField(
|
||||||
|
description="Duration of the media file (in seconds)."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
||||||
|
description="Block to get the duration of a media file.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=MediaDurationBlock.Input,
|
||||||
|
output_schema=MediaDurationBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# 1) Store the input media locally
|
||||||
|
local_media_path = await store_media_file(
|
||||||
|
file=input_data.media_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
media_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_media_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Load the clip
|
||||||
|
if input_data.is_video:
|
||||||
|
clip = VideoFileClip(media_abspath)
|
||||||
|
else:
|
||||||
|
clip = AudioFileClip(media_abspath)
|
||||||
|
|
||||||
|
yield "duration", clip.duration
|
||||||
|
|
||||||
|
|
||||||
|
class LoopVideoBlock(Block):
|
||||||
|
"""
|
||||||
|
Block for looping (repeating) a video clip until a given duration or number of loops.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="The input video (can be a URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
# Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`.
|
||||||
|
duration: Optional[float] = SchemaField(
|
||||||
|
description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.",
|
||||||
|
default=None,
|
||||||
|
ge=0.0,
|
||||||
|
)
|
||||||
|
n_loops: Optional[int] = SchemaField(
|
||||||
|
description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).",
|
||||||
|
default=None,
|
||||||
|
ge=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: str = SchemaField(
|
||||||
|
description="Looped video returned either as a relative path or a data URI."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
||||||
|
description="Block to loop a video to a given duration or number of repeats.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=LoopVideoBlock.Input,
|
||||||
|
output_schema=LoopVideoBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
|
# 1) Store the input video locally
|
||||||
|
local_video_path = await store_media_file(
|
||||||
|
file=input_data.video_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||||
|
|
||||||
|
# 2) Load the clip
|
||||||
|
clip = VideoFileClip(input_abspath)
|
||||||
|
|
||||||
|
# 3) Apply the loop effect
|
||||||
|
looped_clip = clip
|
||||||
|
if input_data.duration:
|
||||||
|
# Loop until we reach the specified duration
|
||||||
|
looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)])
|
||||||
|
elif input_data.n_loops:
|
||||||
|
looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)])
|
||||||
|
else:
|
||||||
|
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
||||||
|
|
||||||
|
assert isinstance(looped_clip, VideoFileClip)
|
||||||
|
|
||||||
|
# 4) Save the looped output
|
||||||
|
output_filename = MediaFileType(
|
||||||
|
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
|
||||||
|
)
|
||||||
|
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||||
|
|
||||||
|
looped_clip = looped_clip.with_audio(clip.audio)
|
||||||
|
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||||
|
|
||||||
|
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
|
video_out = await store_media_file(
|
||||||
|
file=output_filename,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
|
||||||
|
|
||||||
|
class AddAudioToVideoBlock(Block):
|
||||||
|
"""
|
||||||
|
Block that adds (attaches) an audio track to an existing video.
|
||||||
|
Optionally scale the volume of the new track.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Video input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
audio_in: MediaFileType = SchemaField(
|
||||||
|
description="Audio input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
volume: float = SchemaField(
|
||||||
|
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||||
|
default=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Final video (with attached audio), as a path or data URI."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="3503748d-62b6-4425-91d6-725b064af509",
|
||||||
|
description="Block to attach an audio file to a video file using moviepy.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=AddAudioToVideoBlock.Input,
|
||||||
|
output_schema=AddAudioToVideoBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
|
# 1) Store the inputs locally
|
||||||
|
local_video_path = await store_media_file(
|
||||||
|
file=input_data.video_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
local_audio_path = await store_media_file(
|
||||||
|
file=input_data.audio_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
||||||
|
video_abspath = os.path.join(abs_temp_dir, local_video_path)
|
||||||
|
audio_abspath = os.path.join(abs_temp_dir, local_audio_path)
|
||||||
|
|
||||||
|
# 2) Load video + audio with moviepy
|
||||||
|
video_clip = VideoFileClip(video_abspath)
|
||||||
|
audio_clip = AudioFileClip(audio_abspath)
|
||||||
|
# Optionally scale volume
|
||||||
|
if input_data.volume != 1.0:
|
||||||
|
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
||||||
|
|
||||||
|
# 3) Attach the new audio track
|
||||||
|
final_clip = video_clip.with_audio(audio_clip)
|
||||||
|
|
||||||
|
# 4) Write to output file
|
||||||
|
output_filename = MediaFileType(
|
||||||
|
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
|
||||||
|
)
|
||||||
|
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
||||||
|
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||||
|
|
||||||
|
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
|
video_out = await store_media_file(
|
||||||
|
file=output_filename,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from backend.blocks.encoder_block import TextEncoderBlock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_encoder_basic():
|
|
||||||
"""Test basic encoding of newlines and special characters."""
|
|
||||||
block = TextEncoderBlock()
|
|
||||||
result = []
|
|
||||||
async for output in block.run(TextEncoderBlock.Input(text="Hello\nWorld")):
|
|
||||||
result.append(output)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0][0] == "encoded_text"
|
|
||||||
assert result[0][1] == "Hello\\nWorld"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_encoder_multiple_escapes():
|
|
||||||
"""Test encoding of multiple escape sequences."""
|
|
||||||
block = TextEncoderBlock()
|
|
||||||
result = []
|
|
||||||
async for output in block.run(
|
|
||||||
TextEncoderBlock.Input(text="Line1\nLine2\tTabbed\rCarriage")
|
|
||||||
):
|
|
||||||
result.append(output)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0][0] == "encoded_text"
|
|
||||||
assert "\\n" in result[0][1]
|
|
||||||
assert "\\t" in result[0][1]
|
|
||||||
assert "\\r" in result[0][1]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_encoder_unicode():
|
|
||||||
"""Test that unicode characters are handled correctly."""
|
|
||||||
block = TextEncoderBlock()
|
|
||||||
result = []
|
|
||||||
async for output in block.run(TextEncoderBlock.Input(text="Hello 世界\n")):
|
|
||||||
result.append(output)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0][0] == "encoded_text"
|
|
||||||
# Unicode characters should be escaped as \uXXXX sequences
|
|
||||||
assert "\\n" in result[0][1]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_encoder_empty_string():
|
|
||||||
"""Test encoding of an empty string."""
|
|
||||||
block = TextEncoderBlock()
|
|
||||||
result = []
|
|
||||||
async for output in block.run(TextEncoderBlock.Input(text="")):
|
|
||||||
result.append(output)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0][0] == "encoded_text"
|
|
||||||
assert result[0][1] == ""
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_encoder_error_handling():
|
|
||||||
"""Test that encoding errors are handled gracefully."""
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
block = TextEncoderBlock()
|
|
||||||
result = []
|
|
||||||
|
|
||||||
with patch("codecs.encode", side_effect=Exception("Mocked encoding error")):
|
|
||||||
async for output in block.run(TextEncoderBlock.Input(text="test")):
|
|
||||||
result.append(output)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0][0] == "error"
|
|
||||||
assert "Mocked encoding error" in result[0][1]
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
"""Video editing blocks for AutoGPT Platform.
|
|
||||||
|
|
||||||
This module provides blocks for:
|
|
||||||
- Downloading videos from URLs (YouTube, Vimeo, news sites, direct links)
|
|
||||||
- Clipping/trimming video segments
|
|
||||||
- Concatenating multiple videos
|
|
||||||
- Adding text overlays
|
|
||||||
- Adding AI-generated narration
|
|
||||||
- Getting media duration
|
|
||||||
- Looping videos
|
|
||||||
- Adding audio to videos
|
|
||||||
|
|
||||||
Dependencies:
|
|
||||||
- yt-dlp: For video downloading
|
|
||||||
- moviepy: For video editing operations
|
|
||||||
- elevenlabs: For AI narration (optional)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from backend.blocks.video.add_audio import AddAudioToVideoBlock
|
|
||||||
from backend.blocks.video.clip import VideoClipBlock
|
|
||||||
from backend.blocks.video.concat import VideoConcatBlock
|
|
||||||
from backend.blocks.video.download import VideoDownloadBlock
|
|
||||||
from backend.blocks.video.duration import MediaDurationBlock
|
|
||||||
from backend.blocks.video.loop import LoopVideoBlock
|
|
||||||
from backend.blocks.video.narration import VideoNarrationBlock
|
|
||||||
from backend.blocks.video.text_overlay import VideoTextOverlayBlock
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AddAudioToVideoBlock",
|
|
||||||
"LoopVideoBlock",
|
|
||||||
"MediaDurationBlock",
|
|
||||||
"VideoClipBlock",
|
|
||||||
"VideoConcatBlock",
|
|
||||||
"VideoDownloadBlock",
|
|
||||||
"VideoNarrationBlock",
|
|
||||||
"VideoTextOverlayBlock",
|
|
||||||
]
|
|
||||||
@@ -1,131 +0,0 @@
|
|||||||
"""Shared utilities for video blocks."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import subprocess
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Known operation tags added by video blocks
|
|
||||||
_VIDEO_OPS = (
|
|
||||||
r"(?:clip|overlay|narrated|looped|concat|audio_attached|with_audio|narration)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Matches: {node_exec_id}_{operation}_ where node_exec_id contains a UUID
|
|
||||||
_BLOCK_PREFIX_RE = re.compile(
|
|
||||||
r"^[a-zA-Z0-9_-]*"
|
|
||||||
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
|
||||||
r"[a-zA-Z0-9_-]*"
|
|
||||||
r"_" + _VIDEO_OPS + r"_"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Matches: a lone {node_exec_id}_ prefix (no operation keyword, e.g. download output)
|
|
||||||
_UUID_PREFIX_RE = re.compile(
|
|
||||||
r"^[a-zA-Z0-9_-]*"
|
|
||||||
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
|
||||||
r"[a-zA-Z0-9_-]*_"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_source_name(input_path: str, max_length: int = 50) -> str:
|
|
||||||
"""Extract the original source filename by stripping block-generated prefixes.
|
|
||||||
|
|
||||||
Iteratively removes {node_exec_id}_{operation}_ prefixes that accumulate
|
|
||||||
when chaining video blocks, recovering the original human-readable name.
|
|
||||||
|
|
||||||
Safe for plain filenames (no UUID -> no stripping).
|
|
||||||
Falls back to "video" if everything is stripped.
|
|
||||||
"""
|
|
||||||
stem = Path(input_path).stem
|
|
||||||
|
|
||||||
# Pass 1: strip {node_exec_id}_{operation}_ prefixes iteratively
|
|
||||||
while _BLOCK_PREFIX_RE.match(stem):
|
|
||||||
stem = _BLOCK_PREFIX_RE.sub("", stem, count=1)
|
|
||||||
|
|
||||||
# Pass 2: strip a lone {node_exec_id}_ prefix (e.g. from download block)
|
|
||||||
if _UUID_PREFIX_RE.match(stem):
|
|
||||||
stem = _UUID_PREFIX_RE.sub("", stem, count=1)
|
|
||||||
|
|
||||||
if not stem:
|
|
||||||
return "video"
|
|
||||||
|
|
||||||
return stem[:max_length]
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_codecs(output_path: str) -> tuple[str, str]:
|
|
||||||
"""Get appropriate video and audio codecs based on output file extension.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_path: Path to the output file (used to determine extension)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (video_codec, audio_codec)
|
|
||||||
|
|
||||||
Codec mappings:
|
|
||||||
- .mp4: H.264 + AAC (universal compatibility)
|
|
||||||
- .webm: VP8 + Vorbis (web streaming)
|
|
||||||
- .mkv: H.264 + AAC (container supports many codecs)
|
|
||||||
- .mov: H.264 + AAC (Apple QuickTime, widely compatible)
|
|
||||||
- .m4v: H.264 + AAC (Apple iTunes/devices)
|
|
||||||
- .avi: MPEG-4 + MP3 (legacy Windows)
|
|
||||||
"""
|
|
||||||
ext = os.path.splitext(output_path)[1].lower()
|
|
||||||
|
|
||||||
codec_map: dict[str, tuple[str, str]] = {
|
|
||||||
".mp4": ("libx264", "aac"),
|
|
||||||
".webm": ("libvpx", "libvorbis"),
|
|
||||||
".mkv": ("libx264", "aac"),
|
|
||||||
".mov": ("libx264", "aac"),
|
|
||||||
".m4v": ("libx264", "aac"),
|
|
||||||
".avi": ("mpeg4", "libmp3lame"),
|
|
||||||
}
|
|
||||||
|
|
||||||
return codec_map.get(ext, ("libx264", "aac"))
|
|
||||||
|
|
||||||
|
|
||||||
def strip_chapters_inplace(video_path: str) -> None:
|
|
||||||
"""Strip chapter metadata from a media file in-place using ffmpeg.
|
|
||||||
|
|
||||||
MoviePy 2.x crashes with IndexError when parsing files with embedded
|
|
||||||
chapter metadata (https://github.com/Zulko/moviepy/issues/2419).
|
|
||||||
This strips chapters without re-encoding.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
video_path: Absolute path to the media file to strip chapters from.
|
|
||||||
"""
|
|
||||||
base, ext = os.path.splitext(video_path)
|
|
||||||
tmp_path = base + ".tmp" + ext
|
|
||||||
try:
|
|
||||||
result = subprocess.run(
|
|
||||||
[
|
|
||||||
"ffmpeg",
|
|
||||||
"-y",
|
|
||||||
"-i",
|
|
||||||
video_path,
|
|
||||||
"-map_chapters",
|
|
||||||
"-1",
|
|
||||||
"-codec",
|
|
||||||
"copy",
|
|
||||||
tmp_path,
|
|
||||||
],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=300,
|
|
||||||
)
|
|
||||||
if result.returncode != 0:
|
|
||||||
logger.warning(
|
|
||||||
"ffmpeg chapter strip failed (rc=%d): %s",
|
|
||||||
result.returncode,
|
|
||||||
result.stderr,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
os.replace(tmp_path, video_path)
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.warning("ffmpeg not found; skipping chapter strip")
|
|
||||||
finally:
|
|
||||||
if os.path.exists(tmp_path):
|
|
||||||
os.unlink(tmp_path)
|
|
||||||
@@ -1,113 +0,0 @@
|
|||||||
"""AddAudioToVideoBlock - Attach an audio track to a video file."""
|
|
||||||
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class AddAudioToVideoBlock(Block):
|
|
||||||
"""Add (attach) an audio track to an existing video."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="Video input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
audio_in: MediaFileType = SchemaField(
|
|
||||||
description="Audio input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
volume: float = SchemaField(
|
|
||||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
|
||||||
default=1.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Final video (with attached audio), as a path or data URI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="3503748d-62b6-4425-91d6-725b064af509",
|
|
||||||
description="Block to attach an audio file to a video file using moviepy.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=AddAudioToVideoBlock.Input,
|
|
||||||
output_schema=AddAudioToVideoBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
assert execution_context.node_exec_id is not None
|
|
||||||
graph_exec_id = execution_context.graph_exec_id
|
|
||||||
node_exec_id = execution_context.node_exec_id
|
|
||||||
|
|
||||||
# 1) Store the inputs locally
|
|
||||||
local_video_path = await store_media_file(
|
|
||||||
file=input_data.video_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
local_audio_path = await store_media_file(
|
|
||||||
file=input_data.audio_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
video_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
|
||||||
audio_abspath = get_exec_file_path(graph_exec_id, local_audio_path)
|
|
||||||
|
|
||||||
# 2) Load video + audio with moviepy
|
|
||||||
strip_chapters_inplace(video_abspath)
|
|
||||||
strip_chapters_inplace(audio_abspath)
|
|
||||||
video_clip = None
|
|
||||||
audio_clip = None
|
|
||||||
final_clip = None
|
|
||||||
try:
|
|
||||||
video_clip = VideoFileClip(video_abspath)
|
|
||||||
audio_clip = AudioFileClip(audio_abspath)
|
|
||||||
# Optionally scale volume
|
|
||||||
if input_data.volume != 1.0:
|
|
||||||
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
|
||||||
|
|
||||||
# 3) Attach the new audio track
|
|
||||||
final_clip = video_clip.with_audio(audio_clip)
|
|
||||||
|
|
||||||
# 4) Write to output file
|
|
||||||
source = extract_source_name(local_video_path)
|
|
||||||
output_filename = MediaFileType(f"{node_exec_id}_with_audio_{source}.mp4")
|
|
||||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
|
||||||
final_clip.write_videofile(
|
|
||||||
output_abspath, codec="libx264", audio_codec="aac"
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if final_clip:
|
|
||||||
final_clip.close()
|
|
||||||
if audio_clip:
|
|
||||||
audio_clip.close()
|
|
||||||
if video_clip:
|
|
||||||
video_clip.close()
|
|
||||||
|
|
||||||
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
|
||||||
video_out = await store_media_file(
|
|
||||||
file=output_filename,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
@@ -1,167 +0,0 @@
|
|||||||
"""VideoClipBlock - Extract a segment from a video file."""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import (
|
|
||||||
extract_source_name,
|
|
||||||
get_video_codecs,
|
|
||||||
strip_chapters_inplace,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.exceptions import BlockExecutionError
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class VideoClipBlock(Block):
|
|
||||||
"""Extract a time segment from a video."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="Input video (URL, data URI, or local path)"
|
|
||||||
)
|
|
||||||
start_time: float = SchemaField(description="Start time in seconds", ge=0.0)
|
|
||||||
end_time: float = SchemaField(description="End time in seconds", ge=0.0)
|
|
||||||
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
|
||||||
description="Output format", default="mp4", advanced=True
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Clipped video file (path or data URI)"
|
|
||||||
)
|
|
||||||
duration: float = SchemaField(description="Clip duration in seconds")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8f539119-e580-4d86-ad41-86fbcb22abb1",
|
|
||||||
description="Extract a time segment from a video",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
test_input={
|
|
||||||
"video_in": "/tmp/test.mp4",
|
|
||||||
"start_time": 0.0,
|
|
||||||
"end_time": 10.0,
|
|
||||||
},
|
|
||||||
test_output=[("video_out", str), ("duration", float)],
|
|
||||||
test_mock={
|
|
||||||
"_clip_video": lambda *args: 10.0,
|
|
||||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
|
||||||
"_store_output_video": lambda *args, **kwargs: "clip_test.mp4",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_input_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store input video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_output_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store output video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _clip_video(
|
|
||||||
self,
|
|
||||||
video_abspath: str,
|
|
||||||
output_abspath: str,
|
|
||||||
start_time: float,
|
|
||||||
end_time: float,
|
|
||||||
) -> float:
|
|
||||||
"""Extract a clip from a video. Extracted for testability."""
|
|
||||||
clip = None
|
|
||||||
subclip = None
|
|
||||||
try:
|
|
||||||
strip_chapters_inplace(video_abspath)
|
|
||||||
clip = VideoFileClip(video_abspath)
|
|
||||||
subclip = clip.subclipped(start_time, end_time)
|
|
||||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
|
||||||
subclip.write_videofile(
|
|
||||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
|
||||||
)
|
|
||||||
return subclip.duration
|
|
||||||
finally:
|
|
||||||
if subclip:
|
|
||||||
subclip.close()
|
|
||||||
if clip:
|
|
||||||
clip.close()
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
node_exec_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# Validate time range
|
|
||||||
if input_data.end_time <= input_data.start_time:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
|
|
||||||
# Store the input video locally
|
|
||||||
local_video_path = await self._store_input_video(
|
|
||||||
execution_context, input_data.video_in
|
|
||||||
)
|
|
||||||
video_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, local_video_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build output path
|
|
||||||
source = extract_source_name(local_video_path)
|
|
||||||
output_filename = MediaFileType(
|
|
||||||
f"{node_exec_id}_clip_{source}.{input_data.output_format}"
|
|
||||||
)
|
|
||||||
output_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
duration = self._clip_video(
|
|
||||||
video_abspath,
|
|
||||||
output_abspath,
|
|
||||||
input_data.start_time,
|
|
||||||
input_data.end_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return as workspace path or data URI based on context
|
|
||||||
video_out = await self._store_output_video(
|
|
||||||
execution_context, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
yield "duration", duration
|
|
||||||
|
|
||||||
except BlockExecutionError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Failed to clip video: {e}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
) from e
|
|
||||||
@@ -1,227 +0,0 @@
|
|||||||
"""VideoConcatBlock - Concatenate multiple video clips into one."""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from moviepy import concatenate_videoclips
|
|
||||||
from moviepy.video.fx import CrossFadeIn, CrossFadeOut, FadeIn, FadeOut
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import (
|
|
||||||
extract_source_name,
|
|
||||||
get_video_codecs,
|
|
||||||
strip_chapters_inplace,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.exceptions import BlockExecutionError
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class VideoConcatBlock(Block):
|
|
||||||
"""Merge multiple video clips into one continuous video."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
videos: list[MediaFileType] = SchemaField(
|
|
||||||
description="List of video files to concatenate (in order)"
|
|
||||||
)
|
|
||||||
transition: Literal["none", "crossfade", "fade_black"] = SchemaField(
|
|
||||||
description="Transition between clips", default="none"
|
|
||||||
)
|
|
||||||
transition_duration: int = SchemaField(
|
|
||||||
description="Transition duration in seconds",
|
|
||||||
default=1,
|
|
||||||
ge=0,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
|
||||||
description="Output format", default="mp4", advanced=True
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Concatenated video file (path or data URI)"
|
|
||||||
)
|
|
||||||
total_duration: float = SchemaField(description="Total duration in seconds")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="9b0f531a-1118-487f-aeec-3fa63ea8900a",
|
|
||||||
description="Merge multiple video clips into one continuous video",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
test_input={
|
|
||||||
"videos": ["/tmp/a.mp4", "/tmp/b.mp4"],
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
("video_out", str),
|
|
||||||
("total_duration", float),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_concat_videos": lambda *args: 20.0,
|
|
||||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
|
||||||
"_store_output_video": lambda *args, **kwargs: "concat_test.mp4",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_input_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store input video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_output_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store output video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _concat_videos(
|
|
||||||
self,
|
|
||||||
video_abspaths: list[str],
|
|
||||||
output_abspath: str,
|
|
||||||
transition: str,
|
|
||||||
transition_duration: int,
|
|
||||||
) -> float:
|
|
||||||
"""Concatenate videos. Extracted for testability.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Total duration of the concatenated video.
|
|
||||||
"""
|
|
||||||
clips = []
|
|
||||||
faded_clips = []
|
|
||||||
final = None
|
|
||||||
try:
|
|
||||||
# Load clips
|
|
||||||
for v in video_abspaths:
|
|
||||||
strip_chapters_inplace(v)
|
|
||||||
clips.append(VideoFileClip(v))
|
|
||||||
|
|
||||||
# Validate transition_duration against shortest clip
|
|
||||||
if transition in {"crossfade", "fade_black"} and transition_duration > 0:
|
|
||||||
min_duration = min(c.duration for c in clips)
|
|
||||||
if transition_duration >= min_duration:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=(
|
|
||||||
f"transition_duration ({transition_duration}s) must be "
|
|
||||||
f"shorter than the shortest clip ({min_duration:.2f}s)"
|
|
||||||
),
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
)
|
|
||||||
|
|
||||||
if transition == "crossfade":
|
|
||||||
for i, clip in enumerate(clips):
|
|
||||||
effects = []
|
|
||||||
if i > 0:
|
|
||||||
effects.append(CrossFadeIn(transition_duration))
|
|
||||||
if i < len(clips) - 1:
|
|
||||||
effects.append(CrossFadeOut(transition_duration))
|
|
||||||
if effects:
|
|
||||||
clip = clip.with_effects(effects)
|
|
||||||
faded_clips.append(clip)
|
|
||||||
final = concatenate_videoclips(
|
|
||||||
faded_clips,
|
|
||||||
method="compose",
|
|
||||||
padding=-transition_duration,
|
|
||||||
)
|
|
||||||
elif transition == "fade_black":
|
|
||||||
for clip in clips:
|
|
||||||
faded = clip.with_effects(
|
|
||||||
[FadeIn(transition_duration), FadeOut(transition_duration)]
|
|
||||||
)
|
|
||||||
faded_clips.append(faded)
|
|
||||||
final = concatenate_videoclips(faded_clips)
|
|
||||||
else:
|
|
||||||
final = concatenate_videoclips(clips)
|
|
||||||
|
|
||||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
|
||||||
final.write_videofile(
|
|
||||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
|
||||||
)
|
|
||||||
|
|
||||||
return final.duration
|
|
||||||
finally:
|
|
||||||
if final:
|
|
||||||
final.close()
|
|
||||||
for clip in faded_clips:
|
|
||||||
clip.close()
|
|
||||||
for clip in clips:
|
|
||||||
clip.close()
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
node_exec_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# Validate minimum clips
|
|
||||||
if len(input_data.videos) < 2:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message="At least 2 videos are required for concatenation",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
|
|
||||||
# Store all input videos locally
|
|
||||||
video_abspaths = []
|
|
||||||
for video in input_data.videos:
|
|
||||||
local_path = await self._store_input_video(execution_context, video)
|
|
||||||
video_abspaths.append(
|
|
||||||
get_exec_file_path(execution_context.graph_exec_id, local_path)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build output path
|
|
||||||
source = (
|
|
||||||
extract_source_name(video_abspaths[0]) if video_abspaths else "video"
|
|
||||||
)
|
|
||||||
output_filename = MediaFileType(
|
|
||||||
f"{node_exec_id}_concat_{source}.{input_data.output_format}"
|
|
||||||
)
|
|
||||||
output_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
total_duration = self._concat_videos(
|
|
||||||
video_abspaths,
|
|
||||||
output_abspath,
|
|
||||||
input_data.transition,
|
|
||||||
input_data.transition_duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return as workspace path or data URI based on context
|
|
||||||
video_out = await self._store_output_video(
|
|
||||||
execution_context, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
yield "total_duration", total_duration
|
|
||||||
|
|
||||||
except BlockExecutionError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Failed to concatenate videos: {e}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
) from e
|
|
||||||
@@ -1,172 +0,0 @@
|
|||||||
"""VideoDownloadBlock - Download video from URL (YouTube, Vimeo, news sites, direct links)."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import typing
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import yt_dlp
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from yt_dlp import _Params
|
|
||||||
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.exceptions import BlockExecutionError
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class VideoDownloadBlock(Block):
|
|
||||||
"""Download video from URL using yt-dlp."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
url: str = SchemaField(
|
|
||||||
description="URL of the video to download (YouTube, Vimeo, direct link, etc.)",
|
|
||||||
placeholder="https://www.youtube.com/watch?v=...",
|
|
||||||
)
|
|
||||||
quality: Literal["best", "1080p", "720p", "480p", "audio_only"] = SchemaField(
|
|
||||||
description="Video quality preference", default="720p"
|
|
||||||
)
|
|
||||||
output_format: Literal["mp4", "webm", "mkv"] = SchemaField(
|
|
||||||
description="Output video format", default="mp4", advanced=True
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_file: MediaFileType = SchemaField(
|
|
||||||
description="Downloaded video (path or data URI)"
|
|
||||||
)
|
|
||||||
duration: float = SchemaField(description="Video duration in seconds")
|
|
||||||
title: str = SchemaField(description="Video title from source")
|
|
||||||
source_url: str = SchemaField(description="Original source URL")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="c35daabb-cd60-493b-b9ad-51f1fe4b50c4",
|
|
||||||
description="Download video from URL (YouTube, Vimeo, news sites, direct links)",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
disabled=True, # Disable until we can sandbox yt-dlp and handle security implications
|
|
||||||
test_input={
|
|
||||||
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
|
||||||
"quality": "480p",
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
("video_file", str),
|
|
||||||
("duration", float),
|
|
||||||
("title", str),
|
|
||||||
("source_url", str),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_download_video": lambda *args: (
|
|
||||||
"video.mp4",
|
|
||||||
212.0,
|
|
||||||
"Test Video",
|
|
||||||
),
|
|
||||||
"_store_output_video": lambda *args, **kwargs: "video.mp4",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_output_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store output video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_format_string(self, quality: str) -> str:
|
|
||||||
formats = {
|
|
||||||
"best": "bestvideo+bestaudio/best",
|
|
||||||
"1080p": "bestvideo[height<=1080]+bestaudio/best[height<=1080]",
|
|
||||||
"720p": "bestvideo[height<=720]+bestaudio/best[height<=720]",
|
|
||||||
"480p": "bestvideo[height<=480]+bestaudio/best[height<=480]",
|
|
||||||
"audio_only": "bestaudio/best",
|
|
||||||
}
|
|
||||||
return formats.get(quality, formats["720p"])
|
|
||||||
|
|
||||||
def _download_video(
|
|
||||||
self,
|
|
||||||
url: str,
|
|
||||||
quality: str,
|
|
||||||
output_format: str,
|
|
||||||
output_dir: str,
|
|
||||||
node_exec_id: str,
|
|
||||||
) -> tuple[str, float, str]:
|
|
||||||
"""Download video. Extracted for testability."""
|
|
||||||
output_template = os.path.join(
|
|
||||||
output_dir, f"{node_exec_id}_%(title).50s.%(ext)s"
|
|
||||||
)
|
|
||||||
|
|
||||||
ydl_opts: "_Params" = {
|
|
||||||
"format": f"{self._get_format_string(quality)}/best",
|
|
||||||
"outtmpl": output_template,
|
|
||||||
"merge_output_format": output_format,
|
|
||||||
"quiet": True,
|
|
||||||
"no_warnings": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
|
||||||
info = ydl.extract_info(url, download=True)
|
|
||||||
video_path = ydl.prepare_filename(info)
|
|
||||||
|
|
||||||
# Handle format conversion in filename
|
|
||||||
if not video_path.endswith(f".{output_format}"):
|
|
||||||
video_path = video_path.rsplit(".", 1)[0] + f".{output_format}"
|
|
||||||
|
|
||||||
# Return just the filename, not the full path
|
|
||||||
filename = os.path.basename(video_path)
|
|
||||||
|
|
||||||
return (
|
|
||||||
filename,
|
|
||||||
info.get("duration") or 0.0,
|
|
||||||
info.get("title") or "Unknown",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
node_exec_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
|
|
||||||
# Get the exec file directory
|
|
||||||
output_dir = get_exec_file_path(execution_context.graph_exec_id, "")
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
filename, duration, title = self._download_video(
|
|
||||||
input_data.url,
|
|
||||||
input_data.quality,
|
|
||||||
input_data.output_format,
|
|
||||||
output_dir,
|
|
||||||
node_exec_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return as workspace path or data URI based on context
|
|
||||||
video_out = await self._store_output_video(
|
|
||||||
execution_context, MediaFileType(filename)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_file", video_out
|
|
||||||
yield "duration", duration
|
|
||||||
yield "title", title
|
|
||||||
yield "source_url", input_data.url
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Failed to download video: {e}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
) from e
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
"""MediaDurationBlock - Get the duration of a media file."""
|
|
||||||
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import strip_chapters_inplace
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class MediaDurationBlock(Block):
|
|
||||||
"""Get the duration of a media file (video or audio)."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
media_in: MediaFileType = SchemaField(
|
|
||||||
description="Media input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
is_video: bool = SchemaField(
|
|
||||||
description="Whether the media is a video (True) or audio (False).",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
duration: float = SchemaField(
|
|
||||||
description="Duration of the media file (in seconds)."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
|
||||||
description="Block to get the duration of a media file.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=MediaDurationBlock.Input,
|
|
||||||
output_schema=MediaDurationBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# 1) Store the input media locally
|
|
||||||
local_media_path = await store_media_file(
|
|
||||||
file=input_data.media_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
media_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, local_media_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2) Strip chapters to avoid MoviePy crash, then load the clip
|
|
||||||
strip_chapters_inplace(media_abspath)
|
|
||||||
clip = None
|
|
||||||
try:
|
|
||||||
if input_data.is_video:
|
|
||||||
clip = VideoFileClip(media_abspath)
|
|
||||||
else:
|
|
||||||
clip = AudioFileClip(media_abspath)
|
|
||||||
|
|
||||||
duration = clip.duration
|
|
||||||
finally:
|
|
||||||
if clip:
|
|
||||||
clip.close()
|
|
||||||
|
|
||||||
yield "duration", duration
|
|
||||||
@@ -1,115 +0,0 @@
|
|||||||
"""LoopVideoBlock - Loop a video to a given duration or number of repeats."""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from moviepy.video.fx.Loop import Loop
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class LoopVideoBlock(Block):
|
|
||||||
"""Loop (repeat) a video clip until a given duration or number of loops."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="The input video (can be a URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
duration: Optional[float] = SchemaField(
|
|
||||||
description="Target duration (in seconds) to loop the video to. Either duration or n_loops must be provided.",
|
|
||||||
default=None,
|
|
||||||
ge=0.0,
|
|
||||||
le=3600.0, # Max 1 hour to prevent disk exhaustion
|
|
||||||
)
|
|
||||||
n_loops: Optional[int] = SchemaField(
|
|
||||||
description="Number of times to repeat the video. Either n_loops or duration must be provided.",
|
|
||||||
default=None,
|
|
||||||
ge=1,
|
|
||||||
le=10, # Max 10 loops to prevent disk exhaustion
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Looped video returned either as a relative path or a data URI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
|
||||||
description="Block to loop a video to a given duration or number of repeats.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=LoopVideoBlock.Input,
|
|
||||||
output_schema=LoopVideoBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
assert execution_context.node_exec_id is not None
|
|
||||||
graph_exec_id = execution_context.graph_exec_id
|
|
||||||
node_exec_id = execution_context.node_exec_id
|
|
||||||
|
|
||||||
# 1) Store the input video locally
|
|
||||||
local_video_path = await store_media_file(
|
|
||||||
file=input_data.video_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
|
||||||
|
|
||||||
# 2) Load the clip
|
|
||||||
strip_chapters_inplace(input_abspath)
|
|
||||||
clip = None
|
|
||||||
looped_clip = None
|
|
||||||
try:
|
|
||||||
clip = VideoFileClip(input_abspath)
|
|
||||||
|
|
||||||
# 3) Apply the loop effect
|
|
||||||
if input_data.duration:
|
|
||||||
# Loop until we reach the specified duration
|
|
||||||
looped_clip = clip.with_effects([Loop(duration=input_data.duration)])
|
|
||||||
elif input_data.n_loops:
|
|
||||||
looped_clip = clip.with_effects([Loop(n=input_data.n_loops)])
|
|
||||||
else:
|
|
||||||
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
|
||||||
|
|
||||||
assert isinstance(looped_clip, VideoFileClip)
|
|
||||||
|
|
||||||
# 4) Save the looped output
|
|
||||||
source = extract_source_name(local_video_path)
|
|
||||||
output_filename = MediaFileType(f"{node_exec_id}_looped_{source}.mp4")
|
|
||||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
|
||||||
|
|
||||||
looped_clip = looped_clip.with_audio(clip.audio)
|
|
||||||
looped_clip.write_videofile(
|
|
||||||
output_abspath, codec="libx264", audio_codec="aac"
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if looped_clip:
|
|
||||||
looped_clip.close()
|
|
||||||
if clip:
|
|
||||||
clip.close()
|
|
||||||
|
|
||||||
# Return output - for_block_output returns workspace:// if available, else data URI
|
|
||||||
video_out = await store_media_file(
|
|
||||||
file=output_filename,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
@@ -1,267 +0,0 @@
|
|||||||
"""VideoNarrationBlock - Generate AI voice narration and add to video."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from elevenlabs import ElevenLabs
|
|
||||||
from moviepy import CompositeAudioClip
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.elevenlabs._auth import (
|
|
||||||
TEST_CREDENTIALS,
|
|
||||||
TEST_CREDENTIALS_INPUT,
|
|
||||||
ElevenLabsCredentials,
|
|
||||||
ElevenLabsCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.blocks.video._utils import (
|
|
||||||
extract_source_name,
|
|
||||||
get_video_codecs,
|
|
||||||
strip_chapters_inplace,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
|
||||||
from backend.util.exceptions import BlockExecutionError
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class VideoNarrationBlock(Block):
|
|
||||||
"""Generate AI narration and add to video."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: ElevenLabsCredentialsInput = CredentialsField(
|
|
||||||
description="ElevenLabs API key for voice synthesis"
|
|
||||||
)
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="Input video (URL, data URI, or local path)"
|
|
||||||
)
|
|
||||||
script: str = SchemaField(description="Narration script text")
|
|
||||||
voice_id: str = SchemaField(
|
|
||||||
description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel
|
|
||||||
)
|
|
||||||
model_id: Literal[
|
|
||||||
"eleven_multilingual_v2",
|
|
||||||
"eleven_flash_v2_5",
|
|
||||||
"eleven_turbo_v2_5",
|
|
||||||
"eleven_turbo_v2",
|
|
||||||
] = SchemaField(
|
|
||||||
description="ElevenLabs TTS model",
|
|
||||||
default="eleven_multilingual_v2",
|
|
||||||
)
|
|
||||||
mix_mode: Literal["replace", "mix", "ducking"] = SchemaField(
|
|
||||||
description="How to combine with original audio. 'ducking' applies stronger attenuation than 'mix'.",
|
|
||||||
default="ducking",
|
|
||||||
)
|
|
||||||
narration_volume: float = SchemaField(
|
|
||||||
description="Narration volume (0.0 to 2.0)",
|
|
||||||
default=1.0,
|
|
||||||
ge=0.0,
|
|
||||||
le=2.0,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
original_volume: float = SchemaField(
|
|
||||||
description="Original audio volume when mixing (0.0 to 1.0)",
|
|
||||||
default=0.3,
|
|
||||||
ge=0.0,
|
|
||||||
le=1.0,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Video with narration (path or data URI)"
|
|
||||||
)
|
|
||||||
audio_file: MediaFileType = SchemaField(
|
|
||||||
description="Generated audio file (path or data URI)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="3d036b53-859c-4b17-9826-ca340f736e0e",
|
|
||||||
description="Generate AI narration and add to video",
|
|
||||||
categories={BlockCategory.MULTIMEDIA, BlockCategory.AI},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
test_input={
|
|
||||||
"video_in": "/tmp/test.mp4",
|
|
||||||
"script": "Hello world",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[("video_out", str), ("audio_file", str)],
|
|
||||||
test_mock={
|
|
||||||
"_generate_narration_audio": lambda *args: b"mock audio content",
|
|
||||||
"_add_narration_to_video": lambda *args: None,
|
|
||||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
|
||||||
"_store_output_video": lambda *args, **kwargs: "narrated_test.mp4",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_input_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store input video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_output_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store output video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _generate_narration_audio(
|
|
||||||
self, api_key: str, script: str, voice_id: str, model_id: str
|
|
||||||
) -> bytes:
|
|
||||||
"""Generate narration audio via ElevenLabs API."""
|
|
||||||
client = ElevenLabs(api_key=api_key)
|
|
||||||
audio_generator = client.text_to_speech.convert(
|
|
||||||
voice_id=voice_id,
|
|
||||||
text=script,
|
|
||||||
model_id=model_id,
|
|
||||||
)
|
|
||||||
# The SDK returns a generator, collect all chunks
|
|
||||||
return b"".join(audio_generator)
|
|
||||||
|
|
||||||
def _add_narration_to_video(
|
|
||||||
self,
|
|
||||||
video_abspath: str,
|
|
||||||
audio_abspath: str,
|
|
||||||
output_abspath: str,
|
|
||||||
mix_mode: str,
|
|
||||||
narration_volume: float,
|
|
||||||
original_volume: float,
|
|
||||||
) -> None:
|
|
||||||
"""Add narration audio to video. Extracted for testability."""
|
|
||||||
video = None
|
|
||||||
final = None
|
|
||||||
narration_original = None
|
|
||||||
narration_scaled = None
|
|
||||||
original = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
strip_chapters_inplace(video_abspath)
|
|
||||||
video = VideoFileClip(video_abspath)
|
|
||||||
narration_original = AudioFileClip(audio_abspath)
|
|
||||||
narration_scaled = narration_original.with_volume_scaled(narration_volume)
|
|
||||||
narration = narration_scaled
|
|
||||||
|
|
||||||
if mix_mode == "replace":
|
|
||||||
final_audio = narration
|
|
||||||
elif mix_mode == "mix":
|
|
||||||
if video.audio:
|
|
||||||
original = video.audio.with_volume_scaled(original_volume)
|
|
||||||
final_audio = CompositeAudioClip([original, narration])
|
|
||||||
else:
|
|
||||||
final_audio = narration
|
|
||||||
else: # ducking - apply stronger attenuation
|
|
||||||
if video.audio:
|
|
||||||
# Ducking uses a much lower volume for original audio
|
|
||||||
ducking_volume = original_volume * 0.3
|
|
||||||
original = video.audio.with_volume_scaled(ducking_volume)
|
|
||||||
final_audio = CompositeAudioClip([original, narration])
|
|
||||||
else:
|
|
||||||
final_audio = narration
|
|
||||||
|
|
||||||
final = video.with_audio(final_audio)
|
|
||||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
|
||||||
final.write_videofile(
|
|
||||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if original:
|
|
||||||
original.close()
|
|
||||||
if narration_scaled:
|
|
||||||
narration_scaled.close()
|
|
||||||
if narration_original:
|
|
||||||
narration_original.close()
|
|
||||||
if final:
|
|
||||||
final.close()
|
|
||||||
if video:
|
|
||||||
video.close()
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: ElevenLabsCredentials,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
node_exec_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
|
|
||||||
# Store the input video locally
|
|
||||||
local_video_path = await self._store_input_video(
|
|
||||||
execution_context, input_data.video_in
|
|
||||||
)
|
|
||||||
video_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, local_video_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate narration audio via ElevenLabs
|
|
||||||
audio_content = self._generate_narration_audio(
|
|
||||||
credentials.api_key.get_secret_value(),
|
|
||||||
input_data.script,
|
|
||||||
input_data.voice_id,
|
|
||||||
input_data.model_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save audio to exec file path
|
|
||||||
audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3")
|
|
||||||
audio_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, audio_filename
|
|
||||||
)
|
|
||||||
os.makedirs(os.path.dirname(audio_abspath), exist_ok=True)
|
|
||||||
with open(audio_abspath, "wb") as f:
|
|
||||||
f.write(audio_content)
|
|
||||||
|
|
||||||
# Add narration to video
|
|
||||||
source = extract_source_name(local_video_path)
|
|
||||||
output_filename = MediaFileType(f"{node_exec_id}_narrated_{source}.mp4")
|
|
||||||
output_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
self._add_narration_to_video(
|
|
||||||
video_abspath,
|
|
||||||
audio_abspath,
|
|
||||||
output_abspath,
|
|
||||||
input_data.mix_mode,
|
|
||||||
input_data.narration_volume,
|
|
||||||
input_data.original_volume,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return as workspace path or data URI based on context
|
|
||||||
video_out = await self._store_output_video(
|
|
||||||
execution_context, output_filename
|
|
||||||
)
|
|
||||||
audio_out = await self._store_output_video(
|
|
||||||
execution_context, audio_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
yield "audio_file", audio_out
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Failed to add narration: {e}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
) from e
|
|
||||||
@@ -1,231 +0,0 @@
|
|||||||
"""VideoTextOverlayBlock - Add text overlay to video."""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from moviepy import CompositeVideoClip, TextClip
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import (
|
|
||||||
extract_source_name,
|
|
||||||
get_video_codecs,
|
|
||||||
strip_chapters_inplace,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.exceptions import BlockExecutionError
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class VideoTextOverlayBlock(Block):
|
|
||||||
"""Add text overlay/caption to video."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="Input video (URL, data URI, or local path)"
|
|
||||||
)
|
|
||||||
text: str = SchemaField(description="Text to overlay on video")
|
|
||||||
position: Literal[
|
|
||||||
"top",
|
|
||||||
"center",
|
|
||||||
"bottom",
|
|
||||||
"top-left",
|
|
||||||
"top-right",
|
|
||||||
"bottom-left",
|
|
||||||
"bottom-right",
|
|
||||||
] = SchemaField(description="Position of text on screen", default="bottom")
|
|
||||||
start_time: float | None = SchemaField(
|
|
||||||
description="When to show text (seconds). None = entire video",
|
|
||||||
default=None,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
end_time: float | None = SchemaField(
|
|
||||||
description="When to hide text (seconds). None = until end",
|
|
||||||
default=None,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
font_size: int = SchemaField(
|
|
||||||
description="Font size", default=48, ge=12, le=200, advanced=True
|
|
||||||
)
|
|
||||||
font_color: str = SchemaField(
|
|
||||||
description="Font color (hex or name)", default="white", advanced=True
|
|
||||||
)
|
|
||||||
bg_color: str | None = SchemaField(
|
|
||||||
description="Background color behind text (None for transparent)",
|
|
||||||
default=None,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Video with text overlay (path or data URI)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8ef14de6-cc90-430a-8cfa-3a003be92454",
|
|
||||||
description="Add text overlay/caption to video",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
disabled=True, # Disable until we can lockdown imagemagick security policy
|
|
||||||
test_input={"video_in": "/tmp/test.mp4", "text": "Hello World"},
|
|
||||||
test_output=[("video_out", str)],
|
|
||||||
test_mock={
|
|
||||||
"_add_text_overlay": lambda *args: None,
|
|
||||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
|
||||||
"_store_output_video": lambda *args, **kwargs: "overlay_test.mp4",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_input_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store input video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_output_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store output video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_text_overlay(
|
|
||||||
self,
|
|
||||||
video_abspath: str,
|
|
||||||
output_abspath: str,
|
|
||||||
text: str,
|
|
||||||
position: str,
|
|
||||||
start_time: float | None,
|
|
||||||
end_time: float | None,
|
|
||||||
font_size: int,
|
|
||||||
font_color: str,
|
|
||||||
bg_color: str | None,
|
|
||||||
) -> None:
|
|
||||||
"""Add text overlay to video. Extracted for testability."""
|
|
||||||
video = None
|
|
||||||
final = None
|
|
||||||
txt_clip = None
|
|
||||||
try:
|
|
||||||
strip_chapters_inplace(video_abspath)
|
|
||||||
video = VideoFileClip(video_abspath)
|
|
||||||
|
|
||||||
txt_clip = TextClip(
|
|
||||||
text=text,
|
|
||||||
font_size=font_size,
|
|
||||||
color=font_color,
|
|
||||||
bg_color=bg_color,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Position mapping
|
|
||||||
pos_map = {
|
|
||||||
"top": ("center", "top"),
|
|
||||||
"center": ("center", "center"),
|
|
||||||
"bottom": ("center", "bottom"),
|
|
||||||
"top-left": ("left", "top"),
|
|
||||||
"top-right": ("right", "top"),
|
|
||||||
"bottom-left": ("left", "bottom"),
|
|
||||||
"bottom-right": ("right", "bottom"),
|
|
||||||
}
|
|
||||||
|
|
||||||
txt_clip = txt_clip.with_position(pos_map[position])
|
|
||||||
|
|
||||||
# Set timing
|
|
||||||
start = start_time or 0
|
|
||||||
end = end_time or video.duration
|
|
||||||
duration = max(0, end - start)
|
|
||||||
txt_clip = txt_clip.with_start(start).with_end(end).with_duration(duration)
|
|
||||||
|
|
||||||
final = CompositeVideoClip([video, txt_clip])
|
|
||||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
|
||||||
final.write_videofile(
|
|
||||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if txt_clip:
|
|
||||||
txt_clip.close()
|
|
||||||
if final:
|
|
||||||
final.close()
|
|
||||||
if video:
|
|
||||||
video.close()
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
node_exec_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# Validate time range if both are provided
|
|
||||||
if (
|
|
||||||
input_data.start_time is not None
|
|
||||||
and input_data.end_time is not None
|
|
||||||
and input_data.end_time <= input_data.start_time
|
|
||||||
):
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
|
|
||||||
# Store the input video locally
|
|
||||||
local_video_path = await self._store_input_video(
|
|
||||||
execution_context, input_data.video_in
|
|
||||||
)
|
|
||||||
video_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, local_video_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build output path
|
|
||||||
source = extract_source_name(local_video_path)
|
|
||||||
output_filename = MediaFileType(f"{node_exec_id}_overlay_{source}.mp4")
|
|
||||||
output_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
self._add_text_overlay(
|
|
||||||
video_abspath,
|
|
||||||
output_abspath,
|
|
||||||
input_data.text,
|
|
||||||
input_data.position,
|
|
||||||
input_data.start_time,
|
|
||||||
input_data.end_time,
|
|
||||||
input_data.font_size,
|
|
||||||
input_data.font_color,
|
|
||||||
input_data.bg_color,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return as workspace path or data URI based on context
|
|
||||||
video_out = await self._store_output_video(
|
|
||||||
execution_context, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
|
|
||||||
except BlockExecutionError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Failed to add text overlay: {e}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
) from e
|
|
||||||
@@ -165,13 +165,10 @@ class TranscribeYoutubeVideoBlock(Block):
|
|||||||
credentials: WebshareProxyCredentials,
|
credentials: WebshareProxyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
video_id = self.extract_video_id(input_data.youtube_url)
|
||||||
video_id = self.extract_video_id(input_data.youtube_url)
|
yield "video_id", video_id
|
||||||
transcript = self.get_transcript(video_id, credentials)
|
|
||||||
transcript_text = self.format_transcript(transcript=transcript)
|
|
||||||
|
|
||||||
# Only yield after all operations succeed
|
transcript = self.get_transcript(video_id, credentials)
|
||||||
yield "video_id", video_id
|
transcript_text = self.format_transcript(transcript=transcript)
|
||||||
yield "transcript", transcript_text
|
|
||||||
except Exception as e:
|
yield "transcript", transcript_text
|
||||||
yield "error", str(e)
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest_asyncio
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from backend.util.logging import configure_logging
|
from backend.util.logging import configure_logging
|
||||||
@@ -19,7 +19,7 @@ if not os.getenv("PRISMA_DEBUG"):
|
|||||||
prisma_logger.setLevel(logging.INFO)
|
prisma_logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||||
async def server():
|
async def server():
|
||||||
from backend.util.test import SpinTestServer
|
from backend.util.test import SpinTestServer
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ async def server():
|
|||||||
yield server
|
yield server
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
|
||||||
async def graph_cleanup(server):
|
async def graph_cleanup(server):
|
||||||
created_graph_ids = []
|
created_graph_ids = []
|
||||||
original_create_graph = server.agent_server.test_create_graph
|
original_create_graph = server.agent_server.test_create_graph
|
||||||
|
|||||||
@@ -246,9 +246,7 @@ class BlockSchema(BaseModel):
|
|||||||
f"is not of type {CredentialsMetaInput.__name__}"
|
f"is not of type {CredentialsMetaInput.__name__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
CredentialsMetaInput.validate_credentials_field_schema(
|
credentials_fields[field_name].validate_credentials_field_schema(cls)
|
||||||
cls.get_field_schema(field_name), field_name
|
|
||||||
)
|
|
||||||
|
|
||||||
elif field_name in credentials_fields:
|
elif field_name in credentials_fields:
|
||||||
raise KeyError(
|
raise KeyError(
|
||||||
|
|||||||
@@ -36,14 +36,12 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
|||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||||
from backend.blocks.video.narration import VideoNarrationBlock
|
|
||||||
from backend.data.block import Block, BlockCost, BlockCostType
|
from backend.data.block import Block, BlockCost, BlockCostType
|
||||||
from backend.integrations.credentials_store import (
|
from backend.integrations.credentials_store import (
|
||||||
aiml_api_credentials,
|
aiml_api_credentials,
|
||||||
anthropic_credentials,
|
anthropic_credentials,
|
||||||
apollo_credentials,
|
apollo_credentials,
|
||||||
did_credentials,
|
did_credentials,
|
||||||
elevenlabs_credentials,
|
|
||||||
enrichlayer_credentials,
|
enrichlayer_credentials,
|
||||||
groq_credentials,
|
groq_credentials,
|
||||||
ideogram_credentials,
|
ideogram_credentials,
|
||||||
@@ -80,7 +78,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
|||||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||||
LlmModel.CLAUDE_4_OPUS: 21,
|
LlmModel.CLAUDE_4_OPUS: 21,
|
||||||
LlmModel.CLAUDE_4_SONNET: 5,
|
LlmModel.CLAUDE_4_SONNET: 5,
|
||||||
LlmModel.CLAUDE_4_6_OPUS: 14,
|
|
||||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||||
@@ -642,16 +639,4 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
VideoNarrationBlock: [
|
|
||||||
BlockCost(
|
|
||||||
cost_amount=5, # ElevenLabs TTS cost
|
|
||||||
cost_filter={
|
|
||||||
"credentials": {
|
|
||||||
"id": elevenlabs_credentials.id,
|
|
||||||
"provider": elevenlabs_credentials.provider,
|
|
||||||
"type": elevenlabs_credentials.type,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,16 +134,6 @@ async def test_block_credit_reset(server: SpinTestServer):
|
|||||||
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
||||||
user_credit.time_now = lambda: month1
|
user_credit.time_now = lambda: month1
|
||||||
|
|
||||||
# IMPORTANT: Set updatedAt to December of previous year to ensure it's
|
|
||||||
# in a different month than month1 (January). This fixes a timing bug
|
|
||||||
# where if the test runs in early February, 35 days ago would be January,
|
|
||||||
# matching the mocked month1 and preventing the refill from triggering.
|
|
||||||
dec_previous_year = month1.replace(year=month1.year - 1, month=12, day=15)
|
|
||||||
await UserBalance.prisma().update(
|
|
||||||
where={"userId": DEFAULT_USER_ID},
|
|
||||||
data={"updatedAt": dec_previous_year},
|
|
||||||
)
|
|
||||||
|
|
||||||
# First call in month 1 should trigger refill
|
# First call in month 1 should trigger refill
|
||||||
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||||
assert balance == REFILL_VALUE # Should get 1000 credits
|
assert balance == REFILL_VALUE # Should get 1000 credits
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import queue
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from multiprocessing import Manager
|
||||||
|
from queue import Empty
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Annotated,
|
Annotated,
|
||||||
@@ -1199,16 +1200,12 @@ class NodeExecutionEntry(BaseModel):
|
|||||||
|
|
||||||
class ExecutionQueue(Generic[T]):
|
class ExecutionQueue(Generic[T]):
|
||||||
"""
|
"""
|
||||||
Thread-safe queue for managing node execution within a single graph execution.
|
Queue for managing the execution of agents.
|
||||||
|
This will be shared between different processes
|
||||||
Note: Uses queue.Queue (not multiprocessing.Queue) since all access is from
|
|
||||||
threads within the same process. If migrating back to ProcessPoolExecutor,
|
|
||||||
replace with multiprocessing.Manager().Queue() for cross-process safety.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Thread-safe queue (not multiprocessing) — see class docstring
|
self.queue = Manager().Queue()
|
||||||
self.queue: queue.Queue[T] = queue.Queue()
|
|
||||||
|
|
||||||
def add(self, execution: T) -> T:
|
def add(self, execution: T) -> T:
|
||||||
self.queue.put(execution)
|
self.queue.put(execution)
|
||||||
@@ -1223,7 +1220,7 @@ class ExecutionQueue(Generic[T]):
|
|||||||
def get_or_none(self) -> T | None:
|
def get_or_none(self) -> T | None:
|
||||||
try:
|
try:
|
||||||
return self.queue.get_nowait()
|
return self.queue.get_nowait()
|
||||||
except queue.Empty:
|
except Empty:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
"""Tests for ExecutionQueue thread-safety."""
|
|
||||||
|
|
||||||
import queue
|
|
||||||
import threading
|
|
||||||
|
|
||||||
from backend.data.execution import ExecutionQueue
|
|
||||||
|
|
||||||
|
|
||||||
def test_execution_queue_uses_stdlib_queue():
|
|
||||||
"""Verify ExecutionQueue uses queue.Queue (not multiprocessing)."""
|
|
||||||
q = ExecutionQueue()
|
|
||||||
assert isinstance(q.queue, queue.Queue)
|
|
||||||
|
|
||||||
|
|
||||||
def test_basic_operations():
|
|
||||||
"""Test add, get, empty, and get_or_none."""
|
|
||||||
q = ExecutionQueue()
|
|
||||||
|
|
||||||
assert q.empty() is True
|
|
||||||
assert q.get_or_none() is None
|
|
||||||
|
|
||||||
result = q.add("item1")
|
|
||||||
assert result == "item1"
|
|
||||||
assert q.empty() is False
|
|
||||||
|
|
||||||
item = q.get()
|
|
||||||
assert item == "item1"
|
|
||||||
assert q.empty() is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_thread_safety():
|
|
||||||
"""Test concurrent access from multiple threads."""
|
|
||||||
q = ExecutionQueue()
|
|
||||||
results = []
|
|
||||||
num_items = 100
|
|
||||||
|
|
||||||
def producer():
|
|
||||||
for i in range(num_items):
|
|
||||||
q.add(f"item_{i}")
|
|
||||||
|
|
||||||
def consumer():
|
|
||||||
count = 0
|
|
||||||
while count < num_items:
|
|
||||||
item = q.get_or_none()
|
|
||||||
if item is not None:
|
|
||||||
results.append(item)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
producer_thread = threading.Thread(target=producer)
|
|
||||||
consumer_thread = threading.Thread(target=consumer)
|
|
||||||
|
|
||||||
producer_thread.start()
|
|
||||||
consumer_thread.start()
|
|
||||||
|
|
||||||
producer_thread.join(timeout=5)
|
|
||||||
consumer_thread.join(timeout=5)
|
|
||||||
|
|
||||||
assert len(results) == num_items
|
|
||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Self, cast
|
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, cast
|
||||||
|
|
||||||
from prisma.enums import SubmissionStatus
|
from prisma.enums import SubmissionStatus
|
||||||
from prisma.models import (
|
from prisma.models import (
|
||||||
@@ -20,7 +20,7 @@ from prisma.types import (
|
|||||||
AgentNodeLinkCreateInput,
|
AgentNodeLinkCreateInput,
|
||||||
StoreListingVersionWhereInput,
|
StoreListingVersionWhereInput,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, BeforeValidator, Field
|
from pydantic import BaseModel, BeforeValidator, Field, create_model
|
||||||
from pydantic.fields import computed_field
|
from pydantic.fields import computed_field
|
||||||
|
|
||||||
from backend.blocks.agent import AgentExecutorBlock
|
from backend.blocks.agent import AgentExecutorBlock
|
||||||
@@ -30,6 +30,7 @@ from backend.data.db import prisma as db
|
|||||||
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
||||||
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
|
CredentialsField,
|
||||||
CredentialsFieldInfo,
|
CredentialsFieldInfo,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
is_credentials_field_name,
|
is_credentials_field_name,
|
||||||
@@ -39,12 +40,12 @@ from backend.util import type as type_utils
|
|||||||
from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError
|
from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
from backend.util.request import parse_url
|
|
||||||
|
|
||||||
from .block import (
|
from .block import (
|
||||||
AnyBlockSchema,
|
AnyBlockSchema,
|
||||||
Block,
|
Block,
|
||||||
BlockInput,
|
BlockInput,
|
||||||
|
BlockSchema,
|
||||||
BlockType,
|
BlockType,
|
||||||
EmptySchema,
|
EmptySchema,
|
||||||
get_block,
|
get_block,
|
||||||
@@ -112,12 +113,10 @@ class Link(BaseDbModel):
|
|||||||
|
|
||||||
class Node(BaseDbModel):
|
class Node(BaseDbModel):
|
||||||
block_id: str
|
block_id: str
|
||||||
input_default: BlockInput = Field( # dict[input_name, default_value]
|
input_default: BlockInput = {} # dict[input_name, default_value]
|
||||||
default_factory=dict
|
metadata: dict[str, Any] = {}
|
||||||
)
|
input_links: list[Link] = []
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
output_links: list[Link] = []
|
||||||
input_links: list[Link] = Field(default_factory=list)
|
|
||||||
output_links: list[Link] = Field(default_factory=list)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def credentials_optional(self) -> bool:
|
def credentials_optional(self) -> bool:
|
||||||
@@ -222,33 +221,18 @@ class NodeModel(Node):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class GraphBaseMeta(BaseDbModel):
|
class BaseGraph(BaseDbModel):
|
||||||
"""
|
|
||||||
Shared base for `GraphMeta` and `BaseGraph`, with core graph metadata fields.
|
|
||||||
"""
|
|
||||||
|
|
||||||
version: int = 1
|
version: int = 1
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
instructions: str | None = None
|
instructions: str | None = None
|
||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
|
nodes: list[Node] = []
|
||||||
|
links: list[Link] = []
|
||||||
forked_from_id: str | None = None
|
forked_from_id: str | None = None
|
||||||
forked_from_version: int | None = None
|
forked_from_version: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class BaseGraph(GraphBaseMeta):
|
|
||||||
"""
|
|
||||||
Graph with nodes, links, and computed I/O schema fields.
|
|
||||||
|
|
||||||
Used to represent sub-graphs within a `Graph`. Contains the full graph
|
|
||||||
structure including nodes and links, plus computed fields for schemas
|
|
||||||
and trigger info. Does NOT include user_id or created_at (see GraphModel).
|
|
||||||
"""
|
|
||||||
|
|
||||||
nodes: list[Node] = Field(default_factory=list)
|
|
||||||
links: list[Link] = Field(default_factory=list)
|
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def input_schema(self) -> dict[str, Any]:
|
def input_schema(self) -> dict[str, Any]:
|
||||||
@@ -377,79 +361,44 @@ class GraphTriggerInfo(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Graph(BaseGraph):
|
class Graph(BaseGraph):
|
||||||
"""Creatable graph model used in API create/update endpoints."""
|
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
|
||||||
|
|
||||||
sub_graphs: list[BaseGraph] = Field(default_factory=list) # Flattened sub-graphs
|
|
||||||
|
|
||||||
|
|
||||||
class GraphMeta(GraphBaseMeta):
|
|
||||||
"""
|
|
||||||
Lightweight graph metadata model representing an existing graph from the database,
|
|
||||||
for use in listings and summaries.
|
|
||||||
|
|
||||||
Lacks `GraphModel`'s nodes, links, and expensive computed fields.
|
|
||||||
Use for list endpoints where full graph data is not needed and performance matters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: str # type: ignore
|
|
||||||
version: int # type: ignore
|
|
||||||
user_id: str
|
|
||||||
created_at: datetime
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_db(cls, graph: "AgentGraph") -> Self:
|
|
||||||
return cls(
|
|
||||||
id=graph.id,
|
|
||||||
version=graph.version,
|
|
||||||
is_active=graph.isActive,
|
|
||||||
name=graph.name or "",
|
|
||||||
description=graph.description or "",
|
|
||||||
instructions=graph.instructions,
|
|
||||||
recommended_schedule_cron=graph.recommendedScheduleCron,
|
|
||||||
forked_from_id=graph.forkedFromId,
|
|
||||||
forked_from_version=graph.forkedFromVersion,
|
|
||||||
user_id=graph.userId,
|
|
||||||
created_at=graph.createdAt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphModel(Graph, GraphMeta):
|
|
||||||
"""
|
|
||||||
Full graph model representing an existing graph from the database.
|
|
||||||
|
|
||||||
This is the primary model for working with persisted graphs. Includes all
|
|
||||||
graph data (nodes, links, sub_graphs) plus user ownership and timestamps.
|
|
||||||
Provides computed fields (input_schema, output_schema, etc.) used during
|
|
||||||
set-up (frontend) and execution (backend).
|
|
||||||
|
|
||||||
Inherits from:
|
|
||||||
- `Graph`: provides structure (nodes, links, sub_graphs) and computed schemas
|
|
||||||
- `GraphMeta`: provides user_id, created_at for database records
|
|
||||||
"""
|
|
||||||
|
|
||||||
nodes: list[NodeModel] = Field(default_factory=list) # type: ignore
|
|
||||||
|
|
||||||
@property
|
|
||||||
def starting_nodes(self) -> list[NodeModel]:
|
|
||||||
outbound_nodes = {link.sink_id for link in self.links}
|
|
||||||
input_nodes = {
|
|
||||||
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
|
||||||
}
|
|
||||||
return [
|
|
||||||
node
|
|
||||||
for node in self.nodes
|
|
||||||
if node.id not in outbound_nodes or node.id in input_nodes
|
|
||||||
]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
|
||||||
return cast(NodeModel, super().webhook_input_node)
|
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def credentials_input_schema(self) -> dict[str, Any]:
|
def credentials_input_schema(self) -> dict[str, Any]:
|
||||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
schema = self._credentials_input_schema.jsonschema()
|
||||||
|
|
||||||
|
# Determine which credential fields are required based on credentials_optional metadata
|
||||||
|
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||||
|
required_fields = []
|
||||||
|
|
||||||
|
# Build a map of node_id -> node for quick lookup
|
||||||
|
all_nodes = {node.id: node for node in self.nodes}
|
||||||
|
for sub_graph in self.sub_graphs:
|
||||||
|
for node in sub_graph.nodes:
|
||||||
|
all_nodes[node.id] = node
|
||||||
|
|
||||||
|
for field_key, (
|
||||||
|
_field_info,
|
||||||
|
node_field_pairs,
|
||||||
|
) in graph_credentials_inputs.items():
|
||||||
|
# A field is required if ANY node using it has credentials_optional=False
|
||||||
|
is_required = False
|
||||||
|
for node_id, _field_name in node_field_pairs:
|
||||||
|
node = all_nodes.get(node_id)
|
||||||
|
if node and not node.credentials_optional:
|
||||||
|
is_required = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if is_required:
|
||||||
|
required_fields.append(field_key)
|
||||||
|
|
||||||
|
schema["required"] = required_fields
|
||||||
|
return schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||||
|
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||||
f"{graph_credentials_inputs}"
|
f"{graph_credentials_inputs}"
|
||||||
@@ -457,15 +406,12 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
|
|
||||||
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
||||||
graph_cred_fields = list(graph_credentials_inputs.values())
|
graph_cred_fields = list(graph_credentials_inputs.values())
|
||||||
for i, (field, keys, _) in enumerate(graph_cred_fields):
|
for i, (field, keys) in enumerate(graph_cred_fields):
|
||||||
for other_field, other_keys, _ in list(graph_cred_fields)[i + 1 :]:
|
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
||||||
if field.provider != other_field.provider:
|
if field.provider != other_field.provider:
|
||||||
continue
|
continue
|
||||||
if ProviderName.HTTP in field.provider:
|
if ProviderName.HTTP in field.provider:
|
||||||
continue
|
continue
|
||||||
# MCP credentials are intentionally split by server URL
|
|
||||||
if ProviderName.MCP in field.provider:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# If this happens, that means a block implementation probably needs
|
# If this happens, that means a block implementation probably needs
|
||||||
# to be updated.
|
# to be updated.
|
||||||
@@ -477,90 +423,31 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
f"keys: {keys} <> {other_keys}."
|
f"keys: {keys} <> {other_keys}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build JSON schema directly to avoid expensive create_model + validation overhead
|
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
|
||||||
properties = {}
|
agg_field_key: (
|
||||||
required_fields = []
|
CredentialsMetaInput[
|
||||||
|
Literal[tuple(field_info.provider)], # type: ignore
|
||||||
for agg_field_key, (
|
Literal[tuple(field_info.supported_types)], # type: ignore
|
||||||
field_info,
|
],
|
||||||
_,
|
CredentialsField(
|
||||||
is_required,
|
required_scopes=set(field_info.required_scopes or []),
|
||||||
) in graph_credentials_inputs.items():
|
discriminator=field_info.discriminator,
|
||||||
providers = list(field_info.provider)
|
discriminator_mapping=field_info.discriminator_mapping,
|
||||||
cred_types = list(field_info.supported_types)
|
discriminator_values=field_info.discriminator_values,
|
||||||
|
),
|
||||||
field_schema: dict[str, Any] = {
|
|
||||||
"credentials_provider": providers,
|
|
||||||
"credentials_types": cred_types,
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"id": {"title": "Id", "type": "string"},
|
|
||||||
"title": {
|
|
||||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
|
||||||
"default": None,
|
|
||||||
"title": "Title",
|
|
||||||
},
|
|
||||||
"provider": {
|
|
||||||
"title": "Provider",
|
|
||||||
"type": "string",
|
|
||||||
**(
|
|
||||||
{"enum": providers}
|
|
||||||
if len(providers) > 1
|
|
||||||
else {"const": providers[0]}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"title": "Type",
|
|
||||||
"type": "string",
|
|
||||||
**(
|
|
||||||
{"enum": cred_types}
|
|
||||||
if len(cred_types) > 1
|
|
||||||
else {"const": cred_types[0]}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["id", "provider", "type"],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add a descriptive display title when URL-based discriminator values
|
|
||||||
# are present (e.g. "mcp.sentry.dev" instead of just "Mcp")
|
|
||||||
if (
|
|
||||||
field_info.discriminator
|
|
||||||
and not field_info.discriminator_mapping
|
|
||||||
and field_info.discriminator_values
|
|
||||||
):
|
|
||||||
hostnames = sorted(
|
|
||||||
parse_url(str(v)).netloc for v in field_info.discriminator_values
|
|
||||||
)
|
|
||||||
field_schema["display_name"] = ", ".join(hostnames)
|
|
||||||
|
|
||||||
# Add other (optional) field info items
|
|
||||||
field_schema.update(
|
|
||||||
field_info.model_dump(
|
|
||||||
by_alias=True,
|
|
||||||
exclude_defaults=True,
|
|
||||||
exclude={"provider", "supported_types"}, # already included above
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
||||||
# Ensure field schema is well-formed
|
|
||||||
CredentialsMetaInput.validate_credentials_field_schema(
|
|
||||||
field_schema, agg_field_key
|
|
||||||
)
|
|
||||||
|
|
||||||
properties[agg_field_key] = field_schema
|
|
||||||
if is_required:
|
|
||||||
required_fields.append(agg_field_key)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": properties,
|
|
||||||
"required": required_fields,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return create_model(
|
||||||
|
self.name.replace(" ", "") + "CredentialsInputSchema",
|
||||||
|
__base__=BlockSchema,
|
||||||
|
**fields, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
def aggregate_credentials_inputs(
|
def aggregate_credentials_inputs(
|
||||||
self,
|
self,
|
||||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
dict[aggregated_field_key, tuple(
|
dict[aggregated_field_key, tuple(
|
||||||
@@ -568,28 +455,13 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
(now includes discriminator_values from matching nodes)
|
(now includes discriminator_values from matching nodes)
|
||||||
set[(node_id, field_name)]: Node credentials fields that are
|
set[(node_id, field_name)]: Node credentials fields that are
|
||||||
compatible with this aggregated field spec
|
compatible with this aggregated field spec
|
||||||
bool: True if the field is required (any node has credentials_optional=False)
|
|
||||||
)]
|
)]
|
||||||
"""
|
"""
|
||||||
# First collect all credential field data with input defaults
|
# First collect all credential field data with input defaults
|
||||||
# Track (field_info, (node_id, field_name), is_required) for each credential field
|
node_credential_data = []
|
||||||
node_credential_data: list[tuple[CredentialsFieldInfo, tuple[str, str]]] = []
|
|
||||||
node_required_map: dict[str, bool] = {} # node_id -> is_required
|
|
||||||
|
|
||||||
for graph in [self] + self.sub_graphs:
|
for graph in [self] + self.sub_graphs:
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
# A node's credentials are optional if either:
|
|
||||||
# 1. The node metadata says so (credentials_optional=True), or
|
|
||||||
# 2. All credential fields on the block have defaults (not required by schema)
|
|
||||||
block_required = node.block.input_schema.get_required_fields()
|
|
||||||
creds_required_by_schema = any(
|
|
||||||
fname in block_required
|
|
||||||
for fname in node.block.input_schema.get_credentials_fields()
|
|
||||||
)
|
|
||||||
node_required_map[node.id] = (
|
|
||||||
not node.credentials_optional and creds_required_by_schema
|
|
||||||
)
|
|
||||||
|
|
||||||
for (
|
for (
|
||||||
field_name,
|
field_name,
|
||||||
field_info,
|
field_info,
|
||||||
@@ -613,21 +485,37 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Combine credential field info (this will merge discriminator_values automatically)
|
# Combine credential field info (this will merge discriminator_values automatically)
|
||||||
combined = CredentialsFieldInfo.combine(*node_credential_data)
|
return CredentialsFieldInfo.combine(*node_credential_data)
|
||||||
|
|
||||||
# Add is_required flag to each aggregated field
|
|
||||||
# A field is required if ANY node using it has credentials_optional=False
|
class GraphModel(Graph):
|
||||||
return {
|
user_id: str
|
||||||
key: (
|
nodes: list[NodeModel] = [] # type: ignore
|
||||||
field_info,
|
|
||||||
node_field_pairs,
|
created_at: datetime
|
||||||
any(
|
|
||||||
node_required_map.get(node_id, True)
|
@property
|
||||||
for node_id, _ in node_field_pairs
|
def starting_nodes(self) -> list[NodeModel]:
|
||||||
),
|
outbound_nodes = {link.sink_id for link in self.links}
|
||||||
)
|
input_nodes = {
|
||||||
for key, (field_info, node_field_pairs) in combined.items()
|
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
||||||
}
|
}
|
||||||
|
return [
|
||||||
|
node
|
||||||
|
for node in self.nodes
|
||||||
|
if node.id not in outbound_nodes or node.id in input_nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
||||||
|
return cast(NodeModel, super().webhook_input_node)
|
||||||
|
|
||||||
|
def meta(self) -> "GraphMeta":
|
||||||
|
"""
|
||||||
|
Returns a GraphMeta object with metadata about the graph.
|
||||||
|
This is used to return metadata about the graph without exposing nodes and links.
|
||||||
|
"""
|
||||||
|
return GraphMeta.from_graph(self)
|
||||||
|
|
||||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||||
"""
|
"""
|
||||||
@@ -911,14 +799,13 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
if is_static_output_block(link.source_id):
|
if is_static_output_block(link.source_id):
|
||||||
link.is_static = True # Each value block output should be static.
|
link.is_static = True # Each value block output should be static.
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def from_db( # type: ignore[reportIncompatibleMethodOverride]
|
def from_db(
|
||||||
cls,
|
|
||||||
graph: AgentGraph,
|
graph: AgentGraph,
|
||||||
for_export: bool = False,
|
for_export: bool = False,
|
||||||
sub_graphs: list[AgentGraph] | None = None,
|
sub_graphs: list[AgentGraph] | None = None,
|
||||||
) -> Self:
|
) -> "GraphModel":
|
||||||
return cls(
|
return GraphModel(
|
||||||
id=graph.id,
|
id=graph.id,
|
||||||
user_id=graph.userId if not for_export else "",
|
user_id=graph.userId if not for_export else "",
|
||||||
version=graph.version,
|
version=graph.version,
|
||||||
@@ -944,28 +831,17 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def hide_nodes(self) -> "GraphModelWithoutNodes":
|
|
||||||
"""
|
|
||||||
Returns a copy of the `GraphModel` with nodes, links, and sub-graphs hidden
|
|
||||||
(excluded from serialization). They are still present in the model instance
|
|
||||||
so all computed fields (e.g. `credentials_input_schema`) still work.
|
|
||||||
"""
|
|
||||||
return GraphModelWithoutNodes.model_validate(self, from_attributes=True)
|
|
||||||
|
|
||||||
|
class GraphMeta(Graph):
|
||||||
|
user_id: str
|
||||||
|
|
||||||
class GraphModelWithoutNodes(GraphModel):
|
# Easy work-around to prevent exposing nodes and links in the API response
|
||||||
"""
|
nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
|
||||||
GraphModel variant that excludes nodes, links, and sub-graphs from serialization.
|
links: list[Link] = Field(default=[], exclude=True)
|
||||||
|
|
||||||
Used in contexts like the store where exposing internal graph structure
|
@staticmethod
|
||||||
is not desired. Inherits all computed fields from GraphModel but marks
|
def from_graph(graph: GraphModel) -> "GraphMeta":
|
||||||
nodes and links as excluded from JSON output.
|
return GraphMeta(**graph.model_dump())
|
||||||
"""
|
|
||||||
|
|
||||||
nodes: list[NodeModel] = Field(default_factory=list, exclude=True)
|
|
||||||
links: list[Link] = Field(default_factory=list, exclude=True)
|
|
||||||
|
|
||||||
sub_graphs: list[BaseGraph] = Field(default_factory=list, exclude=True)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphsPaginated(BaseModel):
|
class GraphsPaginated(BaseModel):
|
||||||
@@ -1036,11 +912,21 @@ async def list_graphs_paginated(
|
|||||||
where=where_clause,
|
where=where_clause,
|
||||||
distinct=["id"],
|
distinct=["id"],
|
||||||
order={"version": "desc"},
|
order={"version": "desc"},
|
||||||
|
include=AGENT_GRAPH_INCLUDE,
|
||||||
skip=offset,
|
skip=offset,
|
||||||
take=page_size,
|
take=page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_models = [GraphMeta.from_db(graph) for graph in graphs]
|
graph_models: list[GraphMeta] = []
|
||||||
|
for graph in graphs:
|
||||||
|
try:
|
||||||
|
graph_meta = GraphModel.from_db(graph).meta()
|
||||||
|
# Trigger serialization to validate that the graph is well formed
|
||||||
|
graph_meta.model_dump()
|
||||||
|
graph_models.append(graph_meta)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing graph {graph.id}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
return GraphsPaginated(
|
return GraphsPaginated(
|
||||||
graphs=graph_models,
|
graphs=graph_models,
|
||||||
|
|||||||
@@ -463,120 +463,3 @@ def test_node_credentials_optional_with_other_metadata():
|
|||||||
assert node.credentials_optional is True
|
assert node.credentials_optional is True
|
||||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
assert node.metadata["position"] == {"x": 100, "y": 200}
|
||||||
assert node.metadata["customized_name"] == "My Custom Node"
|
assert node.metadata["customized_name"] == "My Custom Node"
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Tests for MCP Credential Deduplication
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def test_mcp_credential_combine_different_servers():
|
|
||||||
"""Two MCP credential fields with different server URLs should produce
|
|
||||||
separate entries when combined (not merged into one)."""
|
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsType
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
|
|
||||||
|
|
||||||
field_sentry = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=oauth2_types,
|
|
||||||
credentials_scopes=None,
|
|
||||||
discriminator="server_url",
|
|
||||||
discriminator_values={"https://mcp.sentry.dev/mcp"},
|
|
||||||
)
|
|
||||||
field_linear = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=oauth2_types,
|
|
||||||
credentials_scopes=None,
|
|
||||||
discriminator="server_url",
|
|
||||||
discriminator_values={"https://mcp.linear.app/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
combined = CredentialsFieldInfo.combine(
|
|
||||||
(field_sentry, ("node-sentry", "credentials")),
|
|
||||||
(field_linear, ("node-linear", "credentials")),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should produce 2 separate credential entries
|
|
||||||
assert len(combined) == 2, (
|
|
||||||
f"Expected 2 credential entries for 2 MCP blocks with different servers, "
|
|
||||||
f"got {len(combined)}: {list(combined.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Each entry should contain the server hostname in its key
|
|
||||||
keys = list(combined.keys())
|
|
||||||
assert any(
|
|
||||||
"mcp.sentry.dev" in k for k in keys
|
|
||||||
), f"Expected 'mcp.sentry.dev' in one key, got {keys}"
|
|
||||||
assert any(
|
|
||||||
"mcp.linear.app" in k for k in keys
|
|
||||||
), f"Expected 'mcp.linear.app' in one key, got {keys}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_mcp_credential_combine_same_server():
|
|
||||||
"""Two MCP credential fields with the same server URL should be combined
|
|
||||||
into one credential entry."""
|
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsType
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
|
|
||||||
|
|
||||||
field_a = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=oauth2_types,
|
|
||||||
credentials_scopes=None,
|
|
||||||
discriminator="server_url",
|
|
||||||
discriminator_values={"https://mcp.sentry.dev/mcp"},
|
|
||||||
)
|
|
||||||
field_b = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=oauth2_types,
|
|
||||||
credentials_scopes=None,
|
|
||||||
discriminator="server_url",
|
|
||||||
discriminator_values={"https://mcp.sentry.dev/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
combined = CredentialsFieldInfo.combine(
|
|
||||||
(field_a, ("node-a", "credentials")),
|
|
||||||
(field_b, ("node-b", "credentials")),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should produce 1 credential entry (same server URL)
|
|
||||||
assert len(combined) == 1, (
|
|
||||||
f"Expected 1 credential entry for 2 MCP blocks with same server, "
|
|
||||||
f"got {len(combined)}: {list(combined.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_mcp_credential_combine_no_discriminator_values():
|
|
||||||
"""MCP credential fields without discriminator_values should be merged
|
|
||||||
into a single entry (backwards compat for blocks without server_url set)."""
|
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsType
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
|
|
||||||
|
|
||||||
field_a = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=oauth2_types,
|
|
||||||
credentials_scopes=None,
|
|
||||||
discriminator="server_url",
|
|
||||||
)
|
|
||||||
field_b = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=oauth2_types,
|
|
||||||
credentials_scopes=None,
|
|
||||||
discriminator="server_url",
|
|
||||||
)
|
|
||||||
|
|
||||||
combined = CredentialsFieldInfo.combine(
|
|
||||||
(field_a, ("node-a", "credentials")),
|
|
||||||
(field_b, ("node-b", "credentials")),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should produce 1 entry (no URL differentiation)
|
|
||||||
assert len(combined) == 1, (
|
|
||||||
f"Expected 1 credential entry for MCP blocks without discriminator_values, "
|
|
||||||
f"got {len(combined)}: {list(combined.keys())}"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ from pydantic import (
|
|||||||
GetCoreSchemaHandler,
|
GetCoreSchemaHandler,
|
||||||
SecretStr,
|
SecretStr,
|
||||||
field_serializer,
|
field_serializer,
|
||||||
model_validator,
|
|
||||||
)
|
)
|
||||||
from pydantic_core import (
|
from pydantic_core import (
|
||||||
CoreSchema,
|
CoreSchema,
|
||||||
@@ -164,6 +163,7 @@ class User(BaseModel):
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from prisma.models import User as PrismaUser
|
from prisma.models import User as PrismaUser
|
||||||
|
|
||||||
|
from backend.data.block import BlockSchema
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -500,25 +500,6 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
provider: CP
|
provider: CP
|
||||||
type: CT
|
type: CT
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def _normalize_legacy_provider(cls, data: Any) -> Any:
|
|
||||||
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug.
|
|
||||||
|
|
||||||
Python 3.13 changed ``str(StrEnum)`` to return ``"ClassName.MEMBER"``
|
|
||||||
instead of the plain value. Old stored credential references may have
|
|
||||||
``provider: "ProviderName.MCP"`` instead of ``"mcp"``.
|
|
||||||
"""
|
|
||||||
if isinstance(data, dict):
|
|
||||||
prov = data.get("provider", "")
|
|
||||||
if isinstance(prov, str) and prov.startswith("ProviderName."):
|
|
||||||
member = prov.removeprefix("ProviderName.")
|
|
||||||
try:
|
|
||||||
data = {**data, "provider": ProviderName[member].value}
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def allowed_providers(cls) -> tuple[ProviderName, ...] | None:
|
def allowed_providers(cls) -> tuple[ProviderName, ...] | None:
|
||||||
return get_args(cls.model_fields["provider"].annotation)
|
return get_args(cls.model_fields["provider"].annotation)
|
||||||
@@ -527,13 +508,15 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
|
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
|
||||||
return get_args(cls.model_fields["type"].annotation)
|
return get_args(cls.model_fields["type"].annotation)
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def validate_credentials_field_schema(
|
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
|
||||||
field_schema: dict[str, Any], field_name: str
|
|
||||||
):
|
|
||||||
"""Validates the schema of a credentials input field"""
|
"""Validates the schema of a credentials input field"""
|
||||||
|
field_name = next(
|
||||||
|
name for name, type in model.get_credentials_fields().items() if type is cls
|
||||||
|
)
|
||||||
|
field_schema = model.jsonschema()["properties"][field_name]
|
||||||
try:
|
try:
|
||||||
field_info = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
if "Field required [type=missing" not in str(e):
|
if "Field required [type=missing" not in str(e):
|
||||||
raise
|
raise
|
||||||
@@ -543,11 +526,11 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
f"{field_schema}"
|
f"{field_schema}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
providers = field_info.provider
|
providers = cls.allowed_providers()
|
||||||
if (
|
if (
|
||||||
providers is not None
|
providers is not None
|
||||||
and len(providers) > 1
|
and len(providers) > 1
|
||||||
and not field_info.discriminator
|
and not schema_extra.discriminator
|
||||||
):
|
):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Multi-provider CredentialsField '{field_name}' "
|
f"Multi-provider CredentialsField '{field_name}' "
|
||||||
@@ -623,18 +606,11 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
] = defaultdict(list)
|
] = defaultdict(list)
|
||||||
|
|
||||||
for field, key in fields:
|
for field, key in fields:
|
||||||
if (
|
if field.provider == frozenset([ProviderName.HTTP]):
|
||||||
field.discriminator
|
# HTTP host-scoped credentials can have different hosts that reqires different credential sets.
|
||||||
and not field.discriminator_mapping
|
# Group by host extracted from the URL
|
||||||
and field.discriminator_values
|
|
||||||
):
|
|
||||||
# URL-based discrimination (e.g. HTTP host-scoped, MCP server URL):
|
|
||||||
# Each unique host gets its own credential entry.
|
|
||||||
provider_prefix = next(iter(field.provider))
|
|
||||||
# Use .value for enum types to get the plain string (e.g. "mcp" not "ProviderName.MCP")
|
|
||||||
prefix_str = getattr(provider_prefix, "value", str(provider_prefix))
|
|
||||||
providers = frozenset(
|
providers = frozenset(
|
||||||
[cast(CP, prefix_str)]
|
[cast(CP, "http")]
|
||||||
+ [
|
+ [
|
||||||
cast(CP, parse_url(str(value)).netloc)
|
cast(CP, parse_url(str(value)).netloc)
|
||||||
for value in field.discriminator_values
|
for value in field.discriminator_values
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -226,10 +225,6 @@ class SyncRabbitMQ(RabbitMQBase):
|
|||||||
class AsyncRabbitMQ(RabbitMQBase):
|
class AsyncRabbitMQ(RabbitMQBase):
|
||||||
"""Asynchronous RabbitMQ client"""
|
"""Asynchronous RabbitMQ client"""
|
||||||
|
|
||||||
def __init__(self, config: RabbitMQConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
self._reconnect_lock: asyncio.Lock | None = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return bool(self._connection and not self._connection.is_closed)
|
return bool(self._connection and not self._connection.is_closed)
|
||||||
@@ -240,17 +235,7 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
|
|
||||||
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
|
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
if self.is_connected and self._channel and not self._channel.is_closed:
|
if self.is_connected:
|
||||||
return
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.is_connected
|
|
||||||
and self._connection
|
|
||||||
and (self._channel is None or self._channel.is_closed)
|
|
||||||
):
|
|
||||||
self._channel = await self._connection.channel()
|
|
||||||
await self._channel.set_qos(prefetch_count=1)
|
|
||||||
await self.declare_infrastructure()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self._connection = await aio_pika.connect_robust(
|
self._connection = await aio_pika.connect_robust(
|
||||||
@@ -306,46 +291,24 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
exchange, routing_key=queue.routing_key or queue.name
|
exchange, routing_key=queue.routing_key or queue.name
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@func_retry
|
||||||
def _lock(self) -> asyncio.Lock:
|
async def publish_message(
|
||||||
if self._reconnect_lock is None:
|
|
||||||
self._reconnect_lock = asyncio.Lock()
|
|
||||||
return self._reconnect_lock
|
|
||||||
|
|
||||||
async def _ensure_channel(self) -> aio_pika.abc.AbstractChannel:
|
|
||||||
"""Get a valid channel, reconnecting if the current one is stale.
|
|
||||||
|
|
||||||
Uses a lock to prevent concurrent reconnection attempts from racing.
|
|
||||||
"""
|
|
||||||
if self.is_ready:
|
|
||||||
return self._channel # type: ignore # is_ready guarantees non-None
|
|
||||||
|
|
||||||
async with self._lock:
|
|
||||||
# Double-check after acquiring lock
|
|
||||||
if self.is_ready:
|
|
||||||
return self._channel # type: ignore
|
|
||||||
|
|
||||||
self._channel = None
|
|
||||||
await self.connect()
|
|
||||||
|
|
||||||
if self._channel is None:
|
|
||||||
raise RuntimeError("Channel should be established after connect")
|
|
||||||
|
|
||||||
return self._channel
|
|
||||||
|
|
||||||
async def _publish_once(
|
|
||||||
self,
|
self,
|
||||||
routing_key: str,
|
routing_key: str,
|
||||||
message: str,
|
message: str,
|
||||||
exchange: Optional[Exchange] = None,
|
exchange: Optional[Exchange] = None,
|
||||||
persistent: bool = True,
|
persistent: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
channel = await self._ensure_channel()
|
if not self.is_ready:
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
if self._channel is None:
|
||||||
|
raise RuntimeError("Channel should be established after connect")
|
||||||
|
|
||||||
if exchange:
|
if exchange:
|
||||||
exchange_obj = await channel.get_exchange(exchange.name)
|
exchange_obj = await self._channel.get_exchange(exchange.name)
|
||||||
else:
|
else:
|
||||||
exchange_obj = channel.default_exchange
|
exchange_obj = self._channel.default_exchange
|
||||||
|
|
||||||
await exchange_obj.publish(
|
await exchange_obj.publish(
|
||||||
aio_pika.Message(
|
aio_pika.Message(
|
||||||
@@ -359,23 +322,9 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
routing_key=routing_key,
|
routing_key=routing_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@func_retry
|
|
||||||
async def publish_message(
|
|
||||||
self,
|
|
||||||
routing_key: str,
|
|
||||||
message: str,
|
|
||||||
exchange: Optional[Exchange] = None,
|
|
||||||
persistent: bool = True,
|
|
||||||
) -> None:
|
|
||||||
try:
|
|
||||||
await self._publish_once(routing_key, message, exchange, persistent)
|
|
||||||
except aio_pika.exceptions.ChannelInvalidStateError:
|
|
||||||
logger.warning(
|
|
||||||
"RabbitMQ channel invalid, forcing reconnect and retrying publish"
|
|
||||||
)
|
|
||||||
async with self._lock:
|
|
||||||
self._channel = None
|
|
||||||
await self._publish_once(routing_key, message, exchange, persistent)
|
|
||||||
|
|
||||||
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
|
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
|
||||||
return await self._ensure_channel()
|
if not self.is_ready:
|
||||||
|
await self.connect()
|
||||||
|
if self._channel is None:
|
||||||
|
raise RuntimeError("Channel should be established after connect")
|
||||||
|
return self._channel
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from redis.asyncio.lock import Lock as AsyncRedisLock
|
|||||||
|
|
||||||
from backend.blocks.agent import AgentExecutorBlock
|
from backend.blocks.agent import AgentExecutorBlock
|
||||||
from backend.blocks.io import AgentOutputBlock
|
from backend.blocks.io import AgentOutputBlock
|
||||||
from backend.blocks.mcp.block import MCPToolBlock
|
|
||||||
from backend.data import redis_client as redis
|
from backend.data import redis_client as redis
|
||||||
from backend.data.block import (
|
from backend.data.block import (
|
||||||
BlockInput,
|
BlockInput,
|
||||||
@@ -230,10 +229,6 @@ async def execute_node(
|
|||||||
_input_data.nodes_input_masks = nodes_input_masks
|
_input_data.nodes_input_masks = nodes_input_masks
|
||||||
_input_data.user_id = user_id
|
_input_data.user_id = user_id
|
||||||
input_data = _input_data.model_dump()
|
input_data = _input_data.model_dump()
|
||||||
elif isinstance(node_block, MCPToolBlock):
|
|
||||||
_mcp_data = MCPToolBlock.Input(**node.input_default)
|
|
||||||
_mcp_data.tool_arguments = input_data
|
|
||||||
input_data = _mcp_data.model_dump()
|
|
||||||
data.inputs = input_data
|
data.inputs = input_data
|
||||||
|
|
||||||
# Execute the node
|
# Execute the node
|
||||||
@@ -270,34 +265,8 @@ async def execute_node(
|
|||||||
|
|
||||||
# Handle regular credentials fields
|
# Handle regular credentials fields
|
||||||
for field_name, input_type in input_model.get_credentials_fields().items():
|
for field_name, input_type in input_model.get_credentials_fields().items():
|
||||||
field_value = input_data.get(field_name)
|
credentials_meta = input_type(**input_data[field_name])
|
||||||
if not field_value or (
|
credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id)
|
||||||
isinstance(field_value, dict) and not field_value.get("id")
|
|
||||||
):
|
|
||||||
# No credentials configured — nullify so JSON schema validation
|
|
||||||
# doesn't choke on the empty default `{}`.
|
|
||||||
input_data[field_name] = None
|
|
||||||
continue # Block runs without credentials
|
|
||||||
|
|
||||||
credentials_meta = input_type(**field_value)
|
|
||||||
# Write normalized values back so JSON schema validation also passes
|
|
||||||
# (model_validator may have fixed legacy formats like "ProviderName.MCP")
|
|
||||||
input_data[field_name] = credentials_meta.model_dump(mode="json")
|
|
||||||
try:
|
|
||||||
credentials, lock = await creds_manager.acquire(
|
|
||||||
user_id, credentials_meta.id
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
# Credential was deleted or doesn't exist.
|
|
||||||
# If the field has a default, run without credentials.
|
|
||||||
if input_model.model_fields[field_name].default is not None:
|
|
||||||
log_metadata.warning(
|
|
||||||
f"Credentials #{credentials_meta.id} not found, "
|
|
||||||
"running without (field has default)"
|
|
||||||
)
|
|
||||||
input_data[field_name] = input_model.model_fields[field_name].default
|
|
||||||
continue
|
|
||||||
raise
|
|
||||||
creds_locks.append(lock)
|
creds_locks.append(lock)
|
||||||
extra_exec_kwargs[field_name] = credentials
|
extra_exec_kwargs[field_name] = credentials
|
||||||
|
|
||||||
|
|||||||
@@ -265,13 +265,7 @@ async def _validate_node_input_credentials(
|
|||||||
# Track if any credential field is missing for this node
|
# Track if any credential field is missing for this node
|
||||||
has_missing_credentials = False
|
has_missing_credentials = False
|
||||||
|
|
||||||
# A credential field is optional if the node metadata says so, or if
|
|
||||||
# the block schema declares a default for the field.
|
|
||||||
required_fields = block.input_schema.get_required_fields()
|
|
||||||
is_creds_optional = node.credentials_optional
|
|
||||||
|
|
||||||
for field_name, credentials_meta_type in credentials_fields.items():
|
for field_name, credentials_meta_type in credentials_fields.items():
|
||||||
field_is_optional = is_creds_optional or field_name not in required_fields
|
|
||||||
try:
|
try:
|
||||||
# Check nodes_input_masks first, then input_default
|
# Check nodes_input_masks first, then input_default
|
||||||
field_value = None
|
field_value = None
|
||||||
@@ -284,7 +278,7 @@ async def _validate_node_input_credentials(
|
|||||||
elif field_name in node.input_default:
|
elif field_name in node.input_default:
|
||||||
# For optional credentials, don't use input_default - treat as missing
|
# For optional credentials, don't use input_default - treat as missing
|
||||||
# This prevents stale credential IDs from failing validation
|
# This prevents stale credential IDs from failing validation
|
||||||
if field_is_optional:
|
if node.credentials_optional:
|
||||||
field_value = None
|
field_value = None
|
||||||
else:
|
else:
|
||||||
field_value = node.input_default[field_name]
|
field_value = node.input_default[field_name]
|
||||||
@@ -294,8 +288,8 @@ async def _validate_node_input_credentials(
|
|||||||
isinstance(field_value, dict) and not field_value.get("id")
|
isinstance(field_value, dict) and not field_value.get("id")
|
||||||
):
|
):
|
||||||
has_missing_credentials = True
|
has_missing_credentials = True
|
||||||
# If credential field is optional, skip instead of error
|
# If node has credentials_optional flag, mark for skipping instead of error
|
||||||
if field_is_optional:
|
if node.credentials_optional:
|
||||||
continue # Don't add error, will be marked for skip after loop
|
continue # Don't add error, will be marked for skip after loop
|
||||||
else:
|
else:
|
||||||
credential_errors[node.id][
|
credential_errors[node.id][
|
||||||
@@ -345,16 +339,16 @@ async def _validate_node_input_credentials(
|
|||||||
] = "Invalid credentials: type/provider mismatch"
|
] = "Invalid credentials: type/provider mismatch"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If node has optional credentials and any are missing, allow running without.
|
# If node has optional credentials and any are missing, mark for skipping
|
||||||
# The executor will pass credentials=None to the block's run().
|
# But only if there are no other errors for this node
|
||||||
if (
|
if (
|
||||||
has_missing_credentials
|
has_missing_credentials
|
||||||
and is_creds_optional
|
and node.credentials_optional
|
||||||
and node.id not in credential_errors
|
and node.id not in credential_errors
|
||||||
):
|
):
|
||||||
|
nodes_to_skip.add(node.id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Node #{node.id}: optional credentials not configured, "
|
f"Node #{node.id} will be skipped: optional credentials not configured"
|
||||||
"running without"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return credential_errors, nodes_to_skip
|
return credential_errors, nodes_to_skip
|
||||||
@@ -379,7 +373,7 @@ def make_node_credentials_input_map(
|
|||||||
# Get aggregated credentials fields for the graph
|
# Get aggregated credentials fields for the graph
|
||||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||||
|
|
||||||
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
|
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
|
||||||
# Best-effort map: skip missing items
|
# Best-effort map: skip missing items
|
||||||
if graph_input_name not in graph_credentials_input:
|
if graph_input_name not in graph_credentials_input:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -495,7 +495,6 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip(
|
|||||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||||
"credentials": mock_credentials_field_type
|
"credentials": mock_credentials_field_type
|
||||||
}
|
}
|
||||||
mock_block.input_schema.get_required_fields.return_value = {"credentials"}
|
|
||||||
mock_node.block = mock_block
|
mock_node.block = mock_block
|
||||||
|
|
||||||
# Create mock graph
|
# Create mock graph
|
||||||
@@ -509,8 +508,8 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip(
|
|||||||
nodes_input_masks=None,
|
nodes_input_masks=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Node should NOT be in nodes_to_skip (runs without credentials) and not in errors
|
# Node should be in nodes_to_skip, not in errors
|
||||||
assert mock_node.id not in nodes_to_skip
|
assert mock_node.id in nodes_to_skip
|
||||||
assert mock_node.id not in errors
|
assert mock_node.id not in errors
|
||||||
|
|
||||||
|
|
||||||
@@ -536,7 +535,6 @@ async def test_validate_node_input_credentials_required_missing_creds_error(
|
|||||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||||
"credentials": mock_credentials_field_type
|
"credentials": mock_credentials_field_type
|
||||||
}
|
}
|
||||||
mock_block.input_schema.get_required_fields.return_value = {"credentials"}
|
|
||||||
mock_node.block = mock_block
|
mock_node.block = mock_block
|
||||||
|
|
||||||
# Create mock graph
|
# Create mock graph
|
||||||
|
|||||||
@@ -22,27 +22,6 @@ from backend.util.settings import Settings
|
|||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
|
|
||||||
def _provider_matches(stored: str, expected: str) -> bool:
|
|
||||||
"""Compare provider strings, handling Python 3.13 ``str(StrEnum)`` bug.
|
|
||||||
|
|
||||||
On Python 3.13, ``str(ProviderName.MCP)`` returns ``"ProviderName.MCP"``
|
|
||||||
instead of ``"mcp"``. OAuth states persisted with the buggy format need
|
|
||||||
to match when ``expected`` is the canonical value (e.g. ``"mcp"``).
|
|
||||||
"""
|
|
||||||
if stored == expected:
|
|
||||||
return True
|
|
||||||
if stored.startswith("ProviderName."):
|
|
||||||
member = stored.removeprefix("ProviderName.")
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
try:
|
|
||||||
return ProviderName[member].value == expected
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached
|
# This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached
|
||||||
ollama_credentials = APIKeyCredentials(
|
ollama_credentials = APIKeyCredentials(
|
||||||
id="744fdc56-071a-4761-b5a5-0af0ce10a2b5",
|
id="744fdc56-071a-4761-b5a5-0af0ce10a2b5",
|
||||||
@@ -245,14 +224,6 @@ openweathermap_credentials = APIKeyCredentials(
|
|||||||
expires_at=None,
|
expires_at=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
elevenlabs_credentials = APIKeyCredentials(
|
|
||||||
id="f4a8b6c2-3d1e-4f5a-9b8c-7d6e5f4a3b2c",
|
|
||||||
provider="elevenlabs",
|
|
||||||
api_key=SecretStr(settings.secrets.elevenlabs_api_key),
|
|
||||||
title="Use Credits for ElevenLabs",
|
|
||||||
expires_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_CREDENTIALS = [
|
DEFAULT_CREDENTIALS = [
|
||||||
ollama_credentials,
|
ollama_credentials,
|
||||||
revid_credentials,
|
revid_credentials,
|
||||||
@@ -281,7 +252,6 @@ DEFAULT_CREDENTIALS = [
|
|||||||
v0_credentials,
|
v0_credentials,
|
||||||
webshare_proxy_credentials,
|
webshare_proxy_credentials,
|
||||||
openweathermap_credentials,
|
openweathermap_credentials,
|
||||||
elevenlabs_credentials,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
||||||
@@ -396,8 +366,6 @@ class IntegrationCredentialsStore:
|
|||||||
all_credentials.append(webshare_proxy_credentials)
|
all_credentials.append(webshare_proxy_credentials)
|
||||||
if settings.secrets.openweathermap_api_key:
|
if settings.secrets.openweathermap_api_key:
|
||||||
all_credentials.append(openweathermap_credentials)
|
all_credentials.append(openweathermap_credentials)
|
||||||
if settings.secrets.elevenlabs_api_key:
|
|
||||||
all_credentials.append(elevenlabs_credentials)
|
|
||||||
return all_credentials
|
return all_credentials
|
||||||
|
|
||||||
async def get_creds_by_id(
|
async def get_creds_by_id(
|
||||||
@@ -410,7 +378,7 @@ class IntegrationCredentialsStore:
|
|||||||
self, user_id: str, provider: str
|
self, user_id: str, provider: str
|
||||||
) -> list[Credentials]:
|
) -> list[Credentials]:
|
||||||
credentials = await self.get_all_creds(user_id)
|
credentials = await self.get_all_creds(user_id)
|
||||||
return [c for c in credentials if _provider_matches(c.provider, provider)]
|
return [c for c in credentials if c.provider == provider]
|
||||||
|
|
||||||
async def get_authorized_providers(self, user_id: str) -> list[str]:
|
async def get_authorized_providers(self, user_id: str) -> list[str]:
|
||||||
credentials = await self.get_all_creds(user_id)
|
credentials = await self.get_all_creds(user_id)
|
||||||
@@ -506,6 +474,17 @@ class IntegrationCredentialsStore:
|
|||||||
async with self.edit_user_integrations(user_id) as user_integrations:
|
async with self.edit_user_integrations(user_id) as user_integrations:
|
||||||
user_integrations.oauth_states.append(state)
|
user_integrations.oauth_states.append(state)
|
||||||
|
|
||||||
|
async with await self.locked_user_integrations(user_id):
|
||||||
|
|
||||||
|
user_integrations = await self._get_user_integrations(user_id)
|
||||||
|
oauth_states = user_integrations.oauth_states
|
||||||
|
oauth_states.append(state)
|
||||||
|
user_integrations.oauth_states = oauth_states
|
||||||
|
|
||||||
|
await self.db_manager.update_user_integrations(
|
||||||
|
user_id=user_id, data=user_integrations
|
||||||
|
)
|
||||||
|
|
||||||
return token, code_challenge
|
return token, code_challenge
|
||||||
|
|
||||||
def _generate_code_challenge(self) -> tuple[str, str]:
|
def _generate_code_challenge(self) -> tuple[str, str]:
|
||||||
@@ -531,7 +510,7 @@ class IntegrationCredentialsStore:
|
|||||||
state
|
state
|
||||||
for state in oauth_states
|
for state in oauth_states
|
||||||
if secrets.compare_digest(state.token, token)
|
if secrets.compare_digest(state.token, token)
|
||||||
and _provider_matches(state.provider, provider)
|
and state.provider == provider
|
||||||
and state.expires_at > now.timestamp()
|
and state.expires_at > now.timestamp()
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
|
|||||||
@@ -137,10 +137,7 @@ class IntegrationCredentialsManager:
|
|||||||
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
|
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
|
||||||
) -> OAuth2Credentials:
|
) -> OAuth2Credentials:
|
||||||
async with self._locked(user_id, credentials.id, "refresh"):
|
async with self._locked(user_id, credentials.id, "refresh"):
|
||||||
if credentials.provider == ProviderName.MCP.value:
|
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
|
||||||
oauth_handler = _create_mcp_oauth_handler(credentials)
|
|
||||||
else:
|
|
||||||
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
|
|
||||||
if oauth_handler.needs_refresh(credentials):
|
if oauth_handler.needs_refresh(credentials):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Refreshing '{credentials.provider}' "
|
f"Refreshing '{credentials.provider}' "
|
||||||
@@ -239,25 +236,3 @@ async def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandl
|
|||||||
client_secret=client_secret,
|
client_secret=client_secret,
|
||||||
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _create_mcp_oauth_handler(
|
|
||||||
credentials: OAuth2Credentials,
|
|
||||||
) -> "BaseOAuthHandler":
|
|
||||||
"""Create an MCPOAuthHandler from credential metadata for token refresh.
|
|
||||||
|
|
||||||
MCP OAuth handlers have dynamic endpoints discovered per-server, so they
|
|
||||||
can't be registered as singletons in HANDLERS_BY_NAME. Instead, the handler
|
|
||||||
is reconstructed from metadata stored on the credential during initial auth.
|
|
||||||
"""
|
|
||||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
|
||||||
|
|
||||||
meta = credentials.metadata or {}
|
|
||||||
return MCPOAuthHandler(
|
|
||||||
client_id=meta.get("mcp_client_id", ""),
|
|
||||||
client_secret=meta.get("mcp_client_secret", ""),
|
|
||||||
redirect_uri="", # Not needed for token refresh
|
|
||||||
authorize_url="", # Not needed for token refresh
|
|
||||||
token_url=meta.get("mcp_token_url", ""),
|
|
||||||
resource_url=meta.get("mcp_resource_url"),
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ class ProviderName(str, Enum):
|
|||||||
DISCORD = "discord"
|
DISCORD = "discord"
|
||||||
D_ID = "d_id"
|
D_ID = "d_id"
|
||||||
E2B = "e2b"
|
E2B = "e2b"
|
||||||
ELEVENLABS = "elevenlabs"
|
|
||||||
FAL = "fal"
|
FAL = "fal"
|
||||||
GITHUB = "github"
|
GITHUB = "github"
|
||||||
GOOGLE = "google"
|
GOOGLE = "google"
|
||||||
@@ -30,7 +29,6 @@ class ProviderName(str, Enum):
|
|||||||
IDEOGRAM = "ideogram"
|
IDEOGRAM = "ideogram"
|
||||||
JINA = "jina"
|
JINA = "jina"
|
||||||
LLAMA_API = "llama_api"
|
LLAMA_API = "llama_api"
|
||||||
MCP = "mcp"
|
|
||||||
MEDIUM = "medium"
|
MEDIUM = "medium"
|
||||||
MEM0 = "mem0"
|
MEM0 = "mem0"
|
||||||
NOTION = "notion"
|
NOTION = "notion"
|
||||||
|
|||||||
@@ -50,21 +50,6 @@ async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str):
|
|||||||
if (
|
if (
|
||||||
creds_meta := new_node.input_default.get(creds_field_name)
|
creds_meta := new_node.input_default.get(creds_field_name)
|
||||||
) and not await get_credentials(creds_meta["id"]):
|
) and not await get_credentials(creds_meta["id"]):
|
||||||
# If the credential field is optional (has a default in the
|
|
||||||
# schema, or node metadata marks it optional), clear the stale
|
|
||||||
# reference instead of blocking the save.
|
|
||||||
creds_field_optional = (
|
|
||||||
new_node.credentials_optional
|
|
||||||
or creds_field_name not in block_input_schema.get_required_fields()
|
|
||||||
)
|
|
||||||
if creds_field_optional:
|
|
||||||
new_node.input_default[creds_field_name] = {}
|
|
||||||
logger.warning(
|
|
||||||
f"Node #{new_node.id}: cleared stale optional "
|
|
||||||
f"credentials #{creds_meta['id']} for "
|
|
||||||
f"'{creds_field_name}'"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Node #{new_node.id} input '{creds_field_name}' updated with "
|
f"Node #{new_node.id} input '{creds_field_name}' updated with "
|
||||||
f"non-existent credentials #{creds_meta['id']}"
|
f"non-existent credentials #{creds_meta['id']}"
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ from pathlib import Path
|
|||||||
from typing import TYPE_CHECKING, Literal
|
from typing import TYPE_CHECKING, Literal
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
@@ -19,35 +17,6 @@ from backend.util.virus_scanner import scan_content_safe
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceUri(BaseModel):
|
|
||||||
"""Parsed workspace:// URI."""
|
|
||||||
|
|
||||||
file_ref: str # File ID or path (e.g. "abc123" or "/path/to/file.txt")
|
|
||||||
mime_type: str | None = None # MIME type from fragment (e.g. "video/mp4")
|
|
||||||
is_path: bool = False # True if file_ref is a path (starts with "/")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_workspace_uri(uri: str) -> WorkspaceUri:
|
|
||||||
"""Parse a workspace:// URI into its components.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
"workspace://abc123" → WorkspaceUri(file_ref="abc123", mime_type=None, is_path=False)
|
|
||||||
"workspace://abc123#video/mp4" → WorkspaceUri(file_ref="abc123", mime_type="video/mp4", is_path=False)
|
|
||||||
"workspace:///path/to/file.txt" → WorkspaceUri(file_ref="/path/to/file.txt", mime_type=None, is_path=True)
|
|
||||||
"""
|
|
||||||
raw = uri.removeprefix("workspace://")
|
|
||||||
mime_type: str | None = None
|
|
||||||
if "#" in raw:
|
|
||||||
raw, fragment = raw.split("#", 1)
|
|
||||||
mime_type = fragment or None
|
|
||||||
return WorkspaceUri(
|
|
||||||
file_ref=raw,
|
|
||||||
mime_type=mime_type,
|
|
||||||
is_path=raw.startswith("/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Return format options for store_media_file
|
# Return format options for store_media_file
|
||||||
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||||
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
||||||
@@ -214,20 +183,22 @@ async def store_media_file(
|
|||||||
"This file type is only available in CoPilot sessions."
|
"This file type is only available in CoPilot sessions."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse workspace reference (strips #mimeType fragment from file ID)
|
# Parse workspace reference
|
||||||
ws = parse_workspace_uri(file)
|
# workspace://abc123 - by file ID
|
||||||
|
# workspace:///path/to/file.txt - by virtual path
|
||||||
|
file_ref = file[12:] # Remove "workspace://"
|
||||||
|
|
||||||
if ws.is_path:
|
if file_ref.startswith("/"):
|
||||||
# Path reference: workspace:///path/to/file.txt
|
# Path reference
|
||||||
workspace_content = await workspace_manager.read_file(ws.file_ref)
|
workspace_content = await workspace_manager.read_file(file_ref)
|
||||||
file_info = await workspace_manager.get_file_info_by_path(ws.file_ref)
|
file_info = await workspace_manager.get_file_info_by_path(file_ref)
|
||||||
filename = sanitize_filename(
|
filename = sanitize_filename(
|
||||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# ID reference: workspace://abc123 or workspace://abc123#video/mp4
|
# ID reference
|
||||||
workspace_content = await workspace_manager.read_file_by_id(ws.file_ref)
|
workspace_content = await workspace_manager.read_file_by_id(file_ref)
|
||||||
file_info = await workspace_manager.get_file_info(ws.file_ref)
|
file_info = await workspace_manager.get_file_info(file_ref)
|
||||||
filename = sanitize_filename(
|
filename = sanitize_filename(
|
||||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
)
|
)
|
||||||
@@ -342,14 +313,6 @@ async def store_media_file(
|
|||||||
if not target_path.is_file():
|
if not target_path.is_file():
|
||||||
raise ValueError(f"Local file does not exist: {target_path}")
|
raise ValueError(f"Local file does not exist: {target_path}")
|
||||||
|
|
||||||
# Virus scan the local file before any further processing
|
|
||||||
local_content = target_path.read_bytes()
|
|
||||||
if len(local_content) > MAX_FILE_SIZE_BYTES:
|
|
||||||
raise ValueError(
|
|
||||||
f"File too large: {len(local_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
|
||||||
)
|
|
||||||
await scan_content_safe(local_content, filename=sanitized_file)
|
|
||||||
|
|
||||||
# Return based on requested format
|
# Return based on requested format
|
||||||
if return_format == "for_local_processing":
|
if return_format == "for_local_processing":
|
||||||
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
||||||
@@ -371,21 +334,7 @@ async def store_media_file(
|
|||||||
|
|
||||||
# Don't re-save if input was already from workspace
|
# Don't re-save if input was already from workspace
|
||||||
if is_from_workspace:
|
if is_from_workspace:
|
||||||
# Return original workspace reference, ensuring MIME type fragment
|
# Return original workspace reference
|
||||||
ws = parse_workspace_uri(file)
|
|
||||||
if not ws.mime_type:
|
|
||||||
# Add MIME type fragment if missing (older refs without it)
|
|
||||||
try:
|
|
||||||
if ws.is_path:
|
|
||||||
info = await workspace_manager.get_file_info_by_path(
|
|
||||||
ws.file_ref
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
info = await workspace_manager.get_file_info(ws.file_ref)
|
|
||||||
if info:
|
|
||||||
return MediaFileType(f"{file}#{info.mimeType}")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return MediaFileType(file)
|
return MediaFileType(file)
|
||||||
|
|
||||||
# Save new content to workspace
|
# Save new content to workspace
|
||||||
@@ -397,7 +346,7 @@ async def store_media_file(
|
|||||||
filename=filename,
|
filename=filename,
|
||||||
overwrite=True,
|
overwrite=True,
|
||||||
)
|
)
|
||||||
return MediaFileType(f"workspace://{file_record.id}#{file_record.mimeType}")
|
return MediaFileType(f"workspace://{file_record.id}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid return_format: {return_format}")
|
raise ValueError(f"Invalid return_format: {return_format}")
|
||||||
|
|||||||
@@ -247,100 +247,3 @@ class TestFileCloudIntegration:
|
|||||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
return_format="for_local_processing",
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_store_media_file_local_path_scanned(self):
|
|
||||||
"""Test that local file paths are scanned for viruses."""
|
|
||||||
graph_exec_id = "test-exec-123"
|
|
||||||
local_file = "test_video.mp4"
|
|
||||||
file_content = b"fake video content"
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.util.file.get_cloud_storage_handler"
|
|
||||||
) as mock_handler_getter, patch(
|
|
||||||
"backend.util.file.scan_content_safe"
|
|
||||||
) as mock_scan, patch(
|
|
||||||
"backend.util.file.Path"
|
|
||||||
) as mock_path_class:
|
|
||||||
|
|
||||||
# Mock cloud storage handler - not a cloud path
|
|
||||||
mock_handler = MagicMock()
|
|
||||||
mock_handler.is_cloud_path.return_value = False
|
|
||||||
mock_handler_getter.return_value = mock_handler
|
|
||||||
|
|
||||||
# Mock virus scanner
|
|
||||||
mock_scan.return_value = None
|
|
||||||
|
|
||||||
# Mock file system operations
|
|
||||||
mock_base_path = MagicMock()
|
|
||||||
mock_target_path = MagicMock()
|
|
||||||
mock_resolved_path = MagicMock()
|
|
||||||
|
|
||||||
mock_path_class.return_value = mock_base_path
|
|
||||||
mock_base_path.mkdir = MagicMock()
|
|
||||||
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
|
||||||
mock_target_path.resolve.return_value = mock_resolved_path
|
|
||||||
mock_resolved_path.is_relative_to.return_value = True
|
|
||||||
mock_resolved_path.is_file.return_value = True
|
|
||||||
mock_resolved_path.read_bytes.return_value = file_content
|
|
||||||
mock_resolved_path.relative_to.return_value = Path(local_file)
|
|
||||||
mock_resolved_path.name = local_file
|
|
||||||
|
|
||||||
result = await store_media_file(
|
|
||||||
file=MediaFileType(local_file),
|
|
||||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify virus scan was called for local file
|
|
||||||
mock_scan.assert_called_once_with(file_content, filename=local_file)
|
|
||||||
|
|
||||||
# Result should be the relative path
|
|
||||||
assert str(result) == local_file
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_store_media_file_local_path_virus_detected(self):
|
|
||||||
"""Test that infected local files raise VirusDetectedError."""
|
|
||||||
from backend.api.features.store.exceptions import VirusDetectedError
|
|
||||||
|
|
||||||
graph_exec_id = "test-exec-123"
|
|
||||||
local_file = "infected.exe"
|
|
||||||
file_content = b"malicious content"
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.util.file.get_cloud_storage_handler"
|
|
||||||
) as mock_handler_getter, patch(
|
|
||||||
"backend.util.file.scan_content_safe"
|
|
||||||
) as mock_scan, patch(
|
|
||||||
"backend.util.file.Path"
|
|
||||||
) as mock_path_class:
|
|
||||||
|
|
||||||
# Mock cloud storage handler - not a cloud path
|
|
||||||
mock_handler = MagicMock()
|
|
||||||
mock_handler.is_cloud_path.return_value = False
|
|
||||||
mock_handler_getter.return_value = mock_handler
|
|
||||||
|
|
||||||
# Mock virus scanner to detect virus
|
|
||||||
mock_scan.side_effect = VirusDetectedError(
|
|
||||||
"EICAR-Test-File", "File rejected due to virus detection"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock file system operations
|
|
||||||
mock_base_path = MagicMock()
|
|
||||||
mock_target_path = MagicMock()
|
|
||||||
mock_resolved_path = MagicMock()
|
|
||||||
|
|
||||||
mock_path_class.return_value = mock_base_path
|
|
||||||
mock_base_path.mkdir = MagicMock()
|
|
||||||
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
|
||||||
mock_target_path.resolve.return_value = mock_resolved_path
|
|
||||||
mock_resolved_path.is_relative_to.return_value = True
|
|
||||||
mock_resolved_path.is_file.return_value = True
|
|
||||||
mock_resolved_path.read_bytes.return_value = file_content
|
|
||||||
|
|
||||||
with pytest.raises(VirusDetectedError):
|
|
||||||
await store_media_file(
|
|
||||||
file=MediaFileType(local_file),
|
|
||||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from pydantic import SecretStr
|
|||||||
from sentry_sdk.integrations import DidNotEnable
|
from sentry_sdk.integrations import DidNotEnable
|
||||||
from sentry_sdk.integrations.anthropic import AnthropicIntegration
|
from sentry_sdk.integrations.anthropic import AnthropicIntegration
|
||||||
from sentry_sdk.integrations.asyncio import AsyncioIntegration
|
from sentry_sdk.integrations.asyncio import AsyncioIntegration
|
||||||
|
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||||
|
from sentry_sdk.integrations.httpx import HttpxIntegration
|
||||||
from sentry_sdk.integrations.launchdarkly import LaunchDarklyIntegration
|
from sentry_sdk.integrations.launchdarkly import LaunchDarklyIntegration
|
||||||
from sentry_sdk.integrations.logging import LoggingIntegration
|
from sentry_sdk.integrations.logging import LoggingIntegration
|
||||||
|
|
||||||
@@ -37,6 +39,8 @@ def sentry_init():
|
|||||||
_experiments={"enable_logs": True},
|
_experiments={"enable_logs": True},
|
||||||
integrations=[
|
integrations=[
|
||||||
AsyncioIntegration(),
|
AsyncioIntegration(),
|
||||||
|
FastApiIntegration(), # Traces FastAPI requests with detailed spans
|
||||||
|
HttpxIntegration(), # Traces outgoing HTTP calls (OpenAI, external APIs)
|
||||||
LoggingIntegration(sentry_logs_level=logging.INFO),
|
LoggingIntegration(sentry_logs_level=logging.INFO),
|
||||||
AnthropicIntegration(
|
AnthropicIntegration(
|
||||||
include_prompts=False,
|
include_prompts=False,
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ class HostResolver(abc.AbstractResolver):
|
|||||||
def __init__(self, ssl_hostname: str, ip_addresses: list[str]):
|
def __init__(self, ssl_hostname: str, ip_addresses: list[str]):
|
||||||
self.ssl_hostname = ssl_hostname
|
self.ssl_hostname = ssl_hostname
|
||||||
self.ip_addresses = ip_addresses
|
self.ip_addresses = ip_addresses
|
||||||
self._default = aiohttp.ThreadedResolver()
|
self._default = aiohttp.AsyncResolver()
|
||||||
|
|
||||||
async def resolve(self, host, port=0, family=socket.AF_INET):
|
async def resolve(self, host, port=0, family=socket.AF_INET):
|
||||||
if host == self.ssl_hostname:
|
if host == self.ssl_hostname:
|
||||||
@@ -467,7 +467,7 @@ class Requests:
|
|||||||
resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses)
|
resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses)
|
||||||
ssl_context = ssl.create_default_context()
|
ssl_context = ssl.create_default_context()
|
||||||
connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context)
|
connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context)
|
||||||
session_kwargs: dict = {}
|
session_kwargs = {}
|
||||||
if connector:
|
if connector:
|
||||||
session_kwargs["connector"] = connector
|
session_kwargs["connector"] = connector
|
||||||
|
|
||||||
|
|||||||
@@ -656,7 +656,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
|||||||
e2b_api_key: str = Field(default="", description="E2B API key")
|
e2b_api_key: str = Field(default="", description="E2B API key")
|
||||||
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
||||||
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
||||||
elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key")
|
|
||||||
|
|
||||||
linear_client_id: str = Field(default="", description="Linear client ID")
|
linear_client_id: str = Field(default="", description="Linear client ID")
|
||||||
linear_client_secret: str = Field(default="", description="Linear client secret")
|
linear_client_secret: str = Field(default="", description="Linear client secret")
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from backend.data.workspace import (
|
|||||||
soft_delete_workspace_file,
|
soft_delete_workspace_file,
|
||||||
)
|
)
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
from backend.util.virus_scanner import scan_content_safe
|
|
||||||
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -188,9 +187,6 @@ class WorkspaceManager:
|
|||||||
f"{Config().max_file_size_mb}MB limit"
|
f"{Config().max_file_size_mb}MB limit"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Virus scan content before persisting (defense in depth)
|
|
||||||
await scan_content_safe(content, filename=filename)
|
|
||||||
|
|
||||||
# Determine path with session scoping
|
# Determine path with session scoping
|
||||||
if path is None:
|
if path is None:
|
||||||
path = f"/{filename}"
|
path = f"/{filename}"
|
||||||
|
|||||||
7054
autogpt_platform/backend/poetry.lock
generated
7054
autogpt_platform/backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user