mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-13 08:14:58 -05:00
Compare commits
1 Commits
fix/claude
...
refactor/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7f7a7067ec |
2
.github/workflows/classic-frontend-ci.yml
vendored
2
.github/workflows/classic-frontend-ci.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||||
if: github.event_name == 'push'
|
if: github.event_name == 'push'
|
||||||
uses: peter-evans/create-pull-request@v8
|
uses: peter-evans/create-pull-request@v7
|
||||||
with:
|
with:
|
||||||
add-paths: classic/frontend/build/web
|
add-paths: classic/frontend/build/web
|
||||||
base: ${{ github.ref_name }}
|
base: ${{ github.ref_name }}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.event.workflow_run.head_branch }}
|
ref: ${{ github.event.workflow_run.head_branch }}
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
@@ -42,7 +42,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Get CI failure details
|
- name: Get CI failure details
|
||||||
id: failure_details
|
id: failure_details
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const run = await github.rest.actions.getWorkflowRun({
|
const run = await github.rest.actions.getWorkflowRun({
|
||||||
|
|||||||
11
.github/workflows/claude-dependabot.yml
vendored
11
.github/workflows/claude-dependabot.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
|||||||
actions: read # Required for CI access
|
actions: read # Required for CI access
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -41,7 +41,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -78,7 +78,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -124,7 +124,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
@@ -309,7 +309,6 @@ jobs:
|
|||||||
uses: anthropics/claude-code-action@v1
|
uses: anthropics/claude-code-action@v1
|
||||||
with:
|
with:
|
||||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||||
allowed_bots: "dependabot[bot]"
|
|
||||||
claude_args: |
|
claude_args: |
|
||||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||||
prompt: |
|
prompt: |
|
||||||
|
|||||||
10
.github/workflows/claude.yml
vendored
10
.github/workflows/claude.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
|||||||
actions: read # Required for CI access
|
actions: read # Required for CI access
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -94,7 +94,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -140,7 +140,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
|
|||||||
2
.github/workflows/codeql.yml
vendored
2
.github/workflows/codeql.yml
vendored
@@ -58,7 +58,7 @@ jobs:
|
|||||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
# Initializes the CodeQL tools for scanning.
|
# Initializes the CodeQL tools for scanning.
|
||||||
- name: Initialize CodeQL
|
- name: Initialize CodeQL
|
||||||
|
|||||||
10
.github/workflows/copilot-setup-steps.yml
vendored
10
.github/workflows/copilot-setup-steps.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
# If you do not check out your code, Copilot will do this for you.
|
# If you do not check out your code, Copilot will do this for you.
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
submodules: true
|
submodules: true
|
||||||
@@ -39,7 +39,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -76,7 +76,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -132,7 +132,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
|
|||||||
4
.github/workflows/docs-block-sync.yml
vendored
4
.github/workflows/docs-block-sync.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
4
.github/workflows/docs-claude-review.yml
vendored
4
.github/workflows/docs-claude-review.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
4
.github/workflows/docs-enhance.yml
vendored
4
.github/workflows/docs-enhance.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger deploy workflow
|
- name: Trigger deploy workflow
|
||||||
uses: peter-evans/repository-dispatch@v4
|
uses: peter-evans/repository-dispatch@v3
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.ref_name || 'master' }}
|
ref: ${{ github.ref_name || 'master' }}
|
||||||
|
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger deploy workflow
|
- name: Trigger deploy workflow
|
||||||
uses: peter-evans/repository-dispatch@v4
|
uses: peter-evans/repository-dispatch@v3
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
|
|||||||
4
.github/workflows/platform-backend-ci.yml
vendored
4
.github/workflows/platform-backend-ci.yml
vendored
@@ -68,7 +68,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
submodules: true
|
submodules: true
|
||||||
@@ -88,7 +88,7 @@ jobs:
|
|||||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ jobs:
|
|||||||
- name: Check comment permissions and deployment status
|
- name: Check comment permissions and deployment status
|
||||||
id: check_status
|
id: check_status
|
||||||
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const commentBody = context.payload.comment.body.trim();
|
const commentBody = context.payload.comment.body.trim();
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post permission denied comment
|
- name: Post permission denied comment
|
||||||
if: steps.check_status.outputs.permission_denied == 'true'
|
if: steps.check_status.outputs.permission_denied == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -68,7 +68,7 @@ jobs:
|
|||||||
- name: Get PR details for deployment
|
- name: Get PR details for deployment
|
||||||
id: pr_details
|
id: pr_details
|
||||||
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const pr = await github.rest.pulls.get({
|
const pr = await github.rest.pulls.get({
|
||||||
@@ -82,7 +82,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Dispatch Deploy Event
|
- name: Dispatch Deploy Event
|
||||||
if: steps.check_status.outputs.should_deploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true'
|
||||||
uses: peter-evans/repository-dispatch@v4
|
uses: peter-evans/repository-dispatch@v3
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
@@ -98,7 +98,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post deploy success comment
|
- name: Post deploy success comment
|
||||||
if: steps.check_status.outputs.should_deploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -110,7 +110,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Dispatch Undeploy Event (from comment)
|
- name: Dispatch Undeploy Event (from comment)
|
||||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: peter-evans/repository-dispatch@v4
|
uses: peter-evans/repository-dispatch@v3
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
@@ -126,7 +126,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post undeploy success comment
|
- name: Post undeploy success comment
|
||||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -139,7 +139,7 @@ jobs:
|
|||||||
- name: Check deployment status on PR close
|
- name: Check deployment status on PR close
|
||||||
id: check_pr_close
|
id: check_pr_close
|
||||||
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const comments = await github.rest.issues.listComments({
|
const comments = await github.rest.issues.listComments({
|
||||||
@@ -168,7 +168,7 @@ jobs:
|
|||||||
github.event_name == 'pull_request' &&
|
github.event_name == 'pull_request' &&
|
||||||
github.event.action == 'closed' &&
|
github.event.action == 'closed' &&
|
||||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||||
uses: peter-evans/repository-dispatch@v4
|
uses: peter-evans/repository-dispatch@v3
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
@@ -187,7 +187,7 @@ jobs:
|
|||||||
github.event_name == 'pull_request' &&
|
github.event_name == 'pull_request' &&
|
||||||
github.event.action == 'closed' &&
|
github.event.action == 'closed' &&
|
||||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
|
|||||||
48
.github/workflows/platform-frontend-ci.yml
vendored
48
.github/workflows/platform-frontend-ci.yml
vendored
@@ -27,22 +27,13 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||||
components-changed: ${{ steps.filter.outputs.components }}
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Check for component changes
|
|
||||||
uses: dorny/paths-filter@v3
|
|
||||||
id: filter
|
|
||||||
with:
|
|
||||||
filters: |
|
|
||||||
components:
|
|
||||||
- 'autogpt_platform/frontend/src/components/**'
|
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -54,7 +45,7 @@ jobs:
|
|||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
key: ${{ steps.cache-key.outputs.key }}
|
||||||
@@ -71,10 +62,10 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -82,7 +73,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -99,20 +90,17 @@ jobs:
|
|||||||
chromatic:
|
chromatic:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: setup
|
needs: setup
|
||||||
# Disabled: to re-enable, remove 'false &&' from the condition below
|
# Only run on dev branch pushes or PRs targeting dev
|
||||||
if: >-
|
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
|
||||||
false
|
|
||||||
&& (github.ref == 'refs/heads/dev' || github.base_ref == 'dev')
|
|
||||||
&& needs.setup.outputs.components-changed == 'true'
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -120,7 +108,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -148,12 +136,12 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -176,7 +164,7 @@ jobs:
|
|||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
- name: Cache Docker layers
|
- name: Cache Docker layers
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: /tmp/.buildx-cache
|
path: /tmp/.buildx-cache
|
||||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -231,7 +219,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -277,12 +265,12 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -290,7 +278,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
|||||||
16
.github/workflows/platform-fullstack-ci.yml
vendored
16
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -29,10 +29,10 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ jobs:
|
|||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
key: ${{ steps.cache-key.outputs.key }}
|
||||||
@@ -56,19 +56,19 @@ jobs:
|
|||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
types:
|
types:
|
||||||
runs-on: big-boi
|
runs-on: ubuntu-latest
|
||||||
needs: setup
|
needs: setup
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -85,10 +85,10 @@ jobs:
|
|||||||
|
|
||||||
- name: Run docker compose
|
- name: Run docker compose
|
||||||
run: |
|
run: |
|
||||||
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
|
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
|||||||
2
.github/workflows/repo-workflow-checker.yml
vendored
2
.github/workflows/repo-workflow-checker.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
# - name: Wait some time for all actions to start
|
# - name: Wait some time for all actions to start
|
||||||
# run: sleep 30
|
# run: sleep 30
|
||||||
- uses: actions/checkout@v6
|
- uses: actions/checkout@v4
|
||||||
# with:
|
# with:
|
||||||
# fetch-depth: 0
|
# fetch-depth: 0
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
|
|||||||
1854
autogpt_platform/autogpt_libs/poetry.lock
generated
1854
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.10,<4.0"
|
python = ">=3.10,<4.0"
|
||||||
colorama = "^0.4.6"
|
colorama = "^0.4.6"
|
||||||
cryptography = "^46.0"
|
cryptography = "^45.0"
|
||||||
expiringdict = "^1.2.2"
|
expiringdict = "^1.2.2"
|
||||||
fastapi = "^0.128.0"
|
fastapi = "^0.116.1"
|
||||||
google-cloud-logging = "^3.13.0"
|
google-cloud-logging = "^3.12.1"
|
||||||
launchdarkly-server-sdk = "^9.14.1"
|
launchdarkly-server-sdk = "^9.12.0"
|
||||||
pydantic = "^2.12.5"
|
pydantic = "^2.11.7"
|
||||||
pydantic-settings = "^2.12.0"
|
pydantic-settings = "^2.10.1"
|
||||||
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
supabase = "^2.27.2"
|
supabase = "^2.16.0"
|
||||||
uvicorn = "^0.40.0"
|
uvicorn = "^0.35.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pyright = "^1.1.408"
|
pyright = "^1.1.404"
|
||||||
pytest = "^8.4.1"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.3.0"
|
pytest-asyncio = "^1.1.0"
|
||||||
pytest-mock = "^3.15.1"
|
pytest-mock = "^3.14.1"
|
||||||
pytest-cov = "^7.0.0"
|
pytest-cov = "^6.2.1"
|
||||||
ruff = "^0.15.0"
|
ruff = "^0.12.11"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|||||||
@@ -152,7 +152,6 @@ REPLICATE_API_KEY=
|
|||||||
REVID_API_KEY=
|
REVID_API_KEY=
|
||||||
SCREENSHOTONE_API_KEY=
|
SCREENSHOTONE_API_KEY=
|
||||||
UNREAL_SPEECH_API_KEY=
|
UNREAL_SPEECH_API_KEY=
|
||||||
ELEVENLABS_API_KEY=
|
|
||||||
|
|
||||||
# Data & Search Services
|
# Data & Search Services
|
||||||
E2B_API_KEY=
|
E2B_API_KEY=
|
||||||
|
|||||||
3
autogpt_platform/backend/.gitignore
vendored
3
autogpt_platform/backend/.gitignore
vendored
@@ -19,6 +19,3 @@ load-tests/*.json
|
|||||||
load-tests/*.log
|
load-tests/*.log
|
||||||
load-tests/node_modules/*
|
load-tests/node_modules/*
|
||||||
migrations/*/rollback*.sql
|
migrations/*/rollback*.sql
|
||||||
|
|
||||||
# Workspace files
|
|
||||||
workspaces/
|
|
||||||
|
|||||||
@@ -62,12 +62,10 @@ ENV POETRY_HOME=/opt/poetry \
|
|||||||
DEBIAN_FRONTEND=noninteractive
|
DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
# Install Python without upgrading system-managed packages
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
python3.13 \
|
python3.13 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
ffmpeg \
|
|
||||||
imagemagick \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy only necessary files from builder
|
# Copy only necessary files from builder
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class ChatConfig(BaseSettings):
|
|||||||
|
|
||||||
# OpenAI API Configuration
|
# OpenAI API Configuration
|
||||||
model: str = Field(
|
model: str = Field(
|
||||||
default="anthropic/claude-opus-4.6", description="Default model to use"
|
default="anthropic/claude-opus-4.5", description="Default model to use"
|
||||||
)
|
)
|
||||||
title_model: str = Field(
|
title_model: str = Field(
|
||||||
default="openai/gpt-4o-mini",
|
default="openai/gpt-4o-mini",
|
||||||
@@ -93,12 +93,6 @@ class ChatConfig(BaseSettings):
|
|||||||
description="Name of the prompt in Langfuse to fetch",
|
description="Name of the prompt in Langfuse to fetch",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extended thinking configuration for Claude models
|
|
||||||
thinking_enabled: bool = Field(
|
|
||||||
default=True,
|
|
||||||
description="Enable adaptive thinking for Claude models via OpenRouter",
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("api_key", mode="before")
|
@field_validator("api_key", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_api_key(cls, v):
|
def get_api_key(cls, v):
|
||||||
|
|||||||
@@ -45,7 +45,10 @@ async def create_chat_session(
|
|||||||
successfulAgentRuns=SafeJson({}),
|
successfulAgentRuns=SafeJson({}),
|
||||||
successfulAgentSchedules=SafeJson({}),
|
successfulAgentSchedules=SafeJson({}),
|
||||||
)
|
)
|
||||||
return await PrismaChatSession.prisma().create(data=data)
|
return await PrismaChatSession.prisma().create(
|
||||||
|
data=data,
|
||||||
|
include={"Messages": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def update_chat_session(
|
async def update_chat_session(
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any, cast
|
from typing import Any
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
@@ -104,26 +104,6 @@ class ChatSession(BaseModel):
|
|||||||
successful_agent_runs: dict[str, int] = {}
|
successful_agent_runs: dict[str, int] = {}
|
||||||
successful_agent_schedules: dict[str, int] = {}
|
successful_agent_schedules: dict[str, int] = {}
|
||||||
|
|
||||||
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
|
||||||
"""Attach a tool_call to the current turn's assistant message.
|
|
||||||
|
|
||||||
Searches backwards for the most recent assistant message (stopping at
|
|
||||||
any user message boundary). If found, appends the tool_call to it.
|
|
||||||
Otherwise creates a new assistant message with the tool_call.
|
|
||||||
"""
|
|
||||||
for msg in reversed(self.messages):
|
|
||||||
if msg.role == "user":
|
|
||||||
break
|
|
||||||
if msg.role == "assistant":
|
|
||||||
if not msg.tool_calls:
|
|
||||||
msg.tool_calls = []
|
|
||||||
msg.tool_calls.append(tool_call)
|
|
||||||
return
|
|
||||||
|
|
||||||
self.messages.append(
|
|
||||||
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def new(user_id: str) -> "ChatSession":
|
def new(user_id: str) -> "ChatSession":
|
||||||
return ChatSession(
|
return ChatSession(
|
||||||
@@ -192,47 +172,6 @@ class ChatSession(BaseModel):
|
|||||||
successful_agent_schedules=successful_agent_schedules,
|
successful_agent_schedules=successful_agent_schedules,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _merge_consecutive_assistant_messages(
|
|
||||||
messages: list[ChatCompletionMessageParam],
|
|
||||||
) -> list[ChatCompletionMessageParam]:
|
|
||||||
"""Merge consecutive assistant messages into single messages.
|
|
||||||
|
|
||||||
Long-running tool flows can create split assistant messages: one with
|
|
||||||
text content and another with tool_calls. Anthropic's API requires
|
|
||||||
tool_result blocks to reference a tool_use in the immediately preceding
|
|
||||||
assistant message, so these splits cause 400 errors via OpenRouter.
|
|
||||||
"""
|
|
||||||
if len(messages) < 2:
|
|
||||||
return messages
|
|
||||||
|
|
||||||
result: list[ChatCompletionMessageParam] = [messages[0]]
|
|
||||||
for msg in messages[1:]:
|
|
||||||
prev = result[-1]
|
|
||||||
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
|
|
||||||
result.append(msg)
|
|
||||||
continue
|
|
||||||
|
|
||||||
prev = cast(ChatCompletionAssistantMessageParam, prev)
|
|
||||||
curr = cast(ChatCompletionAssistantMessageParam, msg)
|
|
||||||
|
|
||||||
curr_content = curr.get("content") or ""
|
|
||||||
if curr_content:
|
|
||||||
prev_content = prev.get("content") or ""
|
|
||||||
prev["content"] = (
|
|
||||||
f"{prev_content}\n{curr_content}" if prev_content else curr_content
|
|
||||||
)
|
|
||||||
|
|
||||||
curr_tool_calls = curr.get("tool_calls")
|
|
||||||
if curr_tool_calls:
|
|
||||||
prev_tool_calls = prev.get("tool_calls")
|
|
||||||
prev["tool_calls"] = (
|
|
||||||
list(prev_tool_calls) + list(curr_tool_calls)
|
|
||||||
if prev_tool_calls
|
|
||||||
else list(curr_tool_calls)
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||||
messages = []
|
messages = []
|
||||||
for message in self.messages:
|
for message in self.messages:
|
||||||
@@ -319,7 +258,7 @@ class ChatSession(BaseModel):
|
|||||||
name=message.name or "",
|
name=message.name or "",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return self._merge_consecutive_assistant_messages(messages)
|
return messages
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||||
|
|||||||
@@ -1,16 +1,4 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from openai.types.chat import (
|
|
||||||
ChatCompletionAssistantMessageParam,
|
|
||||||
ChatCompletionMessageParam,
|
|
||||||
ChatCompletionToolMessageParam,
|
|
||||||
ChatCompletionUserMessageParam,
|
|
||||||
)
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
|
||||||
ChatCompletionMessageToolCallParam,
|
|
||||||
Function,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .model import (
|
from .model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
@@ -129,205 +117,3 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
|||||||
loaded.tool_calls is not None
|
loaded.tool_calls is not None
|
||||||
), f"Tool calls missing for {orig.role} message"
|
), f"Tool calls missing for {orig.role} message"
|
||||||
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------------- #
|
|
||||||
# _merge_consecutive_assistant_messages #
|
|
||||||
# --------------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
_tc = ChatCompletionMessageToolCallParam(
|
|
||||||
id="tc1", type="function", function=Function(name="do_stuff", arguments="{}")
|
|
||||||
)
|
|
||||||
_tc2 = ChatCompletionMessageToolCallParam(
|
|
||||||
id="tc2", type="function", function=Function(name="other", arguments="{}")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_merge_noop_when_no_consecutive_assistants():
|
|
||||||
"""Messages without consecutive assistants are returned unchanged."""
|
|
||||||
msgs = [
|
|
||||||
ChatCompletionUserMessageParam(role="user", content="hi"),
|
|
||||||
ChatCompletionAssistantMessageParam(role="assistant", content="hello"),
|
|
||||||
ChatCompletionUserMessageParam(role="user", content="bye"),
|
|
||||||
]
|
|
||||||
merged = ChatSession._merge_consecutive_assistant_messages(msgs)
|
|
||||||
assert len(merged) == 3
|
|
||||||
assert [m["role"] for m in merged] == ["user", "assistant", "user"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_merge_splits_text_and_tool_calls():
|
|
||||||
"""The exact bug scenario: text-only assistant followed by tool_calls-only assistant."""
|
|
||||||
msgs = [
|
|
||||||
ChatCompletionUserMessageParam(role="user", content="build agent"),
|
|
||||||
ChatCompletionAssistantMessageParam(
|
|
||||||
role="assistant", content="Let me build that"
|
|
||||||
),
|
|
||||||
ChatCompletionAssistantMessageParam(
|
|
||||||
role="assistant", content="", tool_calls=[_tc]
|
|
||||||
),
|
|
||||||
ChatCompletionToolMessageParam(role="tool", content="ok", tool_call_id="tc1"),
|
|
||||||
]
|
|
||||||
merged = ChatSession._merge_consecutive_assistant_messages(msgs)
|
|
||||||
|
|
||||||
assert len(merged) == 3
|
|
||||||
assert merged[0]["role"] == "user"
|
|
||||||
assert merged[2]["role"] == "tool"
|
|
||||||
a = cast(ChatCompletionAssistantMessageParam, merged[1])
|
|
||||||
assert a["role"] == "assistant"
|
|
||||||
assert a.get("content") == "Let me build that"
|
|
||||||
assert a.get("tool_calls") == [_tc]
|
|
||||||
|
|
||||||
|
|
||||||
def test_merge_combines_tool_calls_from_both():
|
|
||||||
"""Both consecutive assistants have tool_calls — they get merged."""
|
|
||||||
msgs: list[ChatCompletionAssistantMessageParam] = [
|
|
||||||
ChatCompletionAssistantMessageParam(
|
|
||||||
role="assistant", content="text", tool_calls=[_tc]
|
|
||||||
),
|
|
||||||
ChatCompletionAssistantMessageParam(
|
|
||||||
role="assistant", content="", tool_calls=[_tc2]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
merged = ChatSession._merge_consecutive_assistant_messages(msgs) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
assert len(merged) == 1
|
|
||||||
a = cast(ChatCompletionAssistantMessageParam, merged[0])
|
|
||||||
assert a.get("tool_calls") == [_tc, _tc2]
|
|
||||||
assert a.get("content") == "text"
|
|
||||||
|
|
||||||
|
|
||||||
def test_merge_three_consecutive_assistants():
|
|
||||||
"""Three consecutive assistants collapse into one."""
|
|
||||||
msgs: list[ChatCompletionAssistantMessageParam] = [
|
|
||||||
ChatCompletionAssistantMessageParam(role="assistant", content="a"),
|
|
||||||
ChatCompletionAssistantMessageParam(role="assistant", content="b"),
|
|
||||||
ChatCompletionAssistantMessageParam(
|
|
||||||
role="assistant", content="", tool_calls=[_tc]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
merged = ChatSession._merge_consecutive_assistant_messages(msgs) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
assert len(merged) == 1
|
|
||||||
a = cast(ChatCompletionAssistantMessageParam, merged[0])
|
|
||||||
assert a.get("content") == "a\nb"
|
|
||||||
assert a.get("tool_calls") == [_tc]
|
|
||||||
|
|
||||||
|
|
||||||
def test_merge_empty_and_single_message():
|
|
||||||
"""Edge cases: empty list and single message."""
|
|
||||||
assert ChatSession._merge_consecutive_assistant_messages([]) == []
|
|
||||||
|
|
||||||
single: list[ChatCompletionMessageParam] = [
|
|
||||||
ChatCompletionUserMessageParam(role="user", content="hi")
|
|
||||||
]
|
|
||||||
assert ChatSession._merge_consecutive_assistant_messages(single) == single
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------------- #
|
|
||||||
# add_tool_call_to_current_turn #
|
|
||||||
# --------------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
_raw_tc = {
|
|
||||||
"id": "tc1",
|
|
||||||
"type": "function",
|
|
||||||
"function": {"name": "f", "arguments": "{}"},
|
|
||||||
}
|
|
||||||
_raw_tc2 = {
|
|
||||||
"id": "tc2",
|
|
||||||
"type": "function",
|
|
||||||
"function": {"name": "g", "arguments": "{}"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_tool_call_appends_to_existing_assistant():
|
|
||||||
"""When the last assistant is from the current turn, tool_call is added to it."""
|
|
||||||
session = ChatSession.new(user_id="u")
|
|
||||||
session.messages = [
|
|
||||||
ChatMessage(role="user", content="hi"),
|
|
||||||
ChatMessage(role="assistant", content="working on it"),
|
|
||||||
]
|
|
||||||
session.add_tool_call_to_current_turn(_raw_tc)
|
|
||||||
|
|
||||||
assert len(session.messages) == 2 # no new message created
|
|
||||||
assert session.messages[1].tool_calls == [_raw_tc]
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_tool_call_creates_assistant_when_none_exists():
|
|
||||||
"""When there's no current-turn assistant, a new one is created."""
|
|
||||||
session = ChatSession.new(user_id="u")
|
|
||||||
session.messages = [
|
|
||||||
ChatMessage(role="user", content="hi"),
|
|
||||||
]
|
|
||||||
session.add_tool_call_to_current_turn(_raw_tc)
|
|
||||||
|
|
||||||
assert len(session.messages) == 2
|
|
||||||
assert session.messages[1].role == "assistant"
|
|
||||||
assert session.messages[1].tool_calls == [_raw_tc]
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_tool_call_does_not_cross_user_boundary():
|
|
||||||
"""A user message acts as a boundary — previous assistant is not modified."""
|
|
||||||
session = ChatSession.new(user_id="u")
|
|
||||||
session.messages = [
|
|
||||||
ChatMessage(role="assistant", content="old turn"),
|
|
||||||
ChatMessage(role="user", content="new message"),
|
|
||||||
]
|
|
||||||
session.add_tool_call_to_current_turn(_raw_tc)
|
|
||||||
|
|
||||||
assert len(session.messages) == 3 # new assistant was created
|
|
||||||
assert session.messages[0].tool_calls is None # old assistant untouched
|
|
||||||
assert session.messages[2].role == "assistant"
|
|
||||||
assert session.messages[2].tool_calls == [_raw_tc]
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_tool_call_multiple_times():
|
|
||||||
"""Multiple long-running tool calls accumulate on the same assistant."""
|
|
||||||
session = ChatSession.new(user_id="u")
|
|
||||||
session.messages = [
|
|
||||||
ChatMessage(role="user", content="hi"),
|
|
||||||
ChatMessage(role="assistant", content="doing stuff"),
|
|
||||||
]
|
|
||||||
session.add_tool_call_to_current_turn(_raw_tc)
|
|
||||||
# Simulate a pending tool result in between (like _yield_tool_call does)
|
|
||||||
session.messages.append(
|
|
||||||
ChatMessage(role="tool", content="pending", tool_call_id="tc1")
|
|
||||||
)
|
|
||||||
session.add_tool_call_to_current_turn(_raw_tc2)
|
|
||||||
|
|
||||||
assert len(session.messages) == 3 # user, assistant, tool — no extra assistant
|
|
||||||
assert session.messages[1].tool_calls == [_raw_tc, _raw_tc2]
|
|
||||||
|
|
||||||
|
|
||||||
def test_to_openai_messages_merges_split_assistants():
|
|
||||||
"""End-to-end: session with split assistants produces valid OpenAI messages."""
|
|
||||||
session = ChatSession.new(user_id="u")
|
|
||||||
session.messages = [
|
|
||||||
ChatMessage(role="user", content="build agent"),
|
|
||||||
ChatMessage(role="assistant", content="Let me build that"),
|
|
||||||
ChatMessage(
|
|
||||||
role="assistant",
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"id": "tc1",
|
|
||||||
"type": "function",
|
|
||||||
"function": {"name": "create_agent", "arguments": "{}"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
),
|
|
||||||
ChatMessage(role="tool", content="done", tool_call_id="tc1"),
|
|
||||||
ChatMessage(role="assistant", content="Saved!"),
|
|
||||||
ChatMessage(role="user", content="show me an example run"),
|
|
||||||
]
|
|
||||||
openai_msgs = session.to_openai_messages()
|
|
||||||
|
|
||||||
# The two consecutive assistants at index 1,2 should be merged
|
|
||||||
roles = [m["role"] for m in openai_msgs]
|
|
||||||
assert roles == ["user", "assistant", "tool", "assistant", "user"]
|
|
||||||
|
|
||||||
# The merged assistant should have both content and tool_calls
|
|
||||||
merged = cast(ChatCompletionAssistantMessageParam, openai_msgs[1])
|
|
||||||
assert merged.get("content") == "Let me build that"
|
|
||||||
tc_list = merged.get("tool_calls")
|
|
||||||
assert tc_list is not None and len(list(tc_list)) == 1
|
|
||||||
assert list(tc_list)[0]["id"] == "tc1"
|
|
||||||
|
|||||||
@@ -10,8 +10,6 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.util.json import dumps as json_dumps
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseType(str, Enum):
|
class ResponseType(str, Enum):
|
||||||
"""Types of streaming responses following AI SDK protocol."""
|
"""Types of streaming responses following AI SDK protocol."""
|
||||||
@@ -20,10 +18,6 @@ class ResponseType(str, Enum):
|
|||||||
START = "start"
|
START = "start"
|
||||||
FINISH = "finish"
|
FINISH = "finish"
|
||||||
|
|
||||||
# Step lifecycle (one LLM API call within a message)
|
|
||||||
START_STEP = "start-step"
|
|
||||||
FINISH_STEP = "finish-step"
|
|
||||||
|
|
||||||
# Text streaming
|
# Text streaming
|
||||||
TEXT_START = "text-start"
|
TEXT_START = "text-start"
|
||||||
TEXT_DELTA = "text-delta"
|
TEXT_DELTA = "text-delta"
|
||||||
@@ -63,16 +57,6 @@ class StreamStart(StreamBaseResponse):
|
|||||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_sse(self) -> str:
|
|
||||||
"""Convert to SSE format, excluding non-protocol fields like taskId."""
|
|
||||||
import json
|
|
||||||
|
|
||||||
data: dict[str, Any] = {
|
|
||||||
"type": self.type.value,
|
|
||||||
"messageId": self.messageId,
|
|
||||||
}
|
|
||||||
return f"data: {json.dumps(data)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
class StreamFinish(StreamBaseResponse):
|
class StreamFinish(StreamBaseResponse):
|
||||||
"""End of message/stream."""
|
"""End of message/stream."""
|
||||||
@@ -80,26 +64,6 @@ class StreamFinish(StreamBaseResponse):
|
|||||||
type: ResponseType = ResponseType.FINISH
|
type: ResponseType = ResponseType.FINISH
|
||||||
|
|
||||||
|
|
||||||
class StreamStartStep(StreamBaseResponse):
|
|
||||||
"""Start of a step (one LLM API call within a message).
|
|
||||||
|
|
||||||
The AI SDK uses this to add a step-start boundary to message.parts,
|
|
||||||
enabling visual separation between multiple LLM calls in a single message.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.START_STEP
|
|
||||||
|
|
||||||
|
|
||||||
class StreamFinishStep(StreamBaseResponse):
|
|
||||||
"""End of a step (one LLM API call within a message).
|
|
||||||
|
|
||||||
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
|
|
||||||
so the next LLM call in a tool-call continuation starts with clean state.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.FINISH_STEP
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Text Streaming ==========
|
# ========== Text Streaming ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -153,7 +117,7 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
|||||||
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
||||||
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
||||||
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
||||||
# Keep these for internal backend use
|
# Additional fields for internal use (not part of AI SDK spec but useful)
|
||||||
toolName: str | None = Field(
|
toolName: str | None = Field(
|
||||||
default=None, description="Name of the tool that was executed"
|
default=None, description="Name of the tool that was executed"
|
||||||
)
|
)
|
||||||
@@ -161,17 +125,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
|||||||
default=True, description="Whether the tool execution succeeded"
|
default=True, description="Whether the tool execution succeeded"
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_sse(self) -> str:
|
|
||||||
"""Convert to SSE format, excluding non-spec fields."""
|
|
||||||
import json
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"type": self.type.value,
|
|
||||||
"toolCallId": self.toolCallId,
|
|
||||||
"output": self.output,
|
|
||||||
}
|
|
||||||
return f"data: {json.dumps(data)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Other ==========
|
# ========== Other ==========
|
||||||
|
|
||||||
@@ -195,18 +148,6 @@ class StreamError(StreamBaseResponse):
|
|||||||
default=None, description="Additional error details"
|
default=None, description="Additional error details"
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_sse(self) -> str:
|
|
||||||
"""Convert to SSE format, only emitting fields required by AI SDK protocol.
|
|
||||||
|
|
||||||
The AI SDK uses z.strictObject({type, errorText}) which rejects
|
|
||||||
any extra fields like `code` or `details`.
|
|
||||||
"""
|
|
||||||
data = {
|
|
||||||
"type": self.type.value,
|
|
||||||
"errorText": self.errorText,
|
|
||||||
}
|
|
||||||
return f"data: {json_dumps(data)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
class StreamHeartbeat(StreamBaseResponse):
|
class StreamHeartbeat(StreamBaseResponse):
|
||||||
"""Heartbeat to keep SSE connection alive during long-running operations.
|
"""Heartbeat to keep SSE connection alive during long-running operations.
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from autogpt_libs import auth
|
from autogpt_libs import auth
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
|
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -17,29 +17,7 @@ from . import stream_registry
|
|||||||
from .completion_handler import process_operation_failure, process_operation_success
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||||
from .response_model import StreamFinish, StreamHeartbeat
|
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
||||||
from .tools.models import (
|
|
||||||
AgentDetailsResponse,
|
|
||||||
AgentOutputResponse,
|
|
||||||
AgentPreviewResponse,
|
|
||||||
AgentSavedResponse,
|
|
||||||
AgentsFoundResponse,
|
|
||||||
BlockListResponse,
|
|
||||||
BlockOutputResponse,
|
|
||||||
ClarificationNeededResponse,
|
|
||||||
DocPageResponse,
|
|
||||||
DocSearchResultsResponse,
|
|
||||||
ErrorResponse,
|
|
||||||
ExecutionStartedResponse,
|
|
||||||
InputValidationErrorResponse,
|
|
||||||
NeedLoginResponse,
|
|
||||||
NoResultsResponse,
|
|
||||||
OperationInProgressResponse,
|
|
||||||
OperationPendingResponse,
|
|
||||||
OperationStartedResponse,
|
|
||||||
SetupRequirementsResponse,
|
|
||||||
UnderstandingUpdatedResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -288,36 +266,12 @@ async def stream_chat_post(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
|
||||||
|
|
||||||
stream_start_time = time.perf_counter()
|
|
||||||
log_meta = {"component": "ChatStream", "session_id": session_id}
|
|
||||||
if user_id:
|
|
||||||
log_meta["user_id"] = user_id
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
|
|
||||||
f"user={user_id}, message_len={len(request.message)}",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
|
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
logger.info(
|
|
||||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Create a task in the stream registry for reconnection support
|
||||||
task_id = str(uuid_module.uuid4())
|
task_id = str(uuid_module.uuid4())
|
||||||
operation_id = str(uuid_module.uuid4())
|
operation_id = str(uuid_module.uuid4())
|
||||||
log_meta["task_id"] = task_id
|
|
||||||
|
|
||||||
task_create_start = time.perf_counter()
|
|
||||||
await stream_registry.create_task(
|
await stream_registry.create_task(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -326,28 +280,14 @@ async def stream_chat_post(
|
|||||||
tool_name="chat",
|
tool_name="chat",
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
)
|
)
|
||||||
logger.info(
|
|
||||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Background task that runs the AI generation independently of SSE connection
|
# Background task that runs the AI generation independently of SSE connection
|
||||||
async def run_ai_generation():
|
async def run_ai_generation():
|
||||||
import time as time_module
|
|
||||||
|
|
||||||
gen_start_time = time_module.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
first_chunk_time, ttfc = None, None
|
|
||||||
chunk_count = 0
|
|
||||||
try:
|
try:
|
||||||
|
# Emit a start event with task_id for reconnection
|
||||||
|
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||||
|
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||||
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
session_id,
|
session_id,
|
||||||
request.message,
|
request.message,
|
||||||
@@ -355,79 +295,25 @@ async def stream_chat_post(
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
context=request.context,
|
context=request.context,
|
||||||
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
|
|
||||||
):
|
):
|
||||||
chunk_count += 1
|
|
||||||
if first_chunk_time is None:
|
|
||||||
first_chunk_time = time_module.perf_counter()
|
|
||||||
ttfc = first_chunk_time - gen_start_time
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunk_type": type(chunk).__name__,
|
|
||||||
"time_to_first_chunk_ms": ttfc * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Write to Redis (subscribers will receive via XREAD)
|
# Write to Redis (subscribers will receive via XREAD)
|
||||||
await stream_registry.publish_chunk(task_id, chunk)
|
await stream_registry.publish_chunk(task_id, chunk)
|
||||||
|
|
||||||
gen_end_time = time_module.perf_counter()
|
# Mark task as completed
|
||||||
total_time = (gen_end_time - gen_start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
|
||||||
f"task={task_id}, session={session_id}, "
|
|
||||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"time_to_first_chunk_ms": (
|
|
||||||
ttfc * 1000 if ttfc is not None else None
|
|
||||||
),
|
|
||||||
"n_chunks": chunk_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await stream_registry.mark_task_completed(task_id, "completed")
|
await stream_registry.mark_task_completed(task_id, "completed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed = time_module.perf_counter() - gen_start_time
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
f"Error in background AI generation for session {session_id}: {e}"
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed * 1000,
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
await stream_registry.mark_task_completed(task_id, "failed")
|
await stream_registry.mark_task_completed(task_id, "failed")
|
||||||
|
|
||||||
# Start the AI generation in a background task
|
# Start the AI generation in a background task
|
||||||
bg_task = asyncio.create_task(run_ai_generation())
|
bg_task = asyncio.create_task(run_ai_generation())
|
||||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
# SSE endpoint that subscribes to the task's stream
|
# SSE endpoint that subscribes to the task's stream
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
import time as time_module
|
|
||||||
|
|
||||||
event_gen_start = time_module.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
|
||||||
f"user={user_id}",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
subscriber_queue = None
|
subscriber_queue = None
|
||||||
first_chunk_yielded = False
|
|
||||||
chunks_yielded = 0
|
|
||||||
try:
|
try:
|
||||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
@@ -442,70 +328,22 @@ async def stream_chat_post(
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Read from the subscriber queue and yield to SSE
|
# Read from the subscriber queue and yield to SSE
|
||||||
logger.info(
|
|
||||||
"[TIMING] Starting to read from subscriber_queue",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||||
chunks_yielded += 1
|
|
||||||
|
|
||||||
if not first_chunk_yielded:
|
|
||||||
first_chunk_yielded = True
|
|
||||||
elapsed = time_module.perf_counter() - event_gen_start
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
|
|
||||||
f"type={type(chunk).__name__}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunk_type": type(chunk).__name__,
|
|
||||||
"elapsed_ms": elapsed * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
yield chunk.to_sse()
|
yield chunk.to_sse()
|
||||||
|
|
||||||
# Check for finish signal
|
# Check for finish signal
|
||||||
if isinstance(chunk, StreamFinish):
|
if isinstance(chunk, StreamFinish):
|
||||||
total_time = time_module.perf_counter() - event_gen_start
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
|
|
||||||
f"n_chunks={chunks_yielded}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunks_yielded": chunks_yielded,
|
|
||||||
"total_time_ms": total_time * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
# Send heartbeat to keep connection alive
|
||||||
yield StreamHeartbeat().to_sse()
|
yield StreamHeartbeat().to_sse()
|
||||||
|
|
||||||
except GeneratorExit:
|
except GeneratorExit:
|
||||||
logger.info(
|
|
||||||
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunks_yielded": chunks_yielded,
|
|
||||||
"reason": "client_disconnect",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
pass # Client disconnected - background task continues
|
pass # Client disconnected - background task continues
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
||||||
logger.error(
|
|
||||||
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||||
if subscriber_queue is not None:
|
if subscriber_queue is not None:
|
||||||
@@ -519,18 +357,6 @@ async def stream_chat_post(
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||||
total_time = time_module.perf_counter() - event_gen_start
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
|
||||||
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time * 1000,
|
|
||||||
"chunks_yielded": chunks_yielded,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
@@ -548,90 +374,63 @@ async def stream_chat_post(
|
|||||||
@router.get(
|
@router.get(
|
||||||
"/sessions/{session_id}/stream",
|
"/sessions/{session_id}/stream",
|
||||||
)
|
)
|
||||||
async def resume_session_stream(
|
async def stream_chat_get(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
||||||
user_id: str | None = Depends(auth.get_user_id),
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
|
is_user_message: bool = Query(default=True),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Resume an active stream for a session.
|
Stream chat responses for a session (GET - legacy endpoint).
|
||||||
|
|
||||||
Called by the AI SDK's ``useChat(resume: true)`` on page load.
|
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||||
Checks for an active (in-progress) task on the session and either replays
|
- Text fragments as they are generated
|
||||||
the full SSE stream or returns 204 No Content if nothing is running.
|
- Tool call UI elements (if invoked)
|
||||||
|
- Tool execution results
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier.
|
session_id: The chat session identifier to associate with the streamed messages.
|
||||||
|
message: The user's new message to process.
|
||||||
user_id: Optional authenticated user ID.
|
user_id: Optional authenticated user ID.
|
||||||
|
is_user_message: Whether the message is a user message.
|
||||||
Returns:
|
Returns:
|
||||||
StreamingResponse (SSE) when an active stream exists,
|
StreamingResponse: SSE-formatted response chunks.
|
||||||
or 204 No Content when there is nothing to resume.
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
active_task, _last_id = await stream_registry.get_active_task_for_session(
|
|
||||||
session_id, user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not active_task:
|
|
||||||
return Response(status_code=204)
|
|
||||||
|
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
|
||||||
task_id=active_task.task_id,
|
|
||||||
user_id=user_id,
|
|
||||||
last_message_id="0-0", # Full replay so useChat rebuilds the message
|
|
||||||
)
|
|
||||||
|
|
||||||
if subscriber_queue is None:
|
|
||||||
return Response(status_code=204)
|
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
first_chunk_type: str | None = None
|
first_chunk_type: str | None = None
|
||||||
try:
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
while True:
|
session_id,
|
||||||
try:
|
message,
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
is_user_message=is_user_message,
|
||||||
if chunk_count < 3:
|
user_id=user_id,
|
||||||
logger.info(
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
"Resume stream chunk",
|
):
|
||||||
extra={
|
if chunk_count < 3:
|
||||||
"session_id": session_id,
|
logger.info(
|
||||||
"chunk_type": str(chunk.type),
|
"Chat stream chunk",
|
||||||
},
|
extra={
|
||||||
)
|
"session_id": session_id,
|
||||||
if not first_chunk_type:
|
"chunk_type": str(chunk.type),
|
||||||
first_chunk_type = str(chunk.type)
|
},
|
||||||
chunk_count += 1
|
|
||||||
yield chunk.to_sse()
|
|
||||||
|
|
||||||
if isinstance(chunk, StreamFinish):
|
|
||||||
break
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
yield StreamHeartbeat().to_sse()
|
|
||||||
except GeneratorExit:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
await stream_registry.unsubscribe_from_task(
|
|
||||||
active_task.task_id, subscriber_queue
|
|
||||||
)
|
)
|
||||||
except Exception as unsub_err:
|
if not first_chunk_type:
|
||||||
logger.error(
|
first_chunk_type = str(chunk.type)
|
||||||
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
|
chunk_count += 1
|
||||||
exc_info=True,
|
yield chunk.to_sse()
|
||||||
)
|
logger.info(
|
||||||
logger.info(
|
"Chat stream completed",
|
||||||
"Resume stream completed",
|
extra={
|
||||||
extra={
|
"session_id": session_id,
|
||||||
"session_id": session_id,
|
"chunk_count": chunk_count,
|
||||||
"n_chunks": chunk_count,
|
"first_chunk_type": first_chunk_type,
|
||||||
"first_chunk_type": first_chunk_type,
|
},
|
||||||
},
|
)
|
||||||
)
|
# AI SDK protocol termination
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
event_generator(),
|
||||||
@@ -639,8 +438,8 @@ async def resume_session_stream(
|
|||||||
headers={
|
headers={
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
"X-Accel-Buffering": "no",
|
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||||
"x-vercel-ai-ui-message-stream": "v1",
|
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -952,42 +751,3 @@ async def health_check() -> dict:
|
|||||||
"service": "chat",
|
"service": "chat",
|
||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
|
|
||||||
|
|
||||||
ToolResponseUnion = (
|
|
||||||
AgentsFoundResponse
|
|
||||||
| NoResultsResponse
|
|
||||||
| AgentDetailsResponse
|
|
||||||
| SetupRequirementsResponse
|
|
||||||
| ExecutionStartedResponse
|
|
||||||
| NeedLoginResponse
|
|
||||||
| ErrorResponse
|
|
||||||
| InputValidationErrorResponse
|
|
||||||
| AgentOutputResponse
|
|
||||||
| UnderstandingUpdatedResponse
|
|
||||||
| AgentPreviewResponse
|
|
||||||
| AgentSavedResponse
|
|
||||||
| ClarificationNeededResponse
|
|
||||||
| BlockListResponse
|
|
||||||
| BlockOutputResponse
|
|
||||||
| DocSearchResultsResponse
|
|
||||||
| DocPageResponse
|
|
||||||
| OperationStartedResponse
|
|
||||||
| OperationPendingResponse
|
|
||||||
| OperationInProgressResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/schema/tool-responses",
|
|
||||||
response_model=ToolResponseUnion,
|
|
||||||
include_in_schema=True,
|
|
||||||
summary="[Dummy] Tool response type export for codegen",
|
|
||||||
description="This endpoint is not meant to be called. It exists solely to "
|
|
||||||
"expose tool response models in the OpenAPI schema for frontend codegen.",
|
|
||||||
)
|
|
||||||
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
|
|
||||||
"""Never called at runtime. Exists only so Orval generates TS types."""
|
|
||||||
raise HTTPException(status_code=501, detail="Schema-only endpoint")
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from backend.data.understanding import (
|
|||||||
get_business_understanding,
|
get_business_understanding,
|
||||||
)
|
)
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.settings import AppEnvironment, Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from . import db as chat_db
|
from . import db as chat_db
|
||||||
from . import stream_registry
|
from . import stream_registry
|
||||||
@@ -52,10 +52,8 @@ from .response_model import (
|
|||||||
StreamBaseResponse,
|
StreamBaseResponse,
|
||||||
StreamError,
|
StreamError,
|
||||||
StreamFinish,
|
StreamFinish,
|
||||||
StreamFinishStep,
|
|
||||||
StreamHeartbeat,
|
StreamHeartbeat,
|
||||||
StreamStart,
|
StreamStart,
|
||||||
StreamStartStep,
|
|
||||||
StreamTextDelta,
|
StreamTextDelta,
|
||||||
StreamTextEnd,
|
StreamTextEnd,
|
||||||
StreamTextStart,
|
StreamTextStart,
|
||||||
@@ -224,18 +222,8 @@ async def _get_system_prompt_template(context: str) -> str:
|
|||||||
try:
|
try:
|
||||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||||
# Use asyncio.to_thread to avoid blocking the event loop
|
# Use asyncio.to_thread to avoid blocking the event loop
|
||||||
# In non-production environments, fetch the latest prompt version
|
|
||||||
# instead of the production-labeled version for easier testing
|
|
||||||
label = (
|
|
||||||
None
|
|
||||||
if settings.config.app_env == AppEnvironment.PRODUCTION
|
|
||||||
else "latest"
|
|
||||||
)
|
|
||||||
prompt = await asyncio.to_thread(
|
prompt = await asyncio.to_thread(
|
||||||
langfuse.get_prompt,
|
langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0
|
||||||
config.langfuse_prompt_name,
|
|
||||||
label=label,
|
|
||||||
cache_ttl_seconds=0,
|
|
||||||
)
|
)
|
||||||
return prompt.compile(users_information=context)
|
return prompt.compile(users_information=context)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -353,10 +341,6 @@ async def stream_chat_completion(
|
|||||||
retry_count: int = 0,
|
retry_count: int = 0,
|
||||||
session: ChatSession | None = None,
|
session: ChatSession | None = None,
|
||||||
context: dict[str, str] | None = None, # {url: str, content: str}
|
context: dict[str, str] | None = None, # {url: str, content: str}
|
||||||
_continuation_message_id: (
|
|
||||||
str | None
|
|
||||||
) = None, # Internal: reuse message ID for tool call continuations
|
|
||||||
_task_id: str | None = None, # Internal: task ID for SSE reconnection support
|
|
||||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
"""Main entry point for streaming chat completions with database handling.
|
"""Main entry point for streaming chat completions with database handling.
|
||||||
|
|
||||||
@@ -377,45 +361,21 @@ async def stream_chat_completion(
|
|||||||
ValueError: If max_context_messages is exceeded
|
ValueError: If max_context_messages is exceeded
|
||||||
|
|
||||||
"""
|
"""
|
||||||
completion_start = time.monotonic()
|
|
||||||
|
|
||||||
# Build log metadata for structured logging
|
|
||||||
log_meta = {"component": "ChatService", "session_id": session_id}
|
|
||||||
if user_id:
|
|
||||||
log_meta["user_id"] = user_id
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
|
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
||||||
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"message_len": len(message) if message else 0,
|
|
||||||
"is_user_message": is_user_message,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only fetch from Redis if session not provided (initial call)
|
# Only fetch from Redis if session not provided (initial call)
|
||||||
if session is None:
|
if session is None:
|
||||||
fetch_start = time.monotonic()
|
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
fetch_time = (time.monotonic() - fetch_start) * 1000
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
|
f"Fetched session from Redis: {session.session_id if session else 'None'}, "
|
||||||
f"n_messages={len(session.messages) if session else 0}",
|
f"message_count={len(session.messages) if session else 0}"
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"duration_ms": fetch_time,
|
|
||||||
"n_messages": len(session.messages) if session else 0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] Using provided session, messages={len(session.messages)}",
|
f"Using provided session object: {session.session_id}, "
|
||||||
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
f"message_count={len(session.messages)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
@@ -436,25 +396,17 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
# Track user message in PostHog
|
# Track user message in PostHog
|
||||||
if is_user_message:
|
if is_user_message:
|
||||||
posthog_start = time.monotonic()
|
|
||||||
track_user_message(
|
track_user_message(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
message_length=len(message),
|
message_length=len(message),
|
||||||
)
|
)
|
||||||
posthog_time = (time.monotonic() - posthog_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
upsert_start = time.monotonic()
|
|
||||||
session = await upsert_chat_session(session)
|
|
||||||
upsert_time = (time.monotonic() - upsert_start) * 1000
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
f"Upserting session: {session.session_id} with user id {session.user_id}, "
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
|
f"message_count={len(session.messages)}"
|
||||||
)
|
)
|
||||||
|
session = await upsert_chat_session(session)
|
||||||
assert session, "Session not found"
|
assert session, "Session not found"
|
||||||
|
|
||||||
# Generate title for new sessions on first user message (non-blocking)
|
# Generate title for new sessions on first user message (non-blocking)
|
||||||
@@ -492,13 +444,7 @@ async def stream_chat_completion(
|
|||||||
asyncio.create_task(_update_title())
|
asyncio.create_task(_update_title())
|
||||||
|
|
||||||
# Build system prompt with business understanding
|
# Build system prompt with business understanding
|
||||||
prompt_start = time.monotonic()
|
|
||||||
system_prompt, understanding = await _build_system_prompt(user_id)
|
system_prompt, understanding = await _build_system_prompt(user_id)
|
||||||
prompt_time = (time.monotonic() - prompt_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize variables for streaming
|
# Initialize variables for streaming
|
||||||
assistant_response = ChatMessage(
|
assistant_response = ChatMessage(
|
||||||
@@ -523,27 +469,13 @@ async def stream_chat_completion(
|
|||||||
# Generate unique IDs for AI SDK protocol
|
# Generate unique IDs for AI SDK protocol
|
||||||
import uuid as uuid_module
|
import uuid as uuid_module
|
||||||
|
|
||||||
is_continuation = _continuation_message_id is not None
|
message_id = str(uuid_module.uuid4())
|
||||||
message_id = _continuation_message_id or str(uuid_module.uuid4())
|
|
||||||
text_block_id = str(uuid_module.uuid4())
|
text_block_id = str(uuid_module.uuid4())
|
||||||
|
|
||||||
# Only yield message start for the initial call, not for continuations.
|
# Yield message start
|
||||||
setup_time = (time.monotonic() - completion_start) * 1000
|
yield StreamStart(messageId=message_id)
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
|
||||||
)
|
|
||||||
if not is_continuation:
|
|
||||||
yield StreamStart(messageId=message_id, taskId=_task_id)
|
|
||||||
|
|
||||||
# Emit start-step before each LLM call (AI SDK uses this to add step boundaries)
|
|
||||||
yield StreamStartStep()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(
|
|
||||||
"[TIMING] Calling _stream_chat_chunks",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
async for chunk in _stream_chat_chunks(
|
async for chunk in _stream_chat_chunks(
|
||||||
session=session,
|
session=session,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@@ -643,10 +575,6 @@ async def stream_chat_completion(
|
|||||||
)
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
elif isinstance(chunk, StreamFinish):
|
elif isinstance(chunk, StreamFinish):
|
||||||
if has_done_tool_call:
|
|
||||||
# Tool calls happened — close the step but don't send message-level finish.
|
|
||||||
# The continuation will open a new step, and finish will come at the end.
|
|
||||||
yield StreamFinishStep()
|
|
||||||
if not has_done_tool_call:
|
if not has_done_tool_call:
|
||||||
# Emit text-end before finish if we received text but haven't closed it
|
# Emit text-end before finish if we received text but haven't closed it
|
||||||
if has_received_text and not text_streaming_ended:
|
if has_received_text and not text_streaming_ended:
|
||||||
@@ -678,8 +606,6 @@ async def stream_chat_completion(
|
|||||||
has_saved_assistant_message = True
|
has_saved_assistant_message = True
|
||||||
|
|
||||||
has_yielded_end = True
|
has_yielded_end = True
|
||||||
# Emit finish-step before finish (resets AI SDK text/reasoning state)
|
|
||||||
yield StreamFinishStep()
|
|
||||||
yield chunk
|
yield chunk
|
||||||
elif isinstance(chunk, StreamError):
|
elif isinstance(chunk, StreamError):
|
||||||
has_yielded_error = True
|
has_yielded_error = True
|
||||||
@@ -692,9 +618,6 @@ async def stream_chat_completion(
|
|||||||
total_tokens=chunk.totalTokens,
|
total_tokens=chunk.totalTokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(chunk, StreamHeartbeat):
|
|
||||||
# Pass through heartbeat to keep SSE connection alive
|
|
||||||
yield chunk
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
||||||
|
|
||||||
@@ -729,10 +652,6 @@ async def stream_chat_completion(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
||||||
)
|
)
|
||||||
# Close the current step before retrying so the recursive call's
|
|
||||||
# StreamStartStep doesn't produce unbalanced step events.
|
|
||||||
if not has_yielded_end:
|
|
||||||
yield StreamFinishStep()
|
|
||||||
should_retry = True
|
should_retry = True
|
||||||
else:
|
else:
|
||||||
# Non-retryable error or max retries exceeded
|
# Non-retryable error or max retries exceeded
|
||||||
@@ -768,7 +687,6 @@ async def stream_chat_completion(
|
|||||||
error_response = StreamError(errorText=error_message)
|
error_response = StreamError(errorText=error_message)
|
||||||
yield error_response
|
yield error_response
|
||||||
if not has_yielded_end:
|
if not has_yielded_end:
|
||||||
yield StreamFinishStep()
|
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -783,8 +701,6 @@ async def stream_chat_completion(
|
|||||||
retry_count=retry_count + 1,
|
retry_count=retry_count + 1,
|
||||||
session=session,
|
session=session,
|
||||||
context=context,
|
context=context,
|
||||||
_continuation_message_id=message_id, # Reuse message ID since start was already sent
|
|
||||||
_task_id=_task_id,
|
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
return # Exit after retry to avoid double-saving in finally block
|
return # Exit after retry to avoid double-saving in finally block
|
||||||
@@ -800,13 +716,9 @@ async def stream_chat_completion(
|
|||||||
# Build the messages list in the correct order
|
# Build the messages list in the correct order
|
||||||
messages_to_save: list[ChatMessage] = []
|
messages_to_save: list[ChatMessage] = []
|
||||||
|
|
||||||
# Add assistant message with tool_calls if any.
|
# Add assistant message with tool_calls if any
|
||||||
# Use extend (not assign) to preserve tool_calls already added by
|
|
||||||
# _yield_tool_call for long-running tools.
|
|
||||||
if accumulated_tool_calls:
|
if accumulated_tool_calls:
|
||||||
if not assistant_response.tool_calls:
|
assistant_response.tool_calls = accumulated_tool_calls
|
||||||
assistant_response.tool_calls = []
|
|
||||||
assistant_response.tool_calls.extend(accumulated_tool_calls)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||||
)
|
)
|
||||||
@@ -858,8 +770,6 @@ async def stream_chat_completion(
|
|||||||
session=session, # Pass session object to avoid Redis refetch
|
session=session, # Pass session object to avoid Redis refetch
|
||||||
context=context,
|
context=context,
|
||||||
tool_call_response=str(tool_response_messages),
|
tool_call_response=str(tool_response_messages),
|
||||||
_continuation_message_id=message_id, # Reuse message ID to avoid duplicates
|
|
||||||
_task_id=_task_id,
|
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
@@ -970,21 +880,9 @@ async def _stream_chat_chunks(
|
|||||||
SSE formatted JSON response objects
|
SSE formatted JSON response objects
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import time as time_module
|
|
||||||
|
|
||||||
stream_chunks_start = time_module.perf_counter()
|
|
||||||
model = config.model
|
model = config.model
|
||||||
|
|
||||||
# Build log metadata for structured logging
|
logger.info("Starting pure chat stream")
|
||||||
log_meta = {"component": "ChatService", "session_id": session.session_id}
|
|
||||||
if session.user_id:
|
|
||||||
log_meta["user_id"] = session.user_id
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
|
|
||||||
f"user={session.user_id}, n_messages={len(session.messages)}",
|
|
||||||
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = session.to_openai_messages()
|
messages = session.to_openai_messages()
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
@@ -995,18 +893,12 @@ async def _stream_chat_chunks(
|
|||||||
messages = [system_message] + messages
|
messages = [system_message] + messages
|
||||||
|
|
||||||
# Apply context window management
|
# Apply context window management
|
||||||
context_start = time_module.perf_counter()
|
|
||||||
context_result = await _manage_context_window(
|
context_result = await _manage_context_window(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=model,
|
model=model,
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
base_url=config.base_url,
|
base_url=config.base_url,
|
||||||
)
|
)
|
||||||
context_time = (time_module.perf_counter() - context_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
if context_result.error:
|
if context_result.error:
|
||||||
if "System prompt dropped" in context_result.error:
|
if "System prompt dropped" in context_result.error:
|
||||||
@@ -1041,19 +933,9 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
while retry_count <= MAX_RETRIES:
|
while retry_count <= MAX_RETRIES:
|
||||||
try:
|
try:
|
||||||
elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000
|
|
||||||
retry_info = (
|
|
||||||
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
|
|
||||||
)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
|
f"Creating OpenAI chat completion stream..."
|
||||||
extra={
|
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed,
|
|
||||||
"retry_count": retry_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build extra_body for OpenRouter tracing and PostHog analytics
|
# Build extra_body for OpenRouter tracing and PostHog analytics
|
||||||
@@ -1070,11 +952,6 @@ async def _stream_chat_chunks(
|
|||||||
:128
|
:128
|
||||||
] # OpenRouter limit
|
] # OpenRouter limit
|
||||||
|
|
||||||
# Enable adaptive thinking for Anthropic models via OpenRouter
|
|
||||||
if config.thinking_enabled and "anthropic" in model.lower():
|
|
||||||
extra_body["reasoning"] = {"enabled": True}
|
|
||||||
|
|
||||||
api_call_start = time_module.perf_counter()
|
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
@@ -1084,11 +961,6 @@ async def _stream_chat_chunks(
|
|||||||
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Variables to accumulate tool calls
|
# Variables to accumulate tool calls
|
||||||
tool_calls: list[dict[str, Any]] = []
|
tool_calls: list[dict[str, Any]] = []
|
||||||
@@ -1099,13 +971,10 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
# Track if we've started the text block
|
# Track if we've started the text block
|
||||||
text_started = False
|
text_started = False
|
||||||
first_content_chunk = True
|
|
||||||
chunk_count = 0
|
|
||||||
|
|
||||||
# Process the stream
|
# Process the stream
|
||||||
chunk: ChatCompletionChunk
|
chunk: ChatCompletionChunk
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
chunk_count += 1
|
|
||||||
if chunk.usage:
|
if chunk.usage:
|
||||||
yield StreamUsage(
|
yield StreamUsage(
|
||||||
promptTokens=chunk.usage.prompt_tokens,
|
promptTokens=chunk.usage.prompt_tokens,
|
||||||
@@ -1128,23 +997,6 @@ async def _stream_chat_chunks(
|
|||||||
if not text_started and text_block_id:
|
if not text_started and text_block_id:
|
||||||
yield StreamTextStart(id=text_block_id)
|
yield StreamTextStart(id=text_block_id)
|
||||||
text_started = True
|
text_started = True
|
||||||
# Log timing for first content chunk
|
|
||||||
if first_content_chunk:
|
|
||||||
first_content_chunk = False
|
|
||||||
ttfc = (
|
|
||||||
time_module.perf_counter() - api_call_start
|
|
||||||
) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
|
|
||||||
f"(since API call), n_chunks={chunk_count}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"time_to_first_chunk_ms": ttfc,
|
|
||||||
"n_chunks": chunk_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Stream the text delta
|
# Stream the text delta
|
||||||
text_response = StreamTextDelta(
|
text_response = StreamTextDelta(
|
||||||
id=text_block_id or "",
|
id=text_block_id or "",
|
||||||
@@ -1201,21 +1053,7 @@ async def _stream_chat_chunks(
|
|||||||
toolName=tool_calls[idx]["function"]["name"],
|
toolName=tool_calls[idx]["function"]["name"],
|
||||||
)
|
)
|
||||||
emitted_start_for_idx.add(idx)
|
emitted_start_for_idx.add(idx)
|
||||||
stream_duration = time_module.perf_counter() - api_call_start
|
logger.info(f"Stream complete. Finish reason: {finish_reason}")
|
||||||
logger.info(
|
|
||||||
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
|
|
||||||
f"duration={stream_duration:.2f}s, "
|
|
||||||
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"stream_duration_ms": stream_duration * 1000,
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
"n_chunks": chunk_count,
|
|
||||||
"n_tool_calls": len(tool_calls),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Yield all accumulated tool calls after the stream is complete
|
# Yield all accumulated tool calls after the stream is complete
|
||||||
# This ensures all tool call arguments have been fully received
|
# This ensures all tool call arguments have been fully received
|
||||||
@@ -1235,12 +1073,6 @@ async def _stream_chat_chunks(
|
|||||||
# Re-raise to trigger retry logic in the parent function
|
# Re-raise to trigger retry logic in the parent function
|
||||||
raise
|
raise
|
||||||
|
|
||||||
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
|
|
||||||
f"session={session.session_id}, user={session.user_id}",
|
|
||||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
|
||||||
)
|
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1408,9 +1240,13 @@ async def _yield_tool_call(
|
|||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Attach the tool_call to the current turn's assistant message
|
# Save assistant message with tool_call FIRST (required by LLM)
|
||||||
# (or create one if this is a tool-only response with no text).
|
assistant_message = ChatMessage(
|
||||||
session.add_tool_call_to_current_turn(tool_calls[yield_idx])
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[tool_calls[yield_idx]],
|
||||||
|
)
|
||||||
|
session.messages.append(assistant_message)
|
||||||
|
|
||||||
# Then save pending tool result
|
# Then save pending tool result
|
||||||
pending_message = ChatMessage(
|
pending_message = ChatMessage(
|
||||||
@@ -1716,7 +1552,6 @@ async def _execute_long_running_tool_with_streaming(
|
|||||||
task_id,
|
task_id,
|
||||||
StreamError(errorText=str(e)),
|
StreamError(errorText=str(e)),
|
||||||
)
|
)
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||||
|
|
||||||
await _update_pending_operation(
|
await _update_pending_operation(
|
||||||
@@ -1833,10 +1668,6 @@ async def _generate_llm_continuation(
|
|||||||
if session_id:
|
if session_id:
|
||||||
extra_body["session_id"] = session_id[:128]
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
# Enable adaptive thinking for Anthropic models via OpenRouter
|
|
||||||
if config.thinking_enabled and "anthropic" in config.model.lower():
|
|
||||||
extra_body["reasoning"] = {"enabled": True}
|
|
||||||
|
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
last_error: Exception | None = None
|
last_error: Exception | None = None
|
||||||
response = None
|
response = None
|
||||||
@@ -1967,10 +1798,6 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
if session_id:
|
if session_id:
|
||||||
extra_body["session_id"] = session_id[:128]
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
# Enable adaptive thinking for Anthropic models via OpenRouter
|
|
||||||
if config.thinking_enabled and "anthropic" in config.model.lower():
|
|
||||||
extra_body["reasoning"] = {"enabled": True}
|
|
||||||
|
|
||||||
# Make streaming LLM call (no tools - just text response)
|
# Make streaming LLM call (no tools - just text response)
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
@@ -1982,7 +1809,6 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
|
|
||||||
# Publish start event
|
# Publish start event
|
||||||
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
||||||
await stream_registry.publish_chunk(task_id, StreamStartStep())
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
||||||
|
|
||||||
# Stream the response
|
# Stream the response
|
||||||
@@ -2006,7 +1832,6 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
|
|
||||||
# Publish end events
|
# Publish end events
|
||||||
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
|
||||||
|
|
||||||
if assistant_content:
|
if assistant_content:
|
||||||
# Reload session from DB to avoid race condition with user messages
|
# Reload session from DB to avoid race condition with user messages
|
||||||
@@ -2048,5 +1873,4 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
task_id,
|
task_id,
|
||||||
StreamError(errorText=f"Failed to generate response: {e}"),
|
StreamError(errorText=f"Failed to generate response: {e}"),
|
||||||
)
|
)
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||||
|
|||||||
@@ -104,24 +104,6 @@ async def create_task(
|
|||||||
Returns:
|
Returns:
|
||||||
The created ActiveTask instance (metadata only)
|
The created ActiveTask instance (metadata only)
|
||||||
"""
|
"""
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
# Build log metadata for structured logging
|
|
||||||
log_meta = {
|
|
||||||
"component": "StreamRegistry",
|
|
||||||
"task_id": task_id,
|
|
||||||
"session_id": session_id,
|
|
||||||
}
|
|
||||||
if user_id:
|
|
||||||
log_meta["user_id"] = user_id
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
|
|
||||||
task = ActiveTask(
|
task = ActiveTask(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -132,18 +114,10 @@ async def create_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Store metadata in Redis
|
# Store metadata in Redis
|
||||||
redis_start = time.perf_counter()
|
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
redis_time = (time.perf_counter() - redis_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
meta_key = _get_task_meta_key(task_id)
|
||||||
op_key = _get_operation_mapping_key(operation_id)
|
op_key = _get_operation_mapping_key(operation_id)
|
||||||
|
|
||||||
hset_start = time.perf_counter()
|
|
||||||
await redis.hset( # type: ignore[misc]
|
await redis.hset( # type: ignore[misc]
|
||||||
meta_key,
|
meta_key,
|
||||||
mapping={
|
mapping={
|
||||||
@@ -157,22 +131,12 @@ async def create_task(
|
|||||||
"created_at": task.created_at.isoformat(),
|
"created_at": task.created_at.isoformat(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
hset_time = (time.perf_counter() - hset_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] redis.hset took {hset_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
await redis.expire(meta_key, config.stream_ttl)
|
await redis.expire(meta_key, config.stream_ttl)
|
||||||
|
|
||||||
# Create operation_id -> task_id mapping for webhook lookups
|
# Create operation_id -> task_id mapping for webhook lookups
|
||||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||||
|
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
logger.debug(f"Created task {task_id} for session {session_id}")
|
||||||
logger.info(
|
|
||||||
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
|
|
||||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
@@ -192,60 +156,26 @@ async def publish_chunk(
|
|||||||
Returns:
|
Returns:
|
||||||
The Redis Stream message ID
|
The Redis Stream message ID
|
||||||
"""
|
"""
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
chunk_type = type(chunk).__name__
|
|
||||||
chunk_json = chunk.model_dump_json()
|
chunk_json = chunk.model_dump_json()
|
||||||
message_id = "0-0"
|
message_id = "0-0"
|
||||||
|
|
||||||
# Build log metadata
|
|
||||||
log_meta = {
|
|
||||||
"component": "StreamRegistry",
|
|
||||||
"task_id": task_id,
|
|
||||||
"chunk_type": chunk_type,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
# Write to Redis Stream for persistence and real-time delivery
|
# Write to Redis Stream for persistence and real-time delivery
|
||||||
xadd_start = time.perf_counter()
|
|
||||||
raw_id = await redis.xadd(
|
raw_id = await redis.xadd(
|
||||||
stream_key,
|
stream_key,
|
||||||
{"data": chunk_json},
|
{"data": chunk_json},
|
||||||
maxlen=config.stream_max_length,
|
maxlen=config.stream_max_length,
|
||||||
)
|
)
|
||||||
xadd_time = (time.perf_counter() - xadd_start) * 1000
|
|
||||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||||
|
|
||||||
# Set TTL on stream to match task metadata TTL
|
# Set TTL on stream to match task metadata TTL
|
||||||
await redis.expire(stream_key, config.stream_ttl)
|
await redis.expire(stream_key, config.stream_ttl)
|
||||||
|
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
|
||||||
# Only log timing for significant chunks or slow operations
|
|
||||||
if (
|
|
||||||
chunk_type
|
|
||||||
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
|
|
||||||
or total_time > 50
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"xadd_time_ms": xadd_time,
|
|
||||||
"message_id": message_id,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed = (time.perf_counter() - start_time) * 1000
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}",
|
f"Failed to publish chunk for task {task_id}: {e}",
|
||||||
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -270,61 +200,24 @@ async def subscribe_to_task(
|
|||||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||||
or user doesn't have access
|
or user doesn't have access
|
||||||
"""
|
"""
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
# Build log metadata
|
|
||||||
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
|
||||||
if user_id:
|
|
||||||
log_meta["user_id"] = user_id
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
|
|
||||||
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
|
|
||||||
)
|
|
||||||
|
|
||||||
redis_start = time.perf_counter()
|
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
meta_key = _get_task_meta_key(task_id)
|
meta_key = _get_task_meta_key(task_id)
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
hgetall_time = (time.perf_counter() - redis_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
if not meta:
|
if not meta:
|
||||||
elapsed = (time.perf_counter() - start_time) * 1000
|
logger.debug(f"Task {task_id} not found in Redis")
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed,
|
|
||||||
"reason": "task_not_found",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||||
task_status = meta.get("status", "")
|
task_status = meta.get("status", "")
|
||||||
task_user_id = meta.get("user_id", "") or None
|
task_user_id = meta.get("user_id", "") or None
|
||||||
log_meta["session_id"] = meta.get("session_id", "")
|
|
||||||
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
# Validate ownership - if task has an owner, requester must match
|
||||||
if task_user_id:
|
if task_user_id:
|
||||||
if user_id != task_user_id:
|
if user_id != task_user_id:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}",
|
f"User {user_id} denied access to task {task_id} "
|
||||||
extra={
|
f"owned by {task_user_id}"
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"task_owner": task_user_id,
|
|
||||||
"reason": "access_denied",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -332,19 +225,7 @@ async def subscribe_to_task(
|
|||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
# Step 1: Replay messages from Redis Stream
|
# Step 1: Replay messages from Redis Stream
|
||||||
xread_start = time.perf_counter()
|
|
||||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"duration_ms": xread_time,
|
|
||||||
"task_status": task_status,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
replayed_count = 0
|
replayed_count = 0
|
||||||
replay_last_id = last_message_id
|
replay_last_id = last_message_id
|
||||||
@@ -363,48 +244,19 @@ async def subscribe_to_task(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to replay message: {e}")
|
logger.warning(f"Failed to replay message: {e}")
|
||||||
|
|
||||||
logger.info(
|
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
|
||||||
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"n_messages_replayed": replayed_count,
|
|
||||||
"replay_last_id": replay_last_id,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 2: If task is still running, start stream listener for live updates
|
# Step 2: If task is still running, start stream listener for live updates
|
||||||
if task_status == "running":
|
if task_status == "running":
|
||||||
logger.info(
|
|
||||||
"[TIMING] Task still running, starting _stream_listener",
|
|
||||||
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
|
||||||
)
|
|
||||||
listener_task = asyncio.create_task(
|
listener_task = asyncio.create_task(
|
||||||
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
|
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
||||||
)
|
)
|
||||||
# Track listener task for cleanup on unsubscribe
|
# Track listener task for cleanup on unsubscribe
|
||||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||||
else:
|
else:
|
||||||
# Task is completed/failed - add finish marker
|
# Task is completed/failed - add finish marker
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Task already {task_status}, adding StreamFinish",
|
|
||||||
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
|
||||||
)
|
|
||||||
await subscriber_queue.put(StreamFinish())
|
await subscriber_queue.put(StreamFinish())
|
||||||
|
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
|
|
||||||
f"n_messages_replayed={replayed_count}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"n_messages_replayed": replayed_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return subscriber_queue
|
return subscriber_queue
|
||||||
|
|
||||||
|
|
||||||
@@ -412,7 +264,6 @@ async def _stream_listener(
|
|||||||
task_id: str,
|
task_id: str,
|
||||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||||
last_replayed_id: str,
|
last_replayed_id: str,
|
||||||
log_meta: dict | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||||
|
|
||||||
@@ -423,27 +274,10 @@ async def _stream_listener(
|
|||||||
task_id: Task ID to listen for
|
task_id: Task ID to listen for
|
||||||
subscriber_queue: Queue to deliver messages to
|
subscriber_queue: Queue to deliver messages to
|
||||||
last_replayed_id: Last message ID from replay (continue from here)
|
last_replayed_id: Last message ID from replay (continue from here)
|
||||||
log_meta: Structured logging metadata
|
|
||||||
"""
|
"""
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
# Use provided log_meta or build minimal one
|
|
||||||
if log_meta is None:
|
|
||||||
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
|
|
||||||
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
|
|
||||||
)
|
|
||||||
|
|
||||||
queue_id = id(subscriber_queue)
|
queue_id = id(subscriber_queue)
|
||||||
# Track the last successfully delivered message ID for recovery hints
|
# Track the last successfully delivered message ID for recovery hints
|
||||||
last_delivered_id = last_replayed_id
|
last_delivered_id = last_replayed_id
|
||||||
messages_delivered = 0
|
|
||||||
first_message_time = None
|
|
||||||
xread_count = 0
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
@@ -453,39 +287,9 @@ async def _stream_listener(
|
|||||||
while True:
|
while True:
|
||||||
# Block for up to 30 seconds waiting for new messages
|
# Block for up to 30 seconds waiting for new messages
|
||||||
# This allows periodic checking if task is still running
|
# This allows periodic checking if task is still running
|
||||||
xread_start = time.perf_counter()
|
|
||||||
xread_count += 1
|
|
||||||
messages = await redis.xread(
|
messages = await redis.xread(
|
||||||
{stream_key: current_id}, block=30000, count=100
|
{stream_key: current_id}, block=30000, count=100
|
||||||
)
|
)
|
||||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
|
||||||
|
|
||||||
if messages:
|
|
||||||
msg_count = sum(len(msgs) for _, msgs in messages)
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"xread_count": xread_count,
|
|
||||||
"n_messages": msg_count,
|
|
||||||
"duration_ms": xread_time,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
elif xread_time > 1000:
|
|
||||||
# Only log timeouts (30s blocking)
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"xread_count": xread_count,
|
|
||||||
"duration_ms": xread_time,
|
|
||||||
"reason": "timeout",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if not messages:
|
if not messages:
|
||||||
# Timeout - check if task is still running
|
# Timeout - check if task is still running
|
||||||
@@ -522,30 +326,10 @@ async def _stream_listener(
|
|||||||
)
|
)
|
||||||
# Update last delivered ID on successful delivery
|
# Update last delivered ID on successful delivery
|
||||||
last_delivered_id = current_id
|
last_delivered_id = current_id
|
||||||
messages_delivered += 1
|
|
||||||
if first_message_time is None:
|
|
||||||
first_message_time = time.perf_counter()
|
|
||||||
elapsed = (first_message_time - start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed,
|
|
||||||
"chunk_type": type(chunk).__name__,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
|
f"Subscriber queue full for task {task_id}, "
|
||||||
extra={
|
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"timeout_s": QUEUE_PUT_TIMEOUT,
|
|
||||||
"reason": "queue_full",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
# Send overflow error with recovery info
|
# Send overflow error with recovery info
|
||||||
try:
|
try:
|
||||||
@@ -567,44 +351,15 @@ async def _stream_listener(
|
|||||||
|
|
||||||
# Stop listening on finish
|
# Stop listening on finish
|
||||||
if isinstance(chunk, StreamFinish):
|
if isinstance(chunk, StreamFinish):
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"messages_delivered": messages_delivered,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(f"Error processing stream message: {e}")
|
||||||
f"Error processing stream message: {e}",
|
|
||||||
extra={"json_fields": {**log_meta, "error": str(e)}},
|
|
||||||
)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
elapsed = (time.perf_counter() - start_time) * 1000
|
logger.debug(f"Stream listener cancelled for task {task_id}")
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed,
|
|
||||||
"messages_delivered": messages_delivered,
|
|
||||||
"reason": "cancelled",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
raise # Re-raise to propagate cancellation
|
raise # Re-raise to propagate cancellation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed = (time.perf_counter() - start_time) * 1000
|
logger.error(f"Stream listener error for task {task_id}: {e}")
|
||||||
logger.error(
|
|
||||||
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
|
|
||||||
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
|
||||||
)
|
|
||||||
# On error, send finish to unblock subscriber
|
# On error, send finish to unblock subscriber
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
@@ -613,24 +368,10 @@ async def _stream_listener(
|
|||||||
)
|
)
|
||||||
except (asyncio.TimeoutError, asyncio.QueueFull):
|
except (asyncio.TimeoutError, asyncio.QueueFull):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Could not deliver finish event after error",
|
f"Could not deliver finish event for task {task_id} after error"
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
# Clean up listener task mapping on exit
|
# Clean up listener task mapping on exit
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
|
|
||||||
f"delivered={messages_delivered}, xread_count={xread_count}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"messages_delivered": messages_delivered,
|
|
||||||
"xread_count": xread_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
_listener_tasks.pop(queue_id, None)
|
_listener_tasks.pop(queue_id, None)
|
||||||
|
|
||||||
|
|
||||||
@@ -857,10 +598,8 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
|||||||
ResponseType,
|
ResponseType,
|
||||||
StreamError,
|
StreamError,
|
||||||
StreamFinish,
|
StreamFinish,
|
||||||
StreamFinishStep,
|
|
||||||
StreamHeartbeat,
|
StreamHeartbeat,
|
||||||
StreamStart,
|
StreamStart,
|
||||||
StreamStartStep,
|
|
||||||
StreamTextDelta,
|
StreamTextDelta,
|
||||||
StreamTextEnd,
|
StreamTextEnd,
|
||||||
StreamTextStart,
|
StreamTextStart,
|
||||||
@@ -874,8 +613,6 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
|||||||
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
||||||
ResponseType.START.value: StreamStart,
|
ResponseType.START.value: StreamStart,
|
||||||
ResponseType.FINISH.value: StreamFinish,
|
ResponseType.FINISH.value: StreamFinish,
|
||||||
ResponseType.START_STEP.value: StreamStartStep,
|
|
||||||
ResponseType.FINISH_STEP.value: StreamFinishStep,
|
|
||||||
ResponseType.TEXT_START.value: StreamTextStart,
|
ResponseType.TEXT_START.value: StreamTextStart,
|
||||||
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
||||||
ResponseType.TEXT_END.value: StreamTextEnd,
|
ResponseType.TEXT_END.value: StreamTextEnd,
|
||||||
|
|||||||
@@ -7,7 +7,15 @@ from typing import Any, NotRequired, TypedDict
|
|||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
from backend.data.graph import (
|
||||||
|
Graph,
|
||||||
|
Link,
|
||||||
|
Node,
|
||||||
|
create_graph,
|
||||||
|
get_graph,
|
||||||
|
get_graph_all_versions,
|
||||||
|
get_store_listed_graphs,
|
||||||
|
)
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .service import (
|
||||||
@@ -20,6 +28,8 @@ from .service import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||||
|
|
||||||
|
|
||||||
class ExecutionSummary(TypedDict):
|
class ExecutionSummary(TypedDict):
|
||||||
"""Summary of a single execution for quality assessment."""
|
"""Summary of a single execution for quality assessment."""
|
||||||
@@ -659,6 +669,45 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _reassign_node_ids(graph: Graph) -> None:
|
||||||
|
"""Reassign all node and link IDs to new UUIDs.
|
||||||
|
|
||||||
|
This is needed when creating a new version to avoid unique constraint violations.
|
||||||
|
"""
|
||||||
|
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||||
|
|
||||||
|
for node in graph.nodes:
|
||||||
|
node.id = id_map[node.id]
|
||||||
|
|
||||||
|
for link in graph.links:
|
||||||
|
link.id = str(uuid.uuid4())
|
||||||
|
if link.source_id in id_map:
|
||||||
|
link.source_id = id_map[link.source_id]
|
||||||
|
if link.sink_id in id_map:
|
||||||
|
link.sink_id = id_map[link.sink_id]
|
||||||
|
|
||||||
|
|
||||||
|
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
|
||||||
|
"""Populate user_id in AgentExecutorBlock nodes.
|
||||||
|
|
||||||
|
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
|
||||||
|
This function fills in the actual user_id so sub-agents run with correct permissions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_json: Agent JSON dict (modified in place)
|
||||||
|
user_id: User ID to set
|
||||||
|
"""
|
||||||
|
for node in agent_json.get("nodes", []):
|
||||||
|
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
|
||||||
|
input_default = node.get("input_default") or {}
|
||||||
|
if not input_default.get("user_id"):
|
||||||
|
input_default["user_id"] = user_id
|
||||||
|
node["input_default"] = input_default
|
||||||
|
logger.debug(
|
||||||
|
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def save_agent_to_library(
|
async def save_agent_to_library(
|
||||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||||
) -> tuple[Graph, Any]:
|
) -> tuple[Graph, Any]:
|
||||||
@@ -672,10 +721,35 @@ async def save_agent_to_library(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (created Graph, LibraryAgent)
|
Tuple of (created Graph, LibraryAgent)
|
||||||
"""
|
"""
|
||||||
|
# Populate user_id in AgentExecutorBlock nodes before conversion
|
||||||
|
_populate_agent_executor_user_ids(agent_json, user_id)
|
||||||
|
|
||||||
graph = json_to_graph(agent_json)
|
graph = json_to_graph(agent_json)
|
||||||
|
|
||||||
if is_update:
|
if is_update:
|
||||||
return await library_db.update_graph_in_library(graph, user_id)
|
if graph.id:
|
||||||
return await library_db.create_graph_in_library(graph, user_id)
|
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
||||||
|
if existing_versions:
|
||||||
|
latest_version = max(v.version for v in existing_versions)
|
||||||
|
graph.version = latest_version + 1
|
||||||
|
_reassign_node_ids(graph)
|
||||||
|
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
||||||
|
else:
|
||||||
|
graph.id = str(uuid.uuid4())
|
||||||
|
graph.version = 1
|
||||||
|
_reassign_node_ids(graph)
|
||||||
|
logger.info(f"Creating new agent with ID {graph.id}")
|
||||||
|
|
||||||
|
created_graph = await create_graph(graph, user_id)
|
||||||
|
|
||||||
|
library_agents = await library_db.create_library_agent(
|
||||||
|
graph=created_graph,
|
||||||
|
user_id=user_id,
|
||||||
|
sensitive_action_safe_mode=True,
|
||||||
|
create_library_agents_for_sub_graphs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return created_graph, library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||||
|
|||||||
@@ -206,9 +206,9 @@ async def search_agents(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
no_results_msg = (
|
no_results_msg = (
|
||||||
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
|
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
|
else f"No agents matching '{query}' found in your library."
|
||||||
)
|
)
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||||
@@ -224,10 +224,10 @@ async def search_agents(
|
|||||||
message = (
|
message = (
|
||||||
"Now you have found some options for the user to choose from. "
|
"Now you have found some options for the user to choose from. "
|
||||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||||
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
|
"Please ask the user if they would like to use any of these agents."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
|
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgentsFoundResponse(
|
return AgentsFoundResponse(
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||||
@@ -27,6 +29,23 @@ from .models import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomizeAgentInput(BaseModel):
|
||||||
|
"""Input parameters for the customize_agent tool."""
|
||||||
|
|
||||||
|
agent_id: str = ""
|
||||||
|
modifications: str = ""
|
||||||
|
context: str = ""
|
||||||
|
save: bool = True
|
||||||
|
|
||||||
|
@field_validator("agent_id", "modifications", "context", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_strings(cls, v: Any) -> str:
|
||||||
|
"""Strip whitespace from string fields."""
|
||||||
|
if isinstance(v, str):
|
||||||
|
return v.strip()
|
||||||
|
return v if v is not None else ""
|
||||||
|
|
||||||
|
|
||||||
class CustomizeAgentTool(BaseTool):
|
class CustomizeAgentTool(BaseTool):
|
||||||
"""Tool for customizing marketplace/template agents using natural language."""
|
"""Tool for customizing marketplace/template agents using natural language."""
|
||||||
|
|
||||||
@@ -92,7 +111,7 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
"""Execute the customize_agent tool.
|
"""Execute the customize_agent tool.
|
||||||
|
|
||||||
@@ -102,20 +121,17 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
3. Call customize_template with the modification request
|
3. Call customize_template with the modification request
|
||||||
4. Preview or save based on the save parameter
|
4. Preview or save based on the save parameter
|
||||||
"""
|
"""
|
||||||
agent_id = kwargs.get("agent_id", "").strip()
|
params = CustomizeAgentInput(**kwargs)
|
||||||
modifications = kwargs.get("modifications", "").strip()
|
|
||||||
context = kwargs.get("context", "")
|
|
||||||
save = kwargs.get("save", True)
|
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
if not agent_id:
|
if not params.agent_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
||||||
error="missing_agent_id",
|
error="missing_agent_id",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not modifications:
|
if not params.modifications:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please describe how you want to customize this agent.",
|
message="Please describe how you want to customize this agent.",
|
||||||
error="missing_modifications",
|
error="missing_modifications",
|
||||||
@@ -123,11 +139,11 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Parse agent_id in format "creator/slug"
|
# Parse agent_id in format "creator/slug"
|
||||||
parts = [p.strip() for p in agent_id.split("/")]
|
parts = params.agent_id.split("/")
|
||||||
if len(parts) != 2 or not parts[0] or not parts[1]:
|
if len(parts) != 2 or not parts[0] or not parts[1]:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
f"Invalid agent ID format: '{agent_id}'. "
|
f"Invalid agent ID format: '{params.agent_id}'. "
|
||||||
"Expected format is 'creator/agent-name' "
|
"Expected format is 'creator/agent-name' "
|
||||||
"(e.g., 'autogpt/newsletter-writer')."
|
"(e.g., 'autogpt/newsletter-writer')."
|
||||||
),
|
),
|
||||||
@@ -145,14 +161,14 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
except AgentNotFoundError:
|
except AgentNotFoundError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
f"Could not find marketplace agent '{agent_id}'. "
|
f"Could not find marketplace agent '{params.agent_id}'. "
|
||||||
"Please check the agent ID and try again."
|
"Please check the agent ID and try again."
|
||||||
),
|
),
|
||||||
error="agent_not_found",
|
error="agent_not_found",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
logger.error(f"Error fetching marketplace agent {params.agent_id}: {e}")
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to fetch the marketplace agent. Please try again.",
|
message="Failed to fetch the marketplace agent. Please try again.",
|
||||||
error="fetch_error",
|
error="fetch_error",
|
||||||
@@ -162,7 +178,7 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
if not agent_details.store_listing_version_id:
|
if not agent_details.store_listing_version_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
f"The agent '{agent_id}' does not have an available version. "
|
f"The agent '{params.agent_id}' does not have an available version. "
|
||||||
"Please try a different agent."
|
"Please try a different agent."
|
||||||
),
|
),
|
||||||
error="no_version_available",
|
error="no_version_available",
|
||||||
@@ -174,7 +190,7 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
||||||
template_agent = graph_to_json(graph)
|
template_agent = graph_to_json(graph)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
logger.error(f"Error fetching agent graph for {params.agent_id}: {e}")
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to fetch the agent configuration. Please try again.",
|
message="Failed to fetch the agent configuration. Please try again.",
|
||||||
error="graph_fetch_error",
|
error="graph_fetch_error",
|
||||||
@@ -185,8 +201,8 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
try:
|
try:
|
||||||
result = await customize_template(
|
result = await customize_template(
|
||||||
template_agent=template_agent,
|
template_agent=template_agent,
|
||||||
modification_request=modifications,
|
modification_request=params.modifications,
|
||||||
context=context,
|
context=params.context,
|
||||||
)
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
@@ -198,7 +214,7 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
logger.error(f"Error calling customize_template for {params.agent_id}: {e}")
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
"Failed to customize the agent due to a service error. "
|
"Failed to customize the agent due to a service error. "
|
||||||
@@ -219,55 +235,25 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle error response
|
# Handle response using match/case for cleaner pattern matching
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
return await self._handle_customization_result(
|
||||||
error_msg = result.get("error", "Unknown error")
|
result=result,
|
||||||
error_type = result.get("error_type", "unknown")
|
params=params,
|
||||||
user_message = get_user_message_for_error(
|
agent_details=agent_details,
|
||||||
error_type,
|
user_id=user_id,
|
||||||
operation="customize the agent",
|
session_id=session_id,
|
||||||
llm_parse_message=(
|
)
|
||||||
"The AI had trouble customizing the agent. "
|
|
||||||
"Please try again or simplify your request."
|
|
||||||
),
|
|
||||||
validation_message=(
|
|
||||||
"The customized agent failed validation. "
|
|
||||||
"Please try rephrasing your request."
|
|
||||||
),
|
|
||||||
error_details=error_msg,
|
|
||||||
)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=user_message,
|
|
||||||
error=f"customization_failed:{error_type}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle clarifying questions
|
async def _handle_customization_result(
|
||||||
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
self,
|
||||||
questions = result.get("questions") or []
|
result: dict[str, Any],
|
||||||
if not isinstance(questions, list):
|
params: CustomizeAgentInput,
|
||||||
logger.error(
|
agent_details: Any,
|
||||||
f"Unexpected clarifying questions format: {type(questions)}"
|
user_id: str | None,
|
||||||
)
|
session_id: str | None,
|
||||||
questions = []
|
) -> ToolResponseBase:
|
||||||
return ClarificationNeededResponse(
|
"""Handle the result from customize_template using pattern matching."""
|
||||||
message=(
|
# Ensure result is a dict
|
||||||
"I need some more information to customize this agent. "
|
|
||||||
"Please answer the following questions:"
|
|
||||||
),
|
|
||||||
questions=[
|
|
||||||
ClarifyingQuestion(
|
|
||||||
question=q.get("question", ""),
|
|
||||||
keyword=q.get("keyword", ""),
|
|
||||||
example=q.get("example"),
|
|
||||||
)
|
|
||||||
for q in questions
|
|
||||||
if isinstance(q, dict)
|
|
||||||
],
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Result should be the customized agent JSON
|
|
||||||
if not isinstance(result, dict):
|
if not isinstance(result, dict):
|
||||||
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
@@ -276,8 +262,77 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
customized_agent = result
|
result_type = result.get("type")
|
||||||
|
|
||||||
|
match result_type:
|
||||||
|
case "error":
|
||||||
|
error_msg = result.get("error", "Unknown error")
|
||||||
|
error_type = result.get("error_type", "unknown")
|
||||||
|
user_message = get_user_message_for_error(
|
||||||
|
error_type,
|
||||||
|
operation="customize the agent",
|
||||||
|
llm_parse_message=(
|
||||||
|
"The AI had trouble customizing the agent. "
|
||||||
|
"Please try again or simplify your request."
|
||||||
|
),
|
||||||
|
validation_message=(
|
||||||
|
"The customized agent failed validation. "
|
||||||
|
"Please try rephrasing your request."
|
||||||
|
),
|
||||||
|
error_details=error_msg,
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=user_message,
|
||||||
|
error=f"customization_failed:{error_type}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
case "clarifying_questions":
|
||||||
|
questions_data = result.get("questions") or []
|
||||||
|
if not isinstance(questions_data, list):
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected clarifying questions format: {type(questions_data)}"
|
||||||
|
)
|
||||||
|
questions_data = []
|
||||||
|
|
||||||
|
questions = [
|
||||||
|
ClarifyingQuestion(
|
||||||
|
question=q.get("question", "") if isinstance(q, dict) else "",
|
||||||
|
keyword=q.get("keyword", "") if isinstance(q, dict) else "",
|
||||||
|
example=q.get("example") if isinstance(q, dict) else None,
|
||||||
|
)
|
||||||
|
for q in questions_data
|
||||||
|
if isinstance(q, dict)
|
||||||
|
]
|
||||||
|
|
||||||
|
return ClarificationNeededResponse(
|
||||||
|
message=(
|
||||||
|
"I need some more information to customize this agent. "
|
||||||
|
"Please answer the following questions:"
|
||||||
|
),
|
||||||
|
questions=questions,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
# Default case: result is the customized agent JSON
|
||||||
|
return await self._save_or_preview_agent(
|
||||||
|
customized_agent=result,
|
||||||
|
params=params,
|
||||||
|
agent_details=agent_details,
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _save_or_preview_agent(
|
||||||
|
self,
|
||||||
|
customized_agent: dict[str, Any],
|
||||||
|
params: CustomizeAgentInput,
|
||||||
|
agent_details: Any,
|
||||||
|
user_id: str | None,
|
||||||
|
session_id: str | None,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""Save or preview the customized agent based on params.save."""
|
||||||
agent_name = customized_agent.get(
|
agent_name = customized_agent.get(
|
||||||
"name", f"Customized {agent_details.agent_name}"
|
"name", f"Customized {agent_details.agent_name}"
|
||||||
)
|
)
|
||||||
@@ -287,7 +342,7 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
node_count = len(nodes) if isinstance(nodes, list) else 0
|
node_count = len(nodes) if isinstance(nodes, list) else 0
|
||||||
link_count = len(links) if isinstance(links, list) else 0
|
link_count = len(links) if isinstance(links, list) else 0
|
||||||
|
|
||||||
if not save:
|
if not params.save:
|
||||||
return AgentPreviewResponse(
|
return AgentPreviewResponse(
|
||||||
message=(
|
message=(
|
||||||
f"I've customized the agent '{agent_details.agent_name}'. "
|
f"I've customized the agent '{agent_details.agent_name}'. "
|
||||||
|
|||||||
@@ -13,32 +13,10 @@ from backend.api.features.chat.tools.models import (
|
|||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
from backend.data.block import BlockType, get_block
|
from backend.data.block import get_block
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_TARGET_RESULTS = 10
|
|
||||||
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
|
|
||||||
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
|
|
||||||
_OVERFETCH_PAGE_SIZE = 40
|
|
||||||
|
|
||||||
# Block types that only work within graphs and cannot run standalone in CoPilot.
|
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES = {
|
|
||||||
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
|
|
||||||
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
|
|
||||||
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
|
|
||||||
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
|
|
||||||
BlockType.NOTE, # Visual annotation only - no runtime behavior
|
|
||||||
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
|
|
||||||
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
|
|
||||||
}
|
|
||||||
|
|
||||||
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
|
||||||
COPILOT_EXCLUDED_BLOCK_IDS = {
|
|
||||||
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
|
|
||||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class FindBlockTool(BaseTool):
|
class FindBlockTool(BaseTool):
|
||||||
"""Tool for searching available blocks."""
|
"""Tool for searching available blocks."""
|
||||||
@@ -110,7 +88,7 @@ class FindBlockTool(BaseTool):
|
|||||||
query=query,
|
query=query,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
page_size=_OVERFETCH_PAGE_SIZE,
|
page_size=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
@@ -130,90 +108,60 @@ class FindBlockTool(BaseTool):
|
|||||||
block = get_block(block_id)
|
block = get_block(block_id)
|
||||||
|
|
||||||
# Skip disabled blocks
|
# Skip disabled blocks
|
||||||
if not block or block.disabled:
|
if block and not block.disabled:
|
||||||
continue
|
# Get input/output schemas
|
||||||
|
input_schema = {}
|
||||||
|
output_schema = {}
|
||||||
|
try:
|
||||||
|
input_schema = block.input_schema.jsonschema()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
output_schema = block.output_schema.jsonschema()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Skip blocks excluded from CoPilot (graph-only blocks)
|
# Get categories from block instance
|
||||||
if (
|
categories = []
|
||||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
if hasattr(block, "categories") and block.categories:
|
||||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
categories = [cat.value for cat in block.categories]
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get input/output schemas
|
# Extract required inputs for easier use
|
||||||
input_schema = {}
|
required_inputs: list[BlockInputFieldInfo] = []
|
||||||
output_schema = {}
|
if input_schema:
|
||||||
try:
|
properties = input_schema.get("properties", {})
|
||||||
input_schema = block.input_schema.jsonschema()
|
required_fields = set(input_schema.get("required", []))
|
||||||
except Exception as e:
|
# Get credential field names to exclude from required inputs
|
||||||
logger.debug(
|
credentials_fields = set(
|
||||||
"Failed to generate input schema for block %s: %s",
|
block.input_schema.get_credentials_fields().keys()
|
||||||
block_id,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
output_schema = block.output_schema.jsonschema()
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(
|
|
||||||
"Failed to generate output schema for block %s: %s",
|
|
||||||
block_id,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get categories from block instance
|
|
||||||
categories = []
|
|
||||||
if hasattr(block, "categories") and block.categories:
|
|
||||||
categories = [cat.value for cat in block.categories]
|
|
||||||
|
|
||||||
# Extract required inputs for easier use
|
|
||||||
required_inputs: list[BlockInputFieldInfo] = []
|
|
||||||
if input_schema:
|
|
||||||
properties = input_schema.get("properties", {})
|
|
||||||
required_fields = set(input_schema.get("required", []))
|
|
||||||
# Get credential field names to exclude from required inputs
|
|
||||||
credentials_fields = set(
|
|
||||||
block.input_schema.get_credentials_fields().keys()
|
|
||||||
)
|
|
||||||
|
|
||||||
for field_name, field_schema in properties.items():
|
|
||||||
# Skip credential fields - they're handled separately
|
|
||||||
if field_name in credentials_fields:
|
|
||||||
continue
|
|
||||||
|
|
||||||
required_inputs.append(
|
|
||||||
BlockInputFieldInfo(
|
|
||||||
name=field_name,
|
|
||||||
type=field_schema.get("type", "string"),
|
|
||||||
description=field_schema.get("description", ""),
|
|
||||||
required=field_name in required_fields,
|
|
||||||
default=field_schema.get("default"),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
blocks.append(
|
for field_name, field_schema in properties.items():
|
||||||
BlockInfoSummary(
|
# Skip credential fields - they're handled separately
|
||||||
id=block_id,
|
if field_name in credentials_fields:
|
||||||
name=block.name,
|
continue
|
||||||
description=block.description or "",
|
|
||||||
categories=categories,
|
required_inputs.append(
|
||||||
input_schema=input_schema,
|
BlockInputFieldInfo(
|
||||||
output_schema=output_schema,
|
name=field_name,
|
||||||
required_inputs=required_inputs,
|
type=field_schema.get("type", "string"),
|
||||||
|
description=field_schema.get("description", ""),
|
||||||
|
required=field_name in required_fields,
|
||||||
|
default=field_schema.get("default"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
blocks.append(
|
||||||
|
BlockInfoSummary(
|
||||||
|
id=block_id,
|
||||||
|
name=block.name,
|
||||||
|
description=block.description or "",
|
||||||
|
categories=categories,
|
||||||
|
input_schema=input_schema,
|
||||||
|
output_schema=output_schema,
|
||||||
|
required_inputs=required_inputs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if len(blocks) >= _TARGET_RESULTS:
|
|
||||||
break
|
|
||||||
|
|
||||||
if blocks and len(blocks) < _TARGET_RESULTS:
|
|
||||||
logger.debug(
|
|
||||||
"find_block returned %d/%d results for query '%s' "
|
|
||||||
"(filtered %d excluded/disabled blocks)",
|
|
||||||
len(blocks),
|
|
||||||
_TARGET_RESULTS,
|
|
||||||
query,
|
|
||||||
len(results) - len(blocks),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not blocks:
|
if not blocks:
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
|
|||||||
@@ -1,139 +0,0 @@
|
|||||||
"""Tests for block filtering in FindBlockTool."""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.api.features.chat.tools.find_block import (
|
|
||||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
|
||||||
FindBlockTool,
|
|
||||||
)
|
|
||||||
from backend.api.features.chat.tools.models import BlockListResponse
|
|
||||||
from backend.data.block import BlockType
|
|
||||||
|
|
||||||
from ._test_data import make_session
|
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-find-block"
|
|
||||||
|
|
||||||
|
|
||||||
def make_mock_block(
|
|
||||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
|
||||||
):
|
|
||||||
"""Create a mock block for testing."""
|
|
||||||
mock = MagicMock()
|
|
||||||
mock.id = block_id
|
|
||||||
mock.name = name
|
|
||||||
mock.description = f"{name} description"
|
|
||||||
mock.block_type = block_type
|
|
||||||
mock.disabled = disabled
|
|
||||||
mock.input_schema = MagicMock()
|
|
||||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
|
||||||
mock.input_schema.get_credentials_fields.return_value = {}
|
|
||||||
mock.output_schema = MagicMock()
|
|
||||||
mock.output_schema.jsonschema.return_value = {}
|
|
||||||
mock.categories = []
|
|
||||||
return mock
|
|
||||||
|
|
||||||
|
|
||||||
class TestFindBlockFiltering:
|
|
||||||
"""Tests for block filtering in FindBlockTool."""
|
|
||||||
|
|
||||||
def test_excluded_block_types_contains_expected_types(self):
|
|
||||||
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
|
|
||||||
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
|
|
||||||
def test_excluded_block_ids_contains_smart_decision_maker(self):
|
|
||||||
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
|
||||||
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_excluded_block_type_filtered_from_results(self):
|
|
||||||
"""Verify blocks with excluded BlockTypes are filtered from search results."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
|
|
||||||
search_results = [
|
|
||||||
{"content_id": "input-block-id", "score": 0.9},
|
|
||||||
{"content_id": "standard-block-id", "score": 0.8},
|
|
||||||
]
|
|
||||||
|
|
||||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
|
||||||
standard_block = make_mock_block(
|
|
||||||
"standard-block-id", "HTTP Request", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
def mock_get_block(block_id):
|
|
||||||
return {
|
|
||||||
"input-block-id": input_block,
|
|
||||||
"standard-block-id": standard_block,
|
|
||||||
}.get(block_id)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=(search_results, 2),
|
|
||||||
):
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.get_block",
|
|
||||||
side_effect=mock_get_block,
|
|
||||||
):
|
|
||||||
tool = FindBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID, session=session, query="test"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should only return the standard block, not the INPUT block
|
|
||||||
assert isinstance(response, BlockListResponse)
|
|
||||||
assert len(response.blocks) == 1
|
|
||||||
assert response.blocks[0].id == "standard-block-id"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_excluded_block_id_filtered_from_results(self):
|
|
||||||
"""Verify SmartDecisionMakerBlock is filtered from search results."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
|
||||||
search_results = [
|
|
||||||
{"content_id": smart_decision_id, "score": 0.9},
|
|
||||||
{"content_id": "normal-block-id", "score": 0.8},
|
|
||||||
]
|
|
||||||
|
|
||||||
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
|
|
||||||
smart_block = make_mock_block(
|
|
||||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
normal_block = make_mock_block(
|
|
||||||
"normal-block-id", "Normal Block", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
def mock_get_block(block_id):
|
|
||||||
return {
|
|
||||||
smart_decision_id: smart_block,
|
|
||||||
"normal-block-id": normal_block,
|
|
||||||
}.get(block_id)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=(search_results, 2),
|
|
||||||
):
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.get_block",
|
|
||||||
side_effect=mock_get_block,
|
|
||||||
):
|
|
||||||
tool = FindBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID, session=session, query="decision"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should only return normal block, not SmartDecisionMakerBlock
|
|
||||||
assert isinstance(response, BlockListResponse)
|
|
||||||
assert len(response.blocks) == 1
|
|
||||||
assert response.blocks[0].id == "normal-block-id"
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
"""Shared helpers for chat tools."""
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
def get_inputs_from_schema(
|
|
||||||
input_schema: dict[str, Any],
|
|
||||||
exclude_fields: set[str] | None = None,
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Extract input field info from JSON schema."""
|
|
||||||
if not isinstance(input_schema, dict):
|
|
||||||
return []
|
|
||||||
|
|
||||||
exclude = exclude_fields or set()
|
|
||||||
properties = input_schema.get("properties", {})
|
|
||||||
required = set(input_schema.get("required", []))
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"name": name,
|
|
||||||
"title": schema.get("title", name),
|
|
||||||
"type": schema.get("type", "string"),
|
|
||||||
"description": schema.get("description", ""),
|
|
||||||
"required": name in required,
|
|
||||||
"default": schema.get("default"),
|
|
||||||
}
|
|
||||||
for name, schema in properties.items()
|
|
||||||
if name not in exclude
|
|
||||||
]
|
|
||||||
@@ -24,7 +24,6 @@ from backend.util.timezone_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .helpers import get_inputs_from_schema
|
|
||||||
from .models import (
|
from .models import (
|
||||||
AgentDetails,
|
AgentDetails,
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
@@ -262,7 +261,7 @@ class RunAgentTool(BaseTool):
|
|||||||
),
|
),
|
||||||
requirements={
|
requirements={
|
||||||
"credentials": requirements_creds_list,
|
"credentials": requirements_creds_list,
|
||||||
"inputs": get_inputs_from_schema(graph.input_schema),
|
"inputs": self._get_inputs_list(graph.input_schema),
|
||||||
"execution_modes": self._get_execution_modes(graph),
|
"execution_modes": self._get_execution_modes(graph),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@@ -370,6 +369,22 @@ class RunAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
|
"""Extract inputs list from schema."""
|
||||||
|
inputs_list = []
|
||||||
|
if isinstance(input_schema, dict) and "properties" in input_schema:
|
||||||
|
for field_name, field_schema in input_schema["properties"].items():
|
||||||
|
inputs_list.append(
|
||||||
|
{
|
||||||
|
"name": field_name,
|
||||||
|
"title": field_schema.get("title", field_name),
|
||||||
|
"type": field_schema.get("type", "string"),
|
||||||
|
"description": field_schema.get("description", ""),
|
||||||
|
"required": field_name in input_schema.get("required", []),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return inputs_list
|
||||||
|
|
||||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||||
"""Get available execution modes for the graph."""
|
"""Get available execution modes for the graph."""
|
||||||
trigger_info = graph.trigger_setup_info
|
trigger_info = graph.trigger_setup_info
|
||||||
@@ -383,7 +398,7 @@ class RunAgentTool(BaseTool):
|
|||||||
suffix: str,
|
suffix: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build a message describing available inputs for an agent."""
|
"""Build a message describing available inputs for an agent."""
|
||||||
inputs_list = get_inputs_from_schema(graph.input_schema)
|
inputs_list = self._get_inputs_list(graph.input_schema)
|
||||||
required_names = [i["name"] for i in inputs_list if i["required"]]
|
required_names = [i["name"] for i in inputs_list if i["required"]]
|
||||||
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
||||||
|
|
||||||
|
|||||||
@@ -8,19 +8,14 @@ from typing import Any
|
|||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.api.features.chat.tools.find_block import (
|
from backend.data.block import get_block
|
||||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
|
||||||
)
|
|
||||||
from backend.data.block import AnyBlockSchema, get_block
|
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .helpers import get_inputs_from_schema
|
|
||||||
from .models import (
|
from .models import (
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -29,10 +24,7 @@ from .models import (
|
|||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
UserReadiness,
|
UserReadiness,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import build_missing_credentials_from_field_info
|
||||||
build_missing_credentials_from_field_info,
|
|
||||||
match_credentials_to_requirements,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -81,6 +73,91 @@ class RunBlockTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def _check_block_credentials(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
block: Any,
|
||||||
|
input_data: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Check if user has required credentials for a block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
block: Block to check credentials for
|
||||||
|
input_data: Input data for the block (used to determine provider via discriminator)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[matched_credentials, missing_credentials]
|
||||||
|
"""
|
||||||
|
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||||
|
missing_credentials: list[CredentialsMetaInput] = []
|
||||||
|
input_data = input_data or {}
|
||||||
|
|
||||||
|
# Get credential field info from block's input schema
|
||||||
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
|
|
||||||
|
if not credentials_fields_info:
|
||||||
|
return matched_credentials, missing_credentials
|
||||||
|
|
||||||
|
# Get user's available credentials
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
|
for field_name, field_info in credentials_fields_info.items():
|
||||||
|
effective_field_info = field_info
|
||||||
|
if field_info.discriminator and field_info.discriminator_mapping:
|
||||||
|
# Get discriminator from input, falling back to schema default
|
||||||
|
discriminator_value = input_data.get(field_info.discriminator)
|
||||||
|
if discriminator_value is None:
|
||||||
|
field = block.input_schema.model_fields.get(
|
||||||
|
field_info.discriminator
|
||||||
|
)
|
||||||
|
if field and field.default is not PydanticUndefined:
|
||||||
|
discriminator_value = field.default
|
||||||
|
|
||||||
|
if (
|
||||||
|
discriminator_value
|
||||||
|
and discriminator_value in field_info.discriminator_mapping
|
||||||
|
):
|
||||||
|
effective_field_info = field_info.discriminate(discriminator_value)
|
||||||
|
logger.debug(
|
||||||
|
f"Discriminated provider for {field_name}: "
|
||||||
|
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
matching_cred = next(
|
||||||
|
(
|
||||||
|
cred
|
||||||
|
for cred in available_creds
|
||||||
|
if cred.provider in effective_field_info.provider
|
||||||
|
and cred.type in effective_field_info.supported_types
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if matching_cred:
|
||||||
|
matched_credentials[field_name] = CredentialsMetaInput(
|
||||||
|
id=matching_cred.id,
|
||||||
|
provider=matching_cred.provider, # type: ignore
|
||||||
|
type=matching_cred.type,
|
||||||
|
title=matching_cred.title,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Create a placeholder for the missing credential
|
||||||
|
provider = next(iter(effective_field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
||||||
|
missing_credentials.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=field_name.replace("_", " ").title(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matched_credentials, missing_credentials
|
||||||
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
@@ -135,24 +212,11 @@ class RunBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if block is excluded from CoPilot (graph-only blocks)
|
|
||||||
if (
|
|
||||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
|
||||||
):
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Block '{block.name}' cannot be run directly in CoPilot. "
|
|
||||||
"This block is designed for use within graphs only."
|
|
||||||
),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||||
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
creds_manager = IntegrationCredentialsManager()
|
||||||
matched_credentials, missing_credentials = (
|
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||||
await self._resolve_block_credentials(user_id, block, input_data)
|
user_id, block, input_data
|
||||||
)
|
)
|
||||||
|
|
||||||
if missing_credentials:
|
if missing_credentials:
|
||||||
@@ -281,75 +345,29 @@ class RunBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _resolve_block_credentials(
|
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
block: AnyBlockSchema,
|
|
||||||
input_data: dict[str, Any] | None = None,
|
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
|
||||||
"""
|
|
||||||
Resolve credentials for a block by matching user's available credentials.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
block: Block to resolve credentials for
|
|
||||||
input_data: Input data for the block (used to determine provider via discriminator)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple of (matched_credentials, missing_credentials) - matched credentials
|
|
||||||
are used for block execution, missing ones indicate setup requirements.
|
|
||||||
"""
|
|
||||||
input_data = input_data or {}
|
|
||||||
requirements = self._resolve_discriminated_credentials(block, input_data)
|
|
||||||
|
|
||||||
if not requirements:
|
|
||||||
return {}, []
|
|
||||||
|
|
||||||
return await match_credentials_to_requirements(user_id, requirements)
|
|
||||||
|
|
||||||
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
|
|
||||||
"""Extract non-credential inputs from block schema."""
|
"""Extract non-credential inputs from block schema."""
|
||||||
|
inputs_list = []
|
||||||
schema = block.input_schema.jsonschema()
|
schema = block.input_schema.jsonschema()
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
required_fields = set(schema.get("required", []))
|
||||||
|
|
||||||
|
# Get credential field names to exclude
|
||||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||||
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
|
||||||
|
|
||||||
def _resolve_discriminated_credentials(
|
for field_name, field_schema in properties.items():
|
||||||
self,
|
# Skip credential fields
|
||||||
block: AnyBlockSchema,
|
if field_name in credentials_fields:
|
||||||
input_data: dict[str, Any],
|
continue
|
||||||
) -> dict[str, CredentialsFieldInfo]:
|
|
||||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
|
||||||
if not credentials_fields_info:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
resolved: dict[str, CredentialsFieldInfo] = {}
|
inputs_list.append(
|
||||||
|
{
|
||||||
|
"name": field_name,
|
||||||
|
"title": field_schema.get("title", field_name),
|
||||||
|
"type": field_schema.get("type", "string"),
|
||||||
|
"description": field_schema.get("description", ""),
|
||||||
|
"required": field_name in required_fields,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
for field_name, field_info in credentials_fields_info.items():
|
return inputs_list
|
||||||
effective_field_info = field_info
|
|
||||||
|
|
||||||
if field_info.discriminator and field_info.discriminator_mapping:
|
|
||||||
discriminator_value = input_data.get(field_info.discriminator)
|
|
||||||
if discriminator_value is None:
|
|
||||||
field = block.input_schema.model_fields.get(
|
|
||||||
field_info.discriminator
|
|
||||||
)
|
|
||||||
if field and field.default is not PydanticUndefined:
|
|
||||||
discriminator_value = field.default
|
|
||||||
|
|
||||||
if (
|
|
||||||
discriminator_value
|
|
||||||
and discriminator_value in field_info.discriminator_mapping
|
|
||||||
):
|
|
||||||
effective_field_info = field_info.discriminate(discriminator_value)
|
|
||||||
# For host-scoped credentials, add the discriminator value
|
|
||||||
# (e.g., URL) so _credential_is_for_host can match it
|
|
||||||
effective_field_info.discriminator_values.add(discriminator_value)
|
|
||||||
logger.debug(
|
|
||||||
f"Discriminated provider for {field_name}: "
|
|
||||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
|
||||||
)
|
|
||||||
|
|
||||||
resolved[field_name] = effective_field_info
|
|
||||||
|
|
||||||
return resolved
|
|
||||||
|
|||||||
@@ -1,106 +0,0 @@
|
|||||||
"""Tests for block execution guards in RunBlockTool."""
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.api.features.chat.tools.models import ErrorResponse
|
|
||||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
|
||||||
from backend.data.block import BlockType
|
|
||||||
|
|
||||||
from ._test_data import make_session
|
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-run-block"
|
|
||||||
|
|
||||||
|
|
||||||
def make_mock_block(
|
|
||||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
|
||||||
):
|
|
||||||
"""Create a mock block for testing."""
|
|
||||||
mock = MagicMock()
|
|
||||||
mock.id = block_id
|
|
||||||
mock.name = name
|
|
||||||
mock.block_type = block_type
|
|
||||||
mock.disabled = disabled
|
|
||||||
mock.input_schema = MagicMock()
|
|
||||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
|
||||||
mock.input_schema.get_credentials_fields_info.return_value = []
|
|
||||||
return mock
|
|
||||||
|
|
||||||
|
|
||||||
class TestRunBlockFiltering:
|
|
||||||
"""Tests for block execution guards in RunBlockTool."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_excluded_block_type_returns_error(self):
|
|
||||||
"""Attempting to execute a block with excluded BlockType returns error."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=input_block,
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id="input-block-id",
|
|
||||||
input_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, ErrorResponse)
|
|
||||||
assert "cannot be run directly in CoPilot" in response.message
|
|
||||||
assert "designed for use within graphs only" in response.message
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_excluded_block_id_returns_error(self):
|
|
||||||
"""Attempting to execute SmartDecisionMakerBlock returns error."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
|
||||||
smart_block = make_mock_block(
|
|
||||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=smart_block,
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id=smart_decision_id,
|
|
||||||
input_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, ErrorResponse)
|
|
||||||
assert "cannot be run directly in CoPilot" in response.message
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_non_excluded_block_passes_guard(self):
|
|
||||||
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
standard_block = make_mock_block(
|
|
||||||
"standard-id", "HTTP Request", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=standard_block,
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id="standard-id",
|
|
||||||
input_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should NOT be an ErrorResponse about CoPilot exclusion
|
|
||||||
# (may be other errors like missing credentials, but not the exclusion guard)
|
|
||||||
if isinstance(response, ErrorResponse):
|
|
||||||
assert "cannot be run directly in CoPilot" not in response.message
|
|
||||||
@@ -6,14 +6,9 @@ from typing import Any
|
|||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.library import model as library_model
|
from backend.api.features.library import model as library_model
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
|
from backend.data import graph as graph_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import (
|
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
||||||
Credentials,
|
|
||||||
CredentialsFieldInfo,
|
|
||||||
CredentialsMetaInput,
|
|
||||||
HostScopedCredentials,
|
|
||||||
OAuth2Credentials,
|
|
||||||
)
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
@@ -44,8 +39,14 @@ async def fetch_graph_from_store_slug(
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Get the graph from store listing version
|
# Get the graph from store listing version
|
||||||
graph = await store_db.get_available_graph(
|
graph_meta = await store_db.get_available_graph(
|
||||||
store_agent.store_listing_version_id, hide_nodes=False
|
store_agent.store_listing_version_id
|
||||||
|
)
|
||||||
|
graph = await graph_db.get_graph(
|
||||||
|
graph_id=graph_meta.id,
|
||||||
|
version=graph_meta.version,
|
||||||
|
user_id=None, # Public access
|
||||||
|
include_subgraphs=True,
|
||||||
)
|
)
|
||||||
return graph, store_agent
|
return graph, store_agent
|
||||||
|
|
||||||
@@ -122,7 +123,7 @@ def build_missing_credentials_from_graph(
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
field_key: _serialize_missing_credential(field_key, field_info)
|
field_key: _serialize_missing_credential(field_key, field_info)
|
||||||
for field_key, (field_info, _, _) in aggregated_fields.items()
|
for field_key, (field_info, _node_fields) in aggregated_fields.items()
|
||||||
if field_key not in matched_keys
|
if field_key not in matched_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -224,99 +225,6 @@ async def get_or_create_library_agent(
|
|||||||
return library_agents[0]
|
return library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
async def match_credentials_to_requirements(
|
|
||||||
user_id: str,
|
|
||||||
requirements: dict[str, CredentialsFieldInfo],
|
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
|
||||||
"""
|
|
||||||
Match user's credentials against a dictionary of credential requirements.
|
|
||||||
|
|
||||||
This is the core matching logic shared by both graph and block credential matching.
|
|
||||||
"""
|
|
||||||
matched: dict[str, CredentialsMetaInput] = {}
|
|
||||||
missing: list[CredentialsMetaInput] = []
|
|
||||||
|
|
||||||
if not requirements:
|
|
||||||
return matched, missing
|
|
||||||
|
|
||||||
available_creds = await get_user_credentials(user_id)
|
|
||||||
|
|
||||||
for field_name, field_info in requirements.items():
|
|
||||||
matching_cred = find_matching_credential(available_creds, field_info)
|
|
||||||
|
|
||||||
if matching_cred:
|
|
||||||
try:
|
|
||||||
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
|
||||||
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
|
||||||
f"credential_id={matching_cred.id}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
provider = next(iter(field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
|
||||||
missing.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=f"{field_name} (validation failed: {e})",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
provider = next(iter(field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
|
||||||
missing.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=field_name.replace("_", " ").title(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matched, missing
|
|
||||||
|
|
||||||
|
|
||||||
async def get_user_credentials(user_id: str) -> list[Credentials]:
|
|
||||||
"""Get all available credentials for a user."""
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
return await creds_manager.store.get_all_creds(user_id)
|
|
||||||
|
|
||||||
|
|
||||||
def find_matching_credential(
|
|
||||||
available_creds: list[Credentials],
|
|
||||||
field_info: CredentialsFieldInfo,
|
|
||||||
) -> Credentials | None:
|
|
||||||
"""Find a credential that matches the required provider, type, scopes, and host."""
|
|
||||||
for cred in available_creds:
|
|
||||||
if cred.provider not in field_info.provider:
|
|
||||||
continue
|
|
||||||
if cred.type not in field_info.supported_types:
|
|
||||||
continue
|
|
||||||
if cred.type == "oauth2" and not _credential_has_required_scopes(
|
|
||||||
cred, field_info
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
|
|
||||||
continue
|
|
||||||
return cred
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def create_credential_meta_from_match(
|
|
||||||
matching_cred: Credentials,
|
|
||||||
) -> CredentialsMetaInput:
|
|
||||||
"""Create a CredentialsMetaInput from a matched credential."""
|
|
||||||
return CredentialsMetaInput(
|
|
||||||
id=matching_cred.id,
|
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def match_user_credentials_to_graph(
|
async def match_user_credentials_to_graph(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
graph: GraphModel,
|
graph: GraphModel,
|
||||||
@@ -356,8 +264,7 @@ async def match_user_credentials_to_graph(
|
|||||||
# provider is in the set of acceptable providers.
|
# provider is in the set of acceptable providers.
|
||||||
for credential_field_name, (
|
for credential_field_name, (
|
||||||
credential_requirements,
|
credential_requirements,
|
||||||
_,
|
_node_fields,
|
||||||
_,
|
|
||||||
) in aggregated_creds.items():
|
) in aggregated_creds.items():
|
||||||
# Find first matching credential by provider, type, and scopes
|
# Find first matching credential by provider, type, and scopes
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
@@ -366,14 +273,7 @@ async def match_user_credentials_to_graph(
|
|||||||
for cred in available_creds
|
for cred in available_creds
|
||||||
if cred.provider in credential_requirements.provider
|
if cred.provider in credential_requirements.provider
|
||||||
and cred.type in credential_requirements.supported_types
|
and cred.type in credential_requirements.supported_types
|
||||||
and (
|
and _credential_has_required_scopes(cred, credential_requirements)
|
||||||
cred.type != "oauth2"
|
|
||||||
or _credential_has_required_scopes(cred, credential_requirements)
|
|
||||||
)
|
|
||||||
and (
|
|
||||||
cred.type != "host_scoped"
|
|
||||||
or _credential_is_for_host(cred, credential_requirements)
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -418,32 +318,27 @@ async def match_user_credentials_to_graph(
|
|||||||
|
|
||||||
|
|
||||||
def _credential_has_required_scopes(
|
def _credential_has_required_scopes(
|
||||||
credential: OAuth2Credentials,
|
credential: Credentials,
|
||||||
requirements: CredentialsFieldInfo,
|
requirements: CredentialsFieldInfo,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if an OAuth2 credential has all the scopes required by the input."""
|
"""
|
||||||
|
Check if a credential has all the scopes required by the block.
|
||||||
|
|
||||||
|
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||||
|
of the required scopes. For other credential types, returns True (no scope check).
|
||||||
|
"""
|
||||||
|
# Only OAuth2 credentials have scopes to check
|
||||||
|
if credential.type != "oauth2":
|
||||||
|
return True
|
||||||
|
|
||||||
# If no scopes are required, any credential matches
|
# If no scopes are required, any credential matches
|
||||||
if not requirements.required_scopes:
|
if not requirements.required_scopes:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# Check that credential scopes are a superset of required scopes
|
||||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
def _credential_is_for_host(
|
|
||||||
credential: HostScopedCredentials,
|
|
||||||
requirements: CredentialsFieldInfo,
|
|
||||||
) -> bool:
|
|
||||||
"""Check if a host-scoped credential matches the host required by the input."""
|
|
||||||
# We need to know the host to match host-scoped credentials to.
|
|
||||||
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
|
|
||||||
# to discriminator_values. No discriminator_values -> no host to match against.
|
|
||||||
if not requirements.discriminator_values:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check that credential host matches required host.
|
|
||||||
# Host-scoped credential inputs are grouped by host, so any item from the set works.
|
|
||||||
return credential.matches_url(list(requirements.discriminator_values)[0])
|
|
||||||
|
|
||||||
|
|
||||||
async def check_user_has_required_credentials(
|
async def check_user_has_required_credentials(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
required_credentials: list[CredentialsMetaInput],
|
required_credentials: list[CredentialsMetaInput],
|
||||||
|
|||||||
@@ -19,10 +19,7 @@ from backend.data.graph import GraphSettings
|
|||||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||||
on_graph_activate,
|
|
||||||
on_graph_deactivate,
|
|
||||||
)
|
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
@@ -374,7 +371,7 @@ async def get_library_agent_by_graph_id(
|
|||||||
|
|
||||||
|
|
||||||
async def add_generated_agent_image(
|
async def add_generated_agent_image(
|
||||||
graph: graph_db.GraphBaseMeta,
|
graph: graph_db.BaseGraph,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
) -> Optional[prisma.models.LibraryAgent]:
|
) -> Optional[prisma.models.LibraryAgent]:
|
||||||
@@ -540,92 +537,6 @@ async def update_agent_version_in_library(
|
|||||||
return library_model.LibraryAgent.from_db(lib)
|
return library_model.LibraryAgent.from_db(lib)
|
||||||
|
|
||||||
|
|
||||||
async def create_graph_in_library(
|
|
||||||
graph: graph_db.Graph,
|
|
||||||
user_id: str,
|
|
||||||
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
|
||||||
"""Create a new graph and add it to the user's library."""
|
|
||||||
graph.version = 1
|
|
||||||
graph_model = graph_db.make_graph_model(graph, user_id)
|
|
||||||
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
|
||||||
|
|
||||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
|
||||||
|
|
||||||
library_agents = await create_library_agent(
|
|
||||||
graph=created_graph,
|
|
||||||
user_id=user_id,
|
|
||||||
sensitive_action_safe_mode=True,
|
|
||||||
create_library_agents_for_sub_graphs=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if created_graph.is_active:
|
|
||||||
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
|
||||||
|
|
||||||
return created_graph, library_agents[0]
|
|
||||||
|
|
||||||
|
|
||||||
async def update_graph_in_library(
|
|
||||||
graph: graph_db.Graph,
|
|
||||||
user_id: str,
|
|
||||||
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
|
||||||
"""Create a new version of an existing graph and update the library entry."""
|
|
||||||
existing_versions = await graph_db.get_graph_all_versions(graph.id, user_id)
|
|
||||||
current_active_version = (
|
|
||||||
next((v for v in existing_versions if v.is_active), None)
|
|
||||||
if existing_versions
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
graph.version = (
|
|
||||||
max(v.version for v in existing_versions) + 1 if existing_versions else 1
|
|
||||||
)
|
|
||||||
|
|
||||||
graph_model = graph_db.make_graph_model(graph, user_id)
|
|
||||||
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
|
||||||
|
|
||||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
|
||||||
|
|
||||||
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
|
||||||
if not library_agent:
|
|
||||||
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
|
||||||
|
|
||||||
library_agent = await update_library_agent_version_and_settings(
|
|
||||||
user_id, created_graph
|
|
||||||
)
|
|
||||||
|
|
||||||
if created_graph.is_active:
|
|
||||||
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
|
||||||
await graph_db.set_graph_active_version(
|
|
||||||
graph_id=created_graph.id,
|
|
||||||
version=created_graph.version,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
if current_active_version:
|
|
||||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
|
||||||
|
|
||||||
return created_graph, library_agent
|
|
||||||
|
|
||||||
|
|
||||||
async def update_library_agent_version_and_settings(
|
|
||||||
user_id: str, agent_graph: graph_db.GraphModel
|
|
||||||
) -> library_model.LibraryAgent:
|
|
||||||
"""Update library agent to point to new graph version and sync settings."""
|
|
||||||
library = await update_agent_version_in_library(
|
|
||||||
user_id, agent_graph.id, agent_graph.version
|
|
||||||
)
|
|
||||||
updated_settings = GraphSettings.from_graph(
|
|
||||||
graph=agent_graph,
|
|
||||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
|
||||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
|
||||||
)
|
|
||||||
if updated_settings != library.settings:
|
|
||||||
library = await update_library_agent(
|
|
||||||
library_agent_id=library.id,
|
|
||||||
user_id=user_id,
|
|
||||||
settings=updated_settings,
|
|
||||||
)
|
|
||||||
return library
|
|
||||||
|
|
||||||
|
|
||||||
async def update_library_agent(
|
async def update_library_agent(
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Literal, overload
|
from typing import Any, Literal
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
@@ -11,8 +11,8 @@ import prisma.types
|
|||||||
|
|
||||||
from backend.data.db import transaction
|
from backend.data.db import transaction
|
||||||
from backend.data.graph import (
|
from backend.data.graph import (
|
||||||
|
GraphMeta,
|
||||||
GraphModel,
|
GraphModel,
|
||||||
GraphModelWithoutNodes,
|
|
||||||
get_graph,
|
get_graph,
|
||||||
get_graph_as_admin,
|
get_graph_as_admin,
|
||||||
get_sub_graphs,
|
get_sub_graphs,
|
||||||
@@ -334,22 +334,7 @@ async def get_store_agent_details(
|
|||||||
raise DatabaseError("Failed to fetch agent details") from e
|
raise DatabaseError("Failed to fetch agent details") from e
|
||||||
|
|
||||||
|
|
||||||
@overload
|
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
||||||
async def get_available_graph(
|
|
||||||
store_listing_version_id: str, hide_nodes: Literal[False]
|
|
||||||
) -> GraphModel: ...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
async def get_available_graph(
|
|
||||||
store_listing_version_id: str, hide_nodes: Literal[True] = True
|
|
||||||
) -> GraphModelWithoutNodes: ...
|
|
||||||
|
|
||||||
|
|
||||||
async def get_available_graph(
|
|
||||||
store_listing_version_id: str,
|
|
||||||
hide_nodes: bool = True,
|
|
||||||
) -> GraphModelWithoutNodes | GraphModel:
|
|
||||||
try:
|
try:
|
||||||
# Get avaialble, non-deleted store listing version
|
# Get avaialble, non-deleted store listing version
|
||||||
store_listing_version = (
|
store_listing_version = (
|
||||||
@@ -359,7 +344,7 @@ async def get_available_graph(
|
|||||||
"isAvailable": True,
|
"isAvailable": True,
|
||||||
"isDeleted": False,
|
"isDeleted": False,
|
||||||
},
|
},
|
||||||
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
include={"AgentGraph": {"include": {"Nodes": True}}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -369,9 +354,7 @@ async def get_available_graph(
|
|||||||
detail=f"Store listing version {store_listing_version_id} not found",
|
detail=f"Store listing version {store_listing_version_id} not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
return (GraphModelWithoutNodes if hide_nodes else GraphModel).from_db(
|
return GraphModel.from_db(store_listing_version.AgentGraph).meta()
|
||||||
store_listing_version.AgentGraph
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting agent: {e}")
|
logger.error(f"Error getting agent: {e}")
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ Includes BM25 reranking for improved lexical relevance.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@@ -363,11 +362,7 @@ async def unified_hybrid_search(
|
|||||||
LIMIT {limit_param} OFFSET {offset_param}
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
results = await query_raw_with_schema(sql_query, *params)
|
||||||
results = await query_raw_with_schema(sql_query, *params)
|
|
||||||
except Exception as e:
|
|
||||||
await _log_vector_error_diagnostics(e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
total = results[0]["total_count"] if results else 0
|
total = results[0]["total_count"] if results else 0
|
||||||
# Apply BM25 reranking
|
# Apply BM25 reranking
|
||||||
@@ -691,11 +686,7 @@ async def hybrid_search(
|
|||||||
LIMIT {limit_param} OFFSET {offset_param}
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
results = await query_raw_with_schema(sql_query, *params)
|
||||||
results = await query_raw_with_schema(sql_query, *params)
|
|
||||||
except Exception as e:
|
|
||||||
await _log_vector_error_diagnostics(e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
total = results[0]["total_count"] if results else 0
|
total = results[0]["total_count"] if results else 0
|
||||||
|
|
||||||
@@ -727,87 +718,6 @@ async def hybrid_search_simple(
|
|||||||
return await hybrid_search(query=query, page=page, page_size=page_size)
|
return await hybrid_search(query=query, page=page, page_size=page_size)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Diagnostics
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
# Rate limit: only log vector error diagnostics once per this interval
|
|
||||||
_VECTOR_DIAG_INTERVAL_SECONDS = 60
|
|
||||||
_last_vector_diag_time: float = 0
|
|
||||||
|
|
||||||
|
|
||||||
async def _log_vector_error_diagnostics(error: Exception) -> None:
|
|
||||||
"""Log diagnostic info when 'type vector does not exist' error occurs.
|
|
||||||
|
|
||||||
Note: Diagnostic queries use query_raw_with_schema which may run on a different
|
|
||||||
pooled connection than the one that failed. Session-level search_path can differ,
|
|
||||||
so these diagnostics show cluster-wide state, not necessarily the failed session.
|
|
||||||
|
|
||||||
Includes rate limiting to avoid log spam - only logs once per minute.
|
|
||||||
Caller should re-raise the error after calling this function.
|
|
||||||
"""
|
|
||||||
global _last_vector_diag_time
|
|
||||||
|
|
||||||
# Check if this is the vector type error
|
|
||||||
error_str = str(error).lower()
|
|
||||||
if not (
|
|
||||||
"type" in error_str and "vector" in error_str and "does not exist" in error_str
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Rate limit: only log once per interval
|
|
||||||
now = time.time()
|
|
||||||
if now - _last_vector_diag_time < _VECTOR_DIAG_INTERVAL_SECONDS:
|
|
||||||
return
|
|
||||||
_last_vector_diag_time = now
|
|
||||||
|
|
||||||
try:
|
|
||||||
diagnostics: dict[str, object] = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
search_path_result = await query_raw_with_schema("SHOW search_path")
|
|
||||||
diagnostics["search_path"] = search_path_result
|
|
||||||
except Exception as e:
|
|
||||||
diagnostics["search_path"] = f"Error: {e}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
schema_result = await query_raw_with_schema("SELECT current_schema()")
|
|
||||||
diagnostics["current_schema"] = schema_result
|
|
||||||
except Exception as e:
|
|
||||||
diagnostics["current_schema"] = f"Error: {e}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
user_result = await query_raw_with_schema(
|
|
||||||
"SELECT current_user, session_user, current_database()"
|
|
||||||
)
|
|
||||||
diagnostics["user_info"] = user_result
|
|
||||||
except Exception as e:
|
|
||||||
diagnostics["user_info"] = f"Error: {e}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Check pgvector extension installation (cluster-wide, stable info)
|
|
||||||
ext_result = await query_raw_with_schema(
|
|
||||||
"SELECT extname, extversion, nspname as schema "
|
|
||||||
"FROM pg_extension e "
|
|
||||||
"JOIN pg_namespace n ON e.extnamespace = n.oid "
|
|
||||||
"WHERE extname = 'vector'"
|
|
||||||
)
|
|
||||||
diagnostics["pgvector_extension"] = ext_result
|
|
||||||
except Exception as e:
|
|
||||||
diagnostics["pgvector_extension"] = f"Error: {e}"
|
|
||||||
|
|
||||||
logger.error(
|
|
||||||
f"Vector type error diagnostics:\n"
|
|
||||||
f" Error: {error}\n"
|
|
||||||
f" search_path: {diagnostics.get('search_path')}\n"
|
|
||||||
f" current_schema: {diagnostics.get('current_schema')}\n"
|
|
||||||
f" user_info: {diagnostics.get('user_info')}\n"
|
|
||||||
f" pgvector_extension: {diagnostics.get('pgvector_extension')}"
|
|
||||||
)
|
|
||||||
except Exception as diag_error:
|
|
||||||
logger.error(f"Failed to collect vector error diagnostics: {diag_error}")
|
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
||||||
# for existing code that expects the popularity parameter
|
# for existing code that expects the popularity parameter
|
||||||
HybridSearchWeights = StoreAgentSearchWeights
|
HybridSearchWeights = StoreAgentSearchWeights
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from backend.blocks.ideogram import (
|
|||||||
StyleType,
|
StyleType,
|
||||||
UpscaleOption,
|
UpscaleOption,
|
||||||
)
|
)
|
||||||
from backend.data.graph import GraphBaseMeta
|
from backend.data.graph import BaseGraph
|
||||||
from backend.data.model import CredentialsMetaInput, ProviderName
|
from backend.data.model import CredentialsMetaInput, ProviderName
|
||||||
from backend.integrations.credentials_store import ideogram_credentials
|
from backend.integrations.credentials_store import ideogram_credentials
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
@@ -34,14 +34,14 @@ class ImageStyle(str, Enum):
|
|||||||
DIGITAL_ART = "digital art"
|
DIGITAL_ART = "digital art"
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||||
if settings.config.use_agent_image_generation_v2:
|
if settings.config.use_agent_image_generation_v2:
|
||||||
return await generate_agent_image_v2(graph=agent)
|
return await generate_agent_image_v2(graph=agent)
|
||||||
else:
|
else:
|
||||||
return await generate_agent_image_v1(agent=agent)
|
return await generate_agent_image_v1(agent=agent)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Generate an image for an agent using Ideogram model.
|
Generate an image for an agent using Ideogram model.
|
||||||
Returns:
|
Returns:
|
||||||
@@ -54,17 +54,14 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
|
|||||||
description = f"{name} ({graph.description})" if graph.description else name
|
description = f"{name} ({graph.description})" if graph.description else name
|
||||||
|
|
||||||
prompt = (
|
prompt = (
|
||||||
"Create a visually striking retro-futuristic vector pop art illustration "
|
f"Create a visually striking retro-futuristic vector pop art illustration prominently featuring "
|
||||||
f'prominently featuring "{name}" in bold typography. The image clearly and '
|
f'"{name}" in bold typography. The image clearly and literally depicts a {description}, '
|
||||||
f"literally depicts a {description}, along with recognizable objects directly "
|
f"along with recognizable objects directly associated with the primary function of a {name}. "
|
||||||
f"associated with the primary function of a {name}. "
|
f"Ensure the imagery is concrete, intuitive, and immediately understandable, clearly conveying the "
|
||||||
f"Ensure the imagery is concrete, intuitive, and immediately understandable, "
|
f"purpose of a {name}. Maintain vibrant, limited-palette colors, sharp vector lines, geometric "
|
||||||
f"clearly conveying the purpose of a {name}. "
|
f"shapes, flat illustration techniques, and solid colors without gradients or shading. Preserve a "
|
||||||
"Maintain vibrant, limited-palette colors, sharp vector lines, "
|
f"retro-futuristic aesthetic influenced by mid-century futurism and 1960s psychedelia, "
|
||||||
"geometric shapes, flat illustration techniques, and solid colors "
|
f"prioritizing clear visual storytelling and thematic clarity above all else."
|
||||||
"without gradients or shading. Preserve a retro-futuristic aesthetic "
|
|
||||||
"influenced by mid-century futurism and 1960s psychedelia, "
|
|
||||||
"prioritizing clear visual storytelling and thematic clarity above all else."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
custom_colors = [
|
custom_colors = [
|
||||||
@@ -102,12 +99,12 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
|
|||||||
return io.BytesIO(response.content)
|
return io.BytesIO(response.content)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Generate an image for an agent using Flux model via Replicate API.
|
Generate an image for an agent using Flux model via Replicate API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent (GraphBaseMeta | AgentGraph): The agent to generate an image for
|
agent (Graph): The agent to generate an image for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
io.BytesIO: The generated image as bytes
|
io.BytesIO: The generated image as bytes
|
||||||
@@ -117,13 +114,7 @@ async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.Bytes
|
|||||||
raise ValueError("Missing Replicate API key in settings")
|
raise ValueError("Missing Replicate API key in settings")
|
||||||
|
|
||||||
# Construct prompt from agent details
|
# Construct prompt from agent details
|
||||||
prompt = (
|
prompt = f"Create a visually engaging app store thumbnail for the AI agent that highlights what it does in a clear and captivating way:\n- **Name**: {agent.name}\n- **Description**: {agent.description}\nFocus on showcasing its core functionality with an appealing design."
|
||||||
"Create a visually engaging app store thumbnail for the AI agent "
|
|
||||||
"that highlights what it does in a clear and captivating way:\n"
|
|
||||||
f"- **Name**: {agent.name}\n"
|
|
||||||
f"- **Description**: {agent.description}\n"
|
|
||||||
f"Focus on showcasing its core functionality with an appealing design."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set up Replicate client
|
# Set up Replicate client
|
||||||
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ async def get_agent(
|
|||||||
)
|
)
|
||||||
async def get_graph_meta_by_store_listing_version_id(
|
async def get_graph_meta_by_store_listing_version_id(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
) -> backend.data.graph.GraphMeta:
|
||||||
"""
|
"""
|
||||||
Get Agent Graph from Store Listing Version ID.
|
Get Agent Graph from Store Listing Version ID.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ from backend.util.timezone_utils import (
|
|||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
|
||||||
from .library import db as library_db
|
from .library import db as library_db
|
||||||
|
from .library import model as library_model
|
||||||
from .store.model import StoreAgentDetails
|
from .store.model import StoreAgentDetails
|
||||||
|
|
||||||
|
|
||||||
@@ -822,16 +823,18 @@ async def update_graph(
|
|||||||
graph: graph_db.Graph,
|
graph: graph_db.Graph,
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> graph_db.GraphModel:
|
) -> graph_db.GraphModel:
|
||||||
|
# Sanity check
|
||||||
if graph.id and graph.id != graph_id:
|
if graph.id and graph.id != graph_id:
|
||||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||||
|
|
||||||
|
# Determine new version
|
||||||
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||||
if not existing_versions:
|
if not existing_versions:
|
||||||
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
||||||
|
latest_version_number = max(g.version for g in existing_versions)
|
||||||
|
graph.version = latest_version_number + 1
|
||||||
|
|
||||||
graph.version = max(g.version for g in existing_versions) + 1
|
|
||||||
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
||||||
|
|
||||||
graph = graph_db.make_graph_model(graph, user_id)
|
graph = graph_db.make_graph_model(graph, user_id)
|
||||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||||
graph.validate_graph(for_run=False)
|
graph.validate_graph(for_run=False)
|
||||||
@@ -839,23 +842,27 @@ async def update_graph(
|
|||||||
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
||||||
|
|
||||||
if new_graph_version.is_active:
|
if new_graph_version.is_active:
|
||||||
await library_db.update_library_agent_version_and_settings(
|
# Keep the library agent up to date with the new active version
|
||||||
user_id, new_graph_version
|
await _update_library_agent_version_and_settings(user_id, new_graph_version)
|
||||||
)
|
|
||||||
|
# Handle activation of the new graph first to ensure continuity
|
||||||
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
||||||
|
# Ensure new version is the only active version
|
||||||
await graph_db.set_graph_active_version(
|
await graph_db.set_graph_active_version(
|
||||||
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
||||||
)
|
)
|
||||||
if current_active_version:
|
if current_active_version:
|
||||||
|
# Handle deactivation of the previously active version
|
||||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||||
|
|
||||||
|
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
|
||||||
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
||||||
graph_id,
|
graph_id,
|
||||||
new_graph_version.version,
|
new_graph_version.version,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
include_subgraphs=True,
|
include_subgraphs=True,
|
||||||
)
|
)
|
||||||
assert new_graph_version_with_subgraphs
|
assert new_graph_version_with_subgraphs # make type checker happy
|
||||||
return new_graph_version_with_subgraphs
|
return new_graph_version_with_subgraphs
|
||||||
|
|
||||||
|
|
||||||
@@ -893,15 +900,33 @@ async def set_graph_active_version(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Keep the library agent up to date with the new active version
|
# Keep the library agent up to date with the new active version
|
||||||
await library_db.update_library_agent_version_and_settings(
|
await _update_library_agent_version_and_settings(user_id, new_active_graph)
|
||||||
user_id, new_active_graph
|
|
||||||
)
|
|
||||||
|
|
||||||
if current_active_graph and current_active_graph.version != new_active_version:
|
if current_active_graph and current_active_graph.version != new_active_version:
|
||||||
# Handle deactivation of the previously active version
|
# Handle deactivation of the previously active version
|
||||||
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_library_agent_version_and_settings(
|
||||||
|
user_id: str, agent_graph: graph_db.GraphModel
|
||||||
|
) -> library_model.LibraryAgent:
|
||||||
|
library = await library_db.update_agent_version_in_library(
|
||||||
|
user_id, agent_graph.id, agent_graph.version
|
||||||
|
)
|
||||||
|
updated_settings = GraphSettings.from_graph(
|
||||||
|
graph=agent_graph,
|
||||||
|
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||||
|
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||||
|
)
|
||||||
|
if updated_settings != library.settings:
|
||||||
|
library = await library_db.update_library_agent(
|
||||||
|
library_agent_id=library.id,
|
||||||
|
user_id=user_id,
|
||||||
|
settings=updated_settings,
|
||||||
|
)
|
||||||
|
return library
|
||||||
|
|
||||||
|
|
||||||
@v1_router.patch(
|
@v1_router.patch(
|
||||||
path="/graphs/{graph_id}/settings",
|
path="/graphs/{graph_id}/settings",
|
||||||
summary="Update graph settings",
|
summary="Update graph settings",
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
import base64
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import shlex
|
import shlex
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
@@ -23,11 +21,6 @@ from backend.data.model import (
|
|||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Maximum size for binary files to extract (50MB)
|
|
||||||
MAX_BINARY_FILE_SIZE = 50 * 1024 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
class ClaudeCodeExecutionError(Exception):
|
class ClaudeCodeExecutionError(Exception):
|
||||||
"""Exception raised when Claude Code execution fails.
|
"""Exception raised when Claude Code execution fails.
|
||||||
@@ -187,9 +180,7 @@ class ClaudeCodeBlock(Block):
|
|||||||
path: str
|
path: str
|
||||||
relative_path: str # Path relative to working directory (for GitHub, etc.)
|
relative_path: str # Path relative to working directory (for GitHub, etc.)
|
||||||
name: str
|
name: str
|
||||||
content: str # Text content for text files, empty string for binary files
|
content: str
|
||||||
is_binary: bool = False # True if this is a binary file
|
|
||||||
content_base64: Optional[str] = None # Base64-encoded content for binary files
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
response: str = SchemaField(
|
response: str = SchemaField(
|
||||||
@@ -197,11 +188,8 @@ class ClaudeCodeBlock(Block):
|
|||||||
)
|
)
|
||||||
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
|
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
|
||||||
description=(
|
description=(
|
||||||
"List of files created/modified by Claude Code during this execution. "
|
"List of text files created/modified by Claude Code during this execution. "
|
||||||
"Each file has 'path', 'relative_path', 'name', 'content', 'is_binary', "
|
"Each file has 'path', 'relative_path', 'name', and 'content' fields."
|
||||||
"and 'content_base64' fields. For text files, 'content' contains the text "
|
|
||||||
"and 'is_binary' is False. For binary files (PDFs, images, etc.), "
|
|
||||||
"'is_binary' is True and 'content_base64' contains the base64-encoded data."
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conversation_history: str = SchemaField(
|
conversation_history: str = SchemaField(
|
||||||
@@ -264,8 +252,6 @@ class ClaudeCodeBlock(Block):
|
|||||||
"relative_path": "index.html",
|
"relative_path": "index.html",
|
||||||
"name": "index.html",
|
"name": "index.html",
|
||||||
"content": "<html>Hello World</html>",
|
"content": "<html>Hello World</html>",
|
||||||
"is_binary": False,
|
|
||||||
"content_base64": None,
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
@@ -286,8 +272,6 @@ class ClaudeCodeBlock(Block):
|
|||||||
relative_path="index.html",
|
relative_path="index.html",
|
||||||
name="index.html",
|
name="index.html",
|
||||||
content="<html>Hello World</html>",
|
content="<html>Hello World</html>",
|
||||||
is_binary=False,
|
|
||||||
content_base64=None,
|
|
||||||
)
|
)
|
||||||
], # files
|
], # files
|
||||||
"User: Create a hello world HTML file\n"
|
"User: Create a hello world HTML file\n"
|
||||||
@@ -547,6 +531,7 @@ class ClaudeCodeBlock(Block):
|
|||||||
".env",
|
".env",
|
||||||
".gitignore",
|
".gitignore",
|
||||||
".dockerfile",
|
".dockerfile",
|
||||||
|
"Dockerfile",
|
||||||
".vue",
|
".vue",
|
||||||
".svelte",
|
".svelte",
|
||||||
".astro",
|
".astro",
|
||||||
@@ -555,44 +540,6 @@ class ClaudeCodeBlock(Block):
|
|||||||
".tex",
|
".tex",
|
||||||
".csv",
|
".csv",
|
||||||
".log",
|
".log",
|
||||||
".svg", # SVG is XML-based text
|
|
||||||
}
|
|
||||||
|
|
||||||
# Binary file extensions we can read and base64-encode
|
|
||||||
binary_extensions = {
|
|
||||||
# Images
|
|
||||||
".png",
|
|
||||||
".jpg",
|
|
||||||
".jpeg",
|
|
||||||
".gif",
|
|
||||||
".webp",
|
|
||||||
".ico",
|
|
||||||
".bmp",
|
|
||||||
".tiff",
|
|
||||||
".tif",
|
|
||||||
# Documents
|
|
||||||
".pdf",
|
|
||||||
# Archives (useful for downloads)
|
|
||||||
".zip",
|
|
||||||
".tar",
|
|
||||||
".gz",
|
|
||||||
".7z",
|
|
||||||
# Audio/Video (if small enough)
|
|
||||||
".mp3",
|
|
||||||
".wav",
|
|
||||||
".mp4",
|
|
||||||
".webm",
|
|
||||||
# Other binary formats
|
|
||||||
".woff",
|
|
||||||
".woff2",
|
|
||||||
".ttf",
|
|
||||||
".otf",
|
|
||||||
".eot",
|
|
||||||
".bin",
|
|
||||||
".exe",
|
|
||||||
".dll",
|
|
||||||
".so",
|
|
||||||
".dylib",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -617,26 +564,10 @@ class ClaudeCodeBlock(Block):
|
|||||||
if not file_path:
|
if not file_path:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if it's a text file we can read (case-insensitive)
|
# Check if it's a text file we can read
|
||||||
file_path_lower = file_path.lower()
|
|
||||||
is_text = any(
|
is_text = any(
|
||||||
file_path_lower.endswith(ext) for ext in text_extensions
|
file_path.endswith(ext) for ext in text_extensions
|
||||||
) or file_path_lower.endswith("dockerfile")
|
) or file_path.endswith("Dockerfile")
|
||||||
|
|
||||||
# Check if it's a binary file we should extract
|
|
||||||
is_binary = any(
|
|
||||||
file_path_lower.endswith(ext) for ext in binary_extensions
|
|
||||||
)
|
|
||||||
|
|
||||||
# Helper to extract filename and relative path
|
|
||||||
def get_file_info(path: str, work_dir: str) -> tuple[str, str]:
|
|
||||||
name = path.split("/")[-1]
|
|
||||||
rel_path = path
|
|
||||||
if path.startswith(work_dir):
|
|
||||||
rel_path = path[len(work_dir) :]
|
|
||||||
if rel_path.startswith("/"):
|
|
||||||
rel_path = rel_path[1:]
|
|
||||||
return name, rel_path
|
|
||||||
|
|
||||||
if is_text:
|
if is_text:
|
||||||
try:
|
try:
|
||||||
@@ -645,72 +576,32 @@ class ClaudeCodeBlock(Block):
|
|||||||
if isinstance(content, bytes):
|
if isinstance(content, bytes):
|
||||||
content = content.decode("utf-8", errors="replace")
|
content = content.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
file_name, relative_path = get_file_info(
|
# Extract filename from path
|
||||||
file_path, working_directory
|
file_name = file_path.split("/")[-1]
|
||||||
)
|
|
||||||
|
# Calculate relative path by stripping working directory
|
||||||
|
relative_path = file_path
|
||||||
|
if file_path.startswith(working_directory):
|
||||||
|
relative_path = file_path[len(working_directory) :]
|
||||||
|
# Remove leading slash if present
|
||||||
|
if relative_path.startswith("/"):
|
||||||
|
relative_path = relative_path[1:]
|
||||||
|
|
||||||
files.append(
|
files.append(
|
||||||
ClaudeCodeBlock.FileOutput(
|
ClaudeCodeBlock.FileOutput(
|
||||||
path=file_path,
|
path=file_path,
|
||||||
relative_path=relative_path,
|
relative_path=relative_path,
|
||||||
name=file_name,
|
name=file_name,
|
||||||
content=content,
|
content=content,
|
||||||
is_binary=False,
|
|
||||||
content_base64=None,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.warning(f"Failed to read text file {file_path}: {e}")
|
# Skip files that can't be read
|
||||||
elif is_binary:
|
pass
|
||||||
try:
|
|
||||||
# Check file size before reading to avoid OOM
|
|
||||||
stat_result = await sandbox.commands.run(
|
|
||||||
f"stat -c %s {shlex.quote(file_path)} 2>/dev/null"
|
|
||||||
)
|
|
||||||
if stat_result.exit_code != 0 or not stat_result.stdout:
|
|
||||||
logger.warning(
|
|
||||||
f"Skipping binary file {file_path}: "
|
|
||||||
f"could not determine file size"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
file_size = int(stat_result.stdout.strip())
|
|
||||||
if file_size > MAX_BINARY_FILE_SIZE:
|
|
||||||
logger.warning(
|
|
||||||
f"Skipping binary file {file_path}: "
|
|
||||||
f"size {file_size} exceeds limit "
|
|
||||||
f"{MAX_BINARY_FILE_SIZE}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Read binary file as bytes using format="bytes"
|
except Exception:
|
||||||
content_bytes = await sandbox.files.read(
|
# If file extraction fails, return empty results
|
||||||
file_path, format="bytes"
|
pass
|
||||||
)
|
|
||||||
|
|
||||||
# Base64 encode the binary content
|
|
||||||
content_b64 = base64.b64encode(content_bytes).decode(
|
|
||||||
"ascii"
|
|
||||||
)
|
|
||||||
|
|
||||||
file_name, relative_path = get_file_info(
|
|
||||||
file_path, working_directory
|
|
||||||
)
|
|
||||||
files.append(
|
|
||||||
ClaudeCodeBlock.FileOutput(
|
|
||||||
path=file_path,
|
|
||||||
relative_path=relative_path,
|
|
||||||
name=file_name,
|
|
||||||
content="", # Empty for binary files
|
|
||||||
is_binary=True,
|
|
||||||
content_base64=content_b64,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to read binary file {file_path}: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"File extraction failed: {e}")
|
|
||||||
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
|
|||||||
@@ -1,28 +0,0 @@
|
|||||||
"""ElevenLabs integration blocks - test credentials and shared utilities."""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
TEST_CREDENTIALS = APIKeyCredentials(
|
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
|
||||||
provider="elevenlabs",
|
|
||||||
api_key=SecretStr("mock-elevenlabs-api-key"),
|
|
||||||
title="Mock ElevenLabs API key",
|
|
||||||
expires_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
TEST_CREDENTIALS_INPUT = {
|
|
||||||
"provider": TEST_CREDENTIALS.provider,
|
|
||||||
"id": TEST_CREDENTIALS.id,
|
|
||||||
"type": TEST_CREDENTIALS.type,
|
|
||||||
"title": TEST_CREDENTIALS.title,
|
|
||||||
}
|
|
||||||
|
|
||||||
ElevenLabsCredentials = APIKeyCredentials
|
|
||||||
ElevenLabsCredentialsInput = CredentialsMetaInput[
|
|
||||||
Literal[ProviderName.ELEVENLABS], Literal["api_key"]
|
|
||||||
]
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
"""Text encoding block for converting special characters to escape sequences."""
|
|
||||||
|
|
||||||
import codecs
|
|
||||||
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
|
|
||||||
|
|
||||||
class TextEncoderBlock(Block):
|
|
||||||
"""
|
|
||||||
Encodes a string by converting special characters into escape sequences.
|
|
||||||
|
|
||||||
This block is the inverse of TextDecoderBlock. It takes text containing
|
|
||||||
special characters (like newlines, tabs, etc.) and converts them into
|
|
||||||
their escape sequence representations (e.g., newline becomes \\n).
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
"""Input schema for TextEncoderBlock."""
|
|
||||||
|
|
||||||
text: str = SchemaField(
|
|
||||||
description="A string containing special characters to be encoded",
|
|
||||||
placeholder="Your text with newlines and quotes to encode",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
"""Output schema for TextEncoderBlock."""
|
|
||||||
|
|
||||||
encoded_text: str = SchemaField(
|
|
||||||
description="The encoded text with special characters converted to escape sequences"
|
|
||||||
)
|
|
||||||
error: str = SchemaField(description="Error message if encoding fails")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="5185f32e-4b65-4ecf-8fbb-873f003f09d6",
|
|
||||||
description="Encodes a string by converting special characters into escape sequences",
|
|
||||||
categories={BlockCategory.TEXT},
|
|
||||||
input_schema=TextEncoderBlock.Input,
|
|
||||||
output_schema=TextEncoderBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"text": """Hello
|
|
||||||
World!
|
|
||||||
This is a "quoted" string."""
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"encoded_text",
|
|
||||||
"""Hello\\nWorld!\\nThis is a "quoted" string.""",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
|
||||||
"""
|
|
||||||
Encode the input text by converting special characters to escape sequences.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_data: The input containing the text to encode.
|
|
||||||
**kwargs: Additional keyword arguments (unused).
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
The encoded text with escape sequences, or an error message if encoding fails.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
encoded_text = codecs.encode(input_data.text, "unicode_escape").decode(
|
|
||||||
"utf-8"
|
|
||||||
)
|
|
||||||
yield "encoded_text", encoded_text
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", f"Encoding error: {str(e)}"
|
|
||||||
@@ -478,7 +478,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
webset = await aexa.websets.get(id=input_data.external_id)
|
webset = aexa.websets.get(id=input_data.external_id)
|
||||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||||
|
|
||||||
yield "webset", webset_result
|
yield "webset", webset_result
|
||||||
@@ -494,7 +494,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
count=input_data.search_count,
|
count=input_data.search_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
webset = await aexa.websets.create(
|
webset = aexa.websets.create(
|
||||||
params=CreateWebsetParameters(
|
params=CreateWebsetParameters(
|
||||||
search=search_params,
|
search=search_params,
|
||||||
external_id=input_data.external_id,
|
external_id=input_data.external_id,
|
||||||
@@ -554,7 +554,7 @@ class ExaUpdateWebsetBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
|
sdk_webset = aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -617,7 +617,7 @@ class ExaListWebsetsBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = await aexa.websets.list(
|
response = aexa.websets.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -678,7 +678,7 @@ class ExaGetWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
|
sdk_webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -748,7 +748,7 @@ class ExaDeleteWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
|
deleted_webset = aexa.websets.delete(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
deleted_webset.status.value
|
deleted_webset.status.value
|
||||||
@@ -798,7 +798,7 @@ class ExaCancelWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
|
canceled_webset = aexa.websets.cancel(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
canceled_webset.status.value
|
canceled_webset.status.value
|
||||||
@@ -968,7 +968,7 @@ class ExaPreviewWebsetBlock(Block):
|
|||||||
entity["description"] = input_data.entity_description
|
entity["description"] = input_data.entity_description
|
||||||
payload["entity"] = entity
|
payload["entity"] = entity
|
||||||
|
|
||||||
sdk_preview = await aexa.websets.preview(params=payload)
|
sdk_preview = aexa.websets.preview(params=payload)
|
||||||
|
|
||||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||||
|
|
||||||
@@ -1051,7 +1051,7 @@ class ExaWebsetStatusBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
@@ -1185,7 +1185,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Extract basic info
|
# Extract basic info
|
||||||
webset_id = webset.id
|
webset_id = webset.id
|
||||||
@@ -1211,7 +1211,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
total_items = 0
|
total_items = 0
|
||||||
|
|
||||||
if input_data.include_sample_items and input_data.sample_size > 0:
|
if input_data.include_sample_items and input_data.sample_size > 0:
|
||||||
items_response = await aexa.websets.items.list(
|
items_response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
sample_items_data = [
|
sample_items_data = [
|
||||||
@@ -1362,7 +1362,7 @@ class ExaWebsetReadyCheckBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset details
|
# Get webset details
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = await aexa.websets.enrichments.create(
|
sdk_enrichment = aexa.websets.enrichments.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -223,7 +223,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_enrich = await aexa.websets.enrichments.get(
|
current_enrich = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=enrichment_id
|
webset_id=input_data.webset_id, id=enrichment_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -234,7 +234,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
|
|
||||||
if current_status in ["completed", "failed", "cancelled"]:
|
if current_status in ["completed", "failed", "cancelled"]:
|
||||||
# Estimate items from webset searches
|
# Estimate items from webset searches
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
for search in webset.searches:
|
for search in webset.searches:
|
||||||
if search.progress:
|
if search.progress:
|
||||||
@@ -329,7 +329,7 @@ class ExaGetEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = await aexa.websets.enrichments.get(
|
sdk_enrichment = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -474,7 +474,7 @@ class ExaDeleteEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_enrichment = await aexa.websets.enrichments.delete(
|
deleted_enrichment = aexa.websets.enrichments.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -525,13 +525,13 @@ class ExaCancelEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_enrichment = await aexa.websets.enrichments.cancel(
|
canceled_enrichment = aexa.websets.enrichments.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to estimate how many items were enriched before cancellation
|
# Try to estimate how many items were enriched before cancellation
|
||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
items_response = await aexa.websets.items.list(
|
items_response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=100
|
webset_id=input_data.webset_id, limit=100
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
# Create mock SDK import object
|
# Create mock SDK import object
|
||||||
mock_import = MagicMock()
|
mock_import = MagicMock()
|
||||||
@@ -247,7 +247,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
imports=MagicMock(create=AsyncMock(return_value=mock_import))
|
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -294,7 +294,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_import = await aexa.websets.imports.create(
|
sdk_import = aexa.websets.imports.create(
|
||||||
params=payload, csv_data=input_data.csv_data
|
params=payload, csv_data=input_data.csv_data
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -360,7 +360,7 @@ class ExaGetImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
|
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
|
||||||
|
|
||||||
import_obj = ImportModel.from_sdk(sdk_import)
|
import_obj = ImportModel.from_sdk(sdk_import)
|
||||||
|
|
||||||
@@ -426,7 +426,7 @@ class ExaListImportsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = await aexa.websets.imports.list(
|
response = aexa.websets.imports.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -474,9 +474,7 @@ class ExaDeleteImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_import = await aexa.websets.imports.delete(
|
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
|
||||||
import_id=input_data.import_id
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "import_id", deleted_import.id
|
yield "import_id", deleted_import.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -575,14 +573,14 @@ class ExaExportWebsetBlock(Block):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create async iterator for list_all
|
# Create mock iterator
|
||||||
async def async_item_iterator(*args, **kwargs):
|
mock_items = [mock_item1, mock_item2]
|
||||||
for item in [mock_item1, mock_item2]:
|
|
||||||
yield item
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(items=MagicMock(list_all=async_item_iterator))
|
websets=MagicMock(
|
||||||
|
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
|
||||||
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -604,7 +602,7 @@ class ExaExportWebsetBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
async for sdk_item in item_iterator:
|
for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class ExaGetWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_item = await aexa.websets.items.get(
|
sdk_item = aexa.websets.items.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -269,7 +269,7 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
response = None
|
response = None
|
||||||
|
|
||||||
while time.time() - start_time < input_data.wait_timeout:
|
while time.time() - start_time < input_data.wait_timeout:
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -282,13 +282,13 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
interval = min(interval * 1.2, 10)
|
interval = min(interval * 1.2, 10)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -340,7 +340,7 @@ class ExaDeleteWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_item = await aexa.websets.items.delete(
|
deleted_item = aexa.websets.items.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -408,7 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
async for sdk_item in item_iterator:
|
for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -475,7 +475,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
entity_type = "unknown"
|
entity_type = "unknown"
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
@@ -495,7 +495,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Get sample items if requested
|
# Get sample items if requested
|
||||||
sample_items: List[WebsetItemModel] = []
|
sample_items: List[WebsetItemModel] = []
|
||||||
if input_data.sample_size > 0:
|
if input_data.sample_size > 0:
|
||||||
items_response = await aexa.websets.items.list(
|
items_response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
# Convert to our stable models
|
# Convert to our stable models
|
||||||
@@ -569,7 +569,7 @@ class ExaGetNewItemsBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get items starting from cursor
|
# Get items starting from cursor
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.since_cursor,
|
cursor=input_data.since_cursor,
|
||||||
limit=input_data.max_items,
|
limit=input_data.max_items,
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
# Create mock SDK monitor object
|
# Create mock SDK monitor object
|
||||||
mock_monitor = MagicMock()
|
mock_monitor = MagicMock()
|
||||||
@@ -263,7 +263,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
monitors=MagicMock(create=AsyncMock(return_value=mock_monitor))
|
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -320,7 +320,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = await aexa.websets.monitors.create(params=payload)
|
sdk_monitor = aexa.websets.monitors.create(params=payload)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -384,7 +384,7 @@ class ExaGetMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -476,7 +476,7 @@ class ExaUpdateMonitorBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = await aexa.websets.monitors.update(
|
sdk_monitor = aexa.websets.monitors.update(
|
||||||
monitor_id=input_data.monitor_id, params=payload
|
monitor_id=input_data.monitor_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -522,9 +522,7 @@ class ExaDeleteMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_monitor = await aexa.websets.monitors.delete(
|
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
|
||||||
monitor_id=input_data.monitor_id
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "monitor_id", deleted_monitor.id
|
yield "monitor_id", deleted_monitor.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -581,7 +579,7 @@ class ExaListMonitorsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = await aexa.websets.monitors.list(
|
response = aexa.websets.monitors.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
WebsetTargetStatus.IDLE,
|
WebsetTargetStatus.IDLE,
|
||||||
WebsetTargetStatus.ANY_COMPLETE,
|
WebsetTargetStatus.ANY_COMPLETE,
|
||||||
]:
|
]:
|
||||||
final_webset = await aexa.websets.wait_until_idle(
|
final_webset = aexa.websets.wait_until_idle(
|
||||||
id=input_data.webset_id,
|
id=input_data.webset_id,
|
||||||
timeout=input_data.timeout,
|
timeout=input_data.timeout,
|
||||||
poll_interval=input_data.check_interval,
|
poll_interval=input_data.check_interval,
|
||||||
@@ -164,7 +164,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
interval = input_data.check_interval
|
interval = input_data.check_interval
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current webset status
|
# Get current webset status
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
current_status = (
|
current_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -209,7 +209,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
|
|
||||||
# Timeout reached
|
# Timeout reached
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
final_status = (
|
final_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -345,7 +345,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current search status using SDK
|
# Get current search status using SDK
|
||||||
search = await aexa.websets.searches.get(
|
search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -401,7 +401,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
search = await aexa.websets.searches.get(
|
search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -503,7 +503,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current enrichment status using SDK
|
# Get current enrichment status using SDK
|
||||||
enrichment = await aexa.websets.enrichments.get(
|
enrichment = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -548,7 +548,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
enrichment = await aexa.websets.enrichments.get(
|
enrichment = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -575,7 +575,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
) -> tuple[list[SampleEnrichmentModel], int]:
|
) -> tuple[list[SampleEnrichmentModel], int]:
|
||||||
"""Get sample enriched data and count."""
|
"""Get sample enriched data and count."""
|
||||||
# Get a few items to see enrichment results using SDK
|
# Get a few items to see enrichment results using SDK
|
||||||
response = await aexa.websets.items.list(webset_id=webset_id, limit=5)
|
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||||
|
|
||||||
sample_data: list[SampleEnrichmentModel] = []
|
sample_data: list[SampleEnrichmentModel] = []
|
||||||
enriched_count = 0
|
enriched_count = 0
|
||||||
|
|||||||
@@ -317,7 +317,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
|
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = await aexa.websets.searches.create(
|
sdk_search = aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -350,7 +350,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
poll_start = time.time()
|
poll_start = time.time()
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_search = await aexa.websets.searches.get(
|
current_search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=search_id
|
webset_id=input_data.webset_id, id=search_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -442,7 +442,7 @@ class ExaGetWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = await aexa.websets.searches.get(
|
sdk_search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -523,7 +523,7 @@ class ExaCancelWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_search = await aexa.websets.searches.cancel(
|
canceled_search = aexa.websets.searches.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -604,7 +604,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset to check existing searches
|
# Get webset to check existing searches
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Look for existing search with same query
|
# Look for existing search with same query
|
||||||
existing_search = None
|
existing_search = None
|
||||||
@@ -636,7 +636,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
if input_data.entity_type != SearchEntityType.AUTO:
|
if input_data.entity_type != SearchEntityType.AUTO:
|
||||||
payload["entity"] = {"type": input_data.entity_type.value}
|
payload["entity"] = {"type": input_data.entity_type.value}
|
||||||
|
|
||||||
sdk_search = await aexa.websets.searches.create(
|
sdk_search = aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -21,71 +21,43 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class HumanInTheLoopBlock(Block):
|
class HumanInTheLoopBlock(Block):
|
||||||
"""
|
"""
|
||||||
Pauses execution and waits for human approval or rejection of the data.
|
This block pauses execution and waits for human approval or modification of the data.
|
||||||
|
|
||||||
When executed, this block creates a pending review entry and sets the node execution
|
When executed, it creates a pending review entry and sets the node execution status
|
||||||
status to REVIEW. The execution remains paused until a human user either approves
|
to REVIEW. The execution will remain paused until a human user either:
|
||||||
or rejects the data.
|
- Approves the data (with or without modifications)
|
||||||
|
- Rejects the data
|
||||||
|
|
||||||
**How it works:**
|
This is useful for workflows that require human validation or intervention before
|
||||||
- The input data is presented to a human reviewer
|
proceeding to the next steps.
|
||||||
- The reviewer can approve or reject (and optionally modify the data if editable)
|
|
||||||
- On approval: the data flows out through the `approved_data` output pin
|
|
||||||
- On rejection: the data flows out through the `rejected_data` output pin
|
|
||||||
|
|
||||||
**Important:** The output pins yield the actual data itself, NOT status strings.
|
|
||||||
The approval/rejection decision determines WHICH output pin fires, not the value.
|
|
||||||
You do NOT need to compare the output to "APPROVED" or "REJECTED" - simply connect
|
|
||||||
downstream blocks to the appropriate output pin for each case.
|
|
||||||
|
|
||||||
**Example usage:**
|
|
||||||
- Connect `approved_data` → next step in your workflow (data was approved)
|
|
||||||
- Connect `rejected_data` → error handling or notification (data was rejected)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
class Input(BlockSchemaInput):
|
||||||
data: Any = SchemaField(
|
data: Any = SchemaField(description="The data to be reviewed by a human user")
|
||||||
description="The data to be reviewed by a human user. "
|
|
||||||
"This exact data will be passed through to either approved_data or "
|
|
||||||
"rejected_data output based on the reviewer's decision."
|
|
||||||
)
|
|
||||||
name: str = SchemaField(
|
name: str = SchemaField(
|
||||||
description="A descriptive name for what this data represents. "
|
description="A descriptive name for what this data represents",
|
||||||
"This helps the reviewer understand what they are reviewing.",
|
|
||||||
)
|
)
|
||||||
editable: bool = SchemaField(
|
editable: bool = SchemaField(
|
||||||
description="Whether the human reviewer can edit the data before "
|
description="Whether the human reviewer can edit the data",
|
||||||
"approving or rejecting it",
|
|
||||||
default=True,
|
default=True,
|
||||||
advanced=True,
|
advanced=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
approved_data: Any = SchemaField(
|
approved_data: Any = SchemaField(
|
||||||
description="Outputs the input data when the reviewer APPROVES it. "
|
description="The data when approved (may be modified by reviewer)"
|
||||||
"The value is the actual data itself (not a status string like 'APPROVED'). "
|
|
||||||
"If the reviewer edited the data, this contains the modified version. "
|
|
||||||
"Connect downstream blocks here for the 'approved' workflow path."
|
|
||||||
)
|
)
|
||||||
rejected_data: Any = SchemaField(
|
rejected_data: Any = SchemaField(
|
||||||
description="Outputs the input data when the reviewer REJECTS it. "
|
description="The data when rejected (may be modified by reviewer)"
|
||||||
"The value is the actual data itself (not a status string like 'REJECTED'). "
|
|
||||||
"If the reviewer edited the data, this contains the modified version. "
|
|
||||||
"Connect downstream blocks here for the 'rejected' workflow path."
|
|
||||||
)
|
)
|
||||||
review_message: str = SchemaField(
|
review_message: str = SchemaField(
|
||||||
description="Optional message provided by the reviewer explaining their "
|
description="Any message provided by the reviewer", default=""
|
||||||
"decision. Only outputs when the reviewer provides a message; "
|
|
||||||
"this pin does not fire if no message was given.",
|
|
||||||
default="",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
id="8b2a7b3c-6e9d-4a5f-8c1b-2e3f4a5b6c7d",
|
id="8b2a7b3c-6e9d-4a5f-8c1b-2e3f4a5b6c7d",
|
||||||
description="Pause execution for human review. Data flows through "
|
description="Pause execution and wait for human approval or modification of data",
|
||||||
"approved_data or rejected_data output based on the reviewer's decision. "
|
|
||||||
"Outputs contain the actual data, not status strings.",
|
|
||||||
categories={BlockCategory.BASIC},
|
categories={BlockCategory.BASIC},
|
||||||
input_schema=HumanInTheLoopBlock.Input,
|
input_schema=HumanInTheLoopBlock.Input,
|
||||||
output_schema=HumanInTheLoopBlock.Output,
|
output_schema=HumanInTheLoopBlock.Output,
|
||||||
|
|||||||
@@ -162,16 +162,8 @@ class LinearClient:
|
|||||||
"searchTerm": team_name,
|
"searchTerm": team_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
result = await self.query(query, variables)
|
team_id = await self.query(query, variables)
|
||||||
nodes = result["teams"]["nodes"]
|
return team_id["teams"]["nodes"][0]["id"]
|
||||||
|
|
||||||
if not nodes:
|
|
||||||
raise LinearAPIException(
|
|
||||||
f"Team '{team_name}' not found. Check the team name or key and try again.",
|
|
||||||
status_code=404,
|
|
||||||
)
|
|
||||||
|
|
||||||
return nodes[0]["id"]
|
|
||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@@ -248,44 +240,17 @@ class LinearClient:
|
|||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def try_search_issues(
|
async def try_search_issues(self, term: str) -> list[Issue]:
|
||||||
self,
|
|
||||||
term: str,
|
|
||||||
max_results: int = 10,
|
|
||||||
team_id: str | None = None,
|
|
||||||
) -> list[Issue]:
|
|
||||||
try:
|
try:
|
||||||
query = """
|
query = """
|
||||||
query SearchIssues(
|
query SearchIssues($term: String!, $includeComments: Boolean!) {
|
||||||
$term: String!,
|
searchIssues(term: $term, includeComments: $includeComments) {
|
||||||
$first: Int,
|
|
||||||
$teamId: String
|
|
||||||
) {
|
|
||||||
searchIssues(
|
|
||||||
term: $term,
|
|
||||||
first: $first,
|
|
||||||
teamId: $teamId
|
|
||||||
) {
|
|
||||||
nodes {
|
nodes {
|
||||||
id
|
id
|
||||||
identifier
|
identifier
|
||||||
title
|
title
|
||||||
description
|
description
|
||||||
priority
|
priority
|
||||||
createdAt
|
|
||||||
state {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
type
|
|
||||||
}
|
|
||||||
project {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
}
|
|
||||||
assignee {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -293,8 +258,7 @@ class LinearClient:
|
|||||||
|
|
||||||
variables: dict[str, Any] = {
|
variables: dict[str, Any] = {
|
||||||
"term": term,
|
"term": term,
|
||||||
"first": max_results,
|
"includeComments": True,
|
||||||
"teamId": team_id,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
issues = await self.query(query, variables)
|
issues = await self.query(query, variables)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from ._config import (
|
|||||||
LinearScope,
|
LinearScope,
|
||||||
linear,
|
linear,
|
||||||
)
|
)
|
||||||
from .models import CreateIssueResponse, Issue, State
|
from .models import CreateIssueResponse, Issue
|
||||||
|
|
||||||
|
|
||||||
class LinearCreateIssueBlock(Block):
|
class LinearCreateIssueBlock(Block):
|
||||||
@@ -135,20 +135,9 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
description="Linear credentials with read permissions",
|
description="Linear credentials with read permissions",
|
||||||
required_scopes={LinearScope.READ},
|
required_scopes={LinearScope.READ},
|
||||||
)
|
)
|
||||||
max_results: int = SchemaField(
|
|
||||||
description="Maximum number of results to return",
|
|
||||||
default=10,
|
|
||||||
ge=1,
|
|
||||||
le=100,
|
|
||||||
)
|
|
||||||
team_name: str | None = SchemaField(
|
|
||||||
description="Optional team name to filter results (e.g., 'Internal', 'Open Source')",
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
issues: list[Issue] = SchemaField(description="List of issues")
|
issues: list[Issue] = SchemaField(description="List of issues")
|
||||||
error: str = SchemaField(description="Error message if the search failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -156,11 +145,8 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
description="Searches for issues on Linear",
|
description="Searches for issues on Linear",
|
||||||
input_schema=self.Input,
|
input_schema=self.Input,
|
||||||
output_schema=self.Output,
|
output_schema=self.Output,
|
||||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
|
|
||||||
test_input={
|
test_input={
|
||||||
"term": "Test issue",
|
"term": "Test issue",
|
||||||
"max_results": 10,
|
|
||||||
"team_name": None,
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS_OAUTH,
|
test_credentials=TEST_CREDENTIALS_OAUTH,
|
||||||
@@ -170,14 +156,10 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
[
|
[
|
||||||
Issue(
|
Issue(
|
||||||
id="abc123",
|
id="abc123",
|
||||||
identifier="TST-123",
|
identifier="abc123",
|
||||||
title="Test issue",
|
title="Test issue",
|
||||||
description="Test description",
|
description="Test description",
|
||||||
priority=1,
|
priority=1,
|
||||||
state=State(
|
|
||||||
id="state1", name="In Progress", type="started"
|
|
||||||
),
|
|
||||||
createdAt="2026-01-15T10:00:00.000Z",
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -186,12 +168,10 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
"search_issues": lambda *args, **kwargs: [
|
"search_issues": lambda *args, **kwargs: [
|
||||||
Issue(
|
Issue(
|
||||||
id="abc123",
|
id="abc123",
|
||||||
identifier="TST-123",
|
identifier="abc123",
|
||||||
title="Test issue",
|
title="Test issue",
|
||||||
description="Test description",
|
description="Test description",
|
||||||
priority=1,
|
priority=1,
|
||||||
state=State(id="state1", name="In Progress", type="started"),
|
|
||||||
createdAt="2026-01-15T10:00:00.000Z",
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -201,22 +181,10 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
async def search_issues(
|
async def search_issues(
|
||||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||||
term: str,
|
term: str,
|
||||||
max_results: int = 10,
|
|
||||||
team_name: str | None = None,
|
|
||||||
) -> list[Issue]:
|
) -> list[Issue]:
|
||||||
client = LinearClient(credentials=credentials)
|
client = LinearClient(credentials=credentials)
|
||||||
|
response: list[Issue] = await client.try_search_issues(term=term)
|
||||||
# Resolve team name to ID if provided
|
return response
|
||||||
# Raises LinearAPIException with descriptive message if team not found
|
|
||||||
team_id: str | None = None
|
|
||||||
if team_name:
|
|
||||||
team_id = await client.try_get_team_by_name(team_name=team_name)
|
|
||||||
|
|
||||||
return await client.try_search_issues(
|
|
||||||
term=term,
|
|
||||||
max_results=max_results,
|
|
||||||
team_id=team_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
@@ -228,10 +196,7 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
"""Execute the issue search"""
|
"""Execute the issue search"""
|
||||||
try:
|
try:
|
||||||
issues = await self.search_issues(
|
issues = await self.search_issues(
|
||||||
credentials=credentials,
|
credentials=credentials, term=input_data.term
|
||||||
term=input_data.term,
|
|
||||||
max_results=input_data.max_results,
|
|
||||||
team_name=input_data.team_name,
|
|
||||||
)
|
)
|
||||||
yield "issues", issues
|
yield "issues", issues
|
||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
|
|||||||
@@ -36,21 +36,12 @@ class Project(BaseModel):
|
|||||||
content: str | None = None
|
content: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class State(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
type: str | None = (
|
|
||||||
None # Workflow state type (e.g., "triage", "backlog", "started", "completed", "canceled")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Issue(BaseModel):
|
class Issue(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
identifier: str
|
identifier: str
|
||||||
title: str
|
title: str
|
||||||
description: str | None
|
description: str | None
|
||||||
priority: int
|
priority: int
|
||||||
state: State | None = None
|
|
||||||
project: Project | None = None
|
project: Project | None = None
|
||||||
createdAt: str | None = None
|
createdAt: str | None = None
|
||||||
comments: list[Comment] | None = None
|
comments: list[Comment] | None = None
|
||||||
|
|||||||
@@ -115,7 +115,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||||
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
|
||||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||||
# AI/ML API models
|
# AI/ML API models
|
||||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||||
@@ -271,9 +270,6 @@ MODEL_METADATA = {
|
|||||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
||||||
), # claude-4-sonnet-20250514
|
), # claude-4-sonnet-20250514
|
||||||
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
|
||||||
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
|
||||||
), # claude-opus-4-6
|
|
||||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||||
), # claude-opus-4-5-20251101
|
), # claude-opus-4-5-20251101
|
||||||
@@ -531,12 +527,12 @@ class LLMResponse(BaseModel):
|
|||||||
|
|
||||||
def convert_openai_tool_fmt_to_anthropic(
|
def convert_openai_tool_fmt_to_anthropic(
|
||||||
openai_tools: list[dict] | None = None,
|
openai_tools: list[dict] | None = None,
|
||||||
) -> Iterable[ToolParam] | anthropic.Omit:
|
) -> Iterable[ToolParam] | anthropic.NotGiven:
|
||||||
"""
|
"""
|
||||||
Convert OpenAI tool format to Anthropic tool format.
|
Convert OpenAI tool format to Anthropic tool format.
|
||||||
"""
|
"""
|
||||||
if not openai_tools or len(openai_tools) == 0:
|
if not openai_tools or len(openai_tools) == 0:
|
||||||
return anthropic.omit
|
return anthropic.NOT_GIVEN
|
||||||
|
|
||||||
anthropic_tools = []
|
anthropic_tools = []
|
||||||
for tool in openai_tools:
|
for tool in openai_tools:
|
||||||
@@ -596,10 +592,10 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
|||||||
|
|
||||||
def get_parallel_tool_calls_param(
|
def get_parallel_tool_calls_param(
|
||||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||||
) -> bool | openai.Omit:
|
):
|
||||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||||
return openai.omit
|
return openai.NOT_GIVEN
|
||||||
return parallel_tool_calls
|
return parallel_tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
246
autogpt_platform/backend/backend/blocks/media.py
Normal file
246
autogpt_platform/backend/backend/blocks/media.py
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
|
from moviepy.video.fx.Loop import Loop
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class MediaDurationBlock(Block):
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
media_in: MediaFileType = SchemaField(
|
||||||
|
description="Media input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
is_video: bool = SchemaField(
|
||||||
|
description="Whether the media is a video (True) or audio (False).",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
duration: float = SchemaField(
|
||||||
|
description="Duration of the media file (in seconds)."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
||||||
|
description="Block to get the duration of a media file.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=MediaDurationBlock.Input,
|
||||||
|
output_schema=MediaDurationBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# 1) Store the input media locally
|
||||||
|
local_media_path = await store_media_file(
|
||||||
|
file=input_data.media_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
media_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_media_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Load the clip
|
||||||
|
if input_data.is_video:
|
||||||
|
clip = VideoFileClip(media_abspath)
|
||||||
|
else:
|
||||||
|
clip = AudioFileClip(media_abspath)
|
||||||
|
|
||||||
|
yield "duration", clip.duration
|
||||||
|
|
||||||
|
|
||||||
|
class LoopVideoBlock(Block):
|
||||||
|
"""
|
||||||
|
Block for looping (repeating) a video clip until a given duration or number of loops.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="The input video (can be a URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
# Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`.
|
||||||
|
duration: Optional[float] = SchemaField(
|
||||||
|
description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.",
|
||||||
|
default=None,
|
||||||
|
ge=0.0,
|
||||||
|
)
|
||||||
|
n_loops: Optional[int] = SchemaField(
|
||||||
|
description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).",
|
||||||
|
default=None,
|
||||||
|
ge=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: str = SchemaField(
|
||||||
|
description="Looped video returned either as a relative path or a data URI."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
||||||
|
description="Block to loop a video to a given duration or number of repeats.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=LoopVideoBlock.Input,
|
||||||
|
output_schema=LoopVideoBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
|
# 1) Store the input video locally
|
||||||
|
local_video_path = await store_media_file(
|
||||||
|
file=input_data.video_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||||
|
|
||||||
|
# 2) Load the clip
|
||||||
|
clip = VideoFileClip(input_abspath)
|
||||||
|
|
||||||
|
# 3) Apply the loop effect
|
||||||
|
looped_clip = clip
|
||||||
|
if input_data.duration:
|
||||||
|
# Loop until we reach the specified duration
|
||||||
|
looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)])
|
||||||
|
elif input_data.n_loops:
|
||||||
|
looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)])
|
||||||
|
else:
|
||||||
|
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
||||||
|
|
||||||
|
assert isinstance(looped_clip, VideoFileClip)
|
||||||
|
|
||||||
|
# 4) Save the looped output
|
||||||
|
output_filename = MediaFileType(
|
||||||
|
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
|
||||||
|
)
|
||||||
|
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||||
|
|
||||||
|
looped_clip = looped_clip.with_audio(clip.audio)
|
||||||
|
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||||
|
|
||||||
|
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
|
video_out = await store_media_file(
|
||||||
|
file=output_filename,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
|
||||||
|
|
||||||
|
class AddAudioToVideoBlock(Block):
|
||||||
|
"""
|
||||||
|
Block that adds (attaches) an audio track to an existing video.
|
||||||
|
Optionally scale the volume of the new track.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Video input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
audio_in: MediaFileType = SchemaField(
|
||||||
|
description="Audio input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
volume: float = SchemaField(
|
||||||
|
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||||
|
default=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Final video (with attached audio), as a path or data URI."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="3503748d-62b6-4425-91d6-725b064af509",
|
||||||
|
description="Block to attach an audio file to a video file using moviepy.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=AddAudioToVideoBlock.Input,
|
||||||
|
output_schema=AddAudioToVideoBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
|
# 1) Store the inputs locally
|
||||||
|
local_video_path = await store_media_file(
|
||||||
|
file=input_data.video_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
local_audio_path = await store_media_file(
|
||||||
|
file=input_data.audio_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
||||||
|
video_abspath = os.path.join(abs_temp_dir, local_video_path)
|
||||||
|
audio_abspath = os.path.join(abs_temp_dir, local_audio_path)
|
||||||
|
|
||||||
|
# 2) Load video + audio with moviepy
|
||||||
|
video_clip = VideoFileClip(video_abspath)
|
||||||
|
audio_clip = AudioFileClip(audio_abspath)
|
||||||
|
# Optionally scale volume
|
||||||
|
if input_data.volume != 1.0:
|
||||||
|
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
||||||
|
|
||||||
|
# 3) Attach the new audio track
|
||||||
|
final_clip = video_clip.with_audio(audio_clip)
|
||||||
|
|
||||||
|
# 4) Write to output file
|
||||||
|
output_filename = MediaFileType(
|
||||||
|
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
|
||||||
|
)
|
||||||
|
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
||||||
|
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||||
|
|
||||||
|
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
|
video_out = await store_media_file(
|
||||||
|
file=output_filename,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from backend.blocks.encoder_block import TextEncoderBlock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_encoder_basic():
|
|
||||||
"""Test basic encoding of newlines and special characters."""
|
|
||||||
block = TextEncoderBlock()
|
|
||||||
result = []
|
|
||||||
async for output in block.run(TextEncoderBlock.Input(text="Hello\nWorld")):
|
|
||||||
result.append(output)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0][0] == "encoded_text"
|
|
||||||
assert result[0][1] == "Hello\\nWorld"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_encoder_multiple_escapes():
|
|
||||||
"""Test encoding of multiple escape sequences."""
|
|
||||||
block = TextEncoderBlock()
|
|
||||||
result = []
|
|
||||||
async for output in block.run(
|
|
||||||
TextEncoderBlock.Input(text="Line1\nLine2\tTabbed\rCarriage")
|
|
||||||
):
|
|
||||||
result.append(output)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0][0] == "encoded_text"
|
|
||||||
assert "\\n" in result[0][1]
|
|
||||||
assert "\\t" in result[0][1]
|
|
||||||
assert "\\r" in result[0][1]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_encoder_unicode():
|
|
||||||
"""Test that unicode characters are handled correctly."""
|
|
||||||
block = TextEncoderBlock()
|
|
||||||
result = []
|
|
||||||
async for output in block.run(TextEncoderBlock.Input(text="Hello 世界\n")):
|
|
||||||
result.append(output)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0][0] == "encoded_text"
|
|
||||||
# Unicode characters should be escaped as \uXXXX sequences
|
|
||||||
assert "\\n" in result[0][1]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_encoder_empty_string():
|
|
||||||
"""Test encoding of an empty string."""
|
|
||||||
block = TextEncoderBlock()
|
|
||||||
result = []
|
|
||||||
async for output in block.run(TextEncoderBlock.Input(text="")):
|
|
||||||
result.append(output)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0][0] == "encoded_text"
|
|
||||||
assert result[0][1] == ""
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_encoder_error_handling():
|
|
||||||
"""Test that encoding errors are handled gracefully."""
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
block = TextEncoderBlock()
|
|
||||||
result = []
|
|
||||||
|
|
||||||
with patch("codecs.encode", side_effect=Exception("Mocked encoding error")):
|
|
||||||
async for output in block.run(TextEncoderBlock.Input(text="test")):
|
|
||||||
result.append(output)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0][0] == "error"
|
|
||||||
assert "Mocked encoding error" in result[0][1]
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
"""Video editing blocks for AutoGPT Platform.
|
|
||||||
|
|
||||||
This module provides blocks for:
|
|
||||||
- Downloading videos from URLs (YouTube, Vimeo, news sites, direct links)
|
|
||||||
- Clipping/trimming video segments
|
|
||||||
- Concatenating multiple videos
|
|
||||||
- Adding text overlays
|
|
||||||
- Adding AI-generated narration
|
|
||||||
- Getting media duration
|
|
||||||
- Looping videos
|
|
||||||
- Adding audio to videos
|
|
||||||
|
|
||||||
Dependencies:
|
|
||||||
- yt-dlp: For video downloading
|
|
||||||
- moviepy: For video editing operations
|
|
||||||
- elevenlabs: For AI narration (optional)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from backend.blocks.video.add_audio import AddAudioToVideoBlock
|
|
||||||
from backend.blocks.video.clip import VideoClipBlock
|
|
||||||
from backend.blocks.video.concat import VideoConcatBlock
|
|
||||||
from backend.blocks.video.download import VideoDownloadBlock
|
|
||||||
from backend.blocks.video.duration import MediaDurationBlock
|
|
||||||
from backend.blocks.video.loop import LoopVideoBlock
|
|
||||||
from backend.blocks.video.narration import VideoNarrationBlock
|
|
||||||
from backend.blocks.video.text_overlay import VideoTextOverlayBlock
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AddAudioToVideoBlock",
|
|
||||||
"LoopVideoBlock",
|
|
||||||
"MediaDurationBlock",
|
|
||||||
"VideoClipBlock",
|
|
||||||
"VideoConcatBlock",
|
|
||||||
"VideoDownloadBlock",
|
|
||||||
"VideoNarrationBlock",
|
|
||||||
"VideoTextOverlayBlock",
|
|
||||||
]
|
|
||||||
@@ -1,131 +0,0 @@
|
|||||||
"""Shared utilities for video blocks."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import subprocess
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Known operation tags added by video blocks
|
|
||||||
_VIDEO_OPS = (
|
|
||||||
r"(?:clip|overlay|narrated|looped|concat|audio_attached|with_audio|narration)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Matches: {node_exec_id}_{operation}_ where node_exec_id contains a UUID
|
|
||||||
_BLOCK_PREFIX_RE = re.compile(
|
|
||||||
r"^[a-zA-Z0-9_-]*"
|
|
||||||
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
|
||||||
r"[a-zA-Z0-9_-]*"
|
|
||||||
r"_" + _VIDEO_OPS + r"_"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Matches: a lone {node_exec_id}_ prefix (no operation keyword, e.g. download output)
|
|
||||||
_UUID_PREFIX_RE = re.compile(
|
|
||||||
r"^[a-zA-Z0-9_-]*"
|
|
||||||
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
|
||||||
r"[a-zA-Z0-9_-]*_"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_source_name(input_path: str, max_length: int = 50) -> str:
|
|
||||||
"""Extract the original source filename by stripping block-generated prefixes.
|
|
||||||
|
|
||||||
Iteratively removes {node_exec_id}_{operation}_ prefixes that accumulate
|
|
||||||
when chaining video blocks, recovering the original human-readable name.
|
|
||||||
|
|
||||||
Safe for plain filenames (no UUID -> no stripping).
|
|
||||||
Falls back to "video" if everything is stripped.
|
|
||||||
"""
|
|
||||||
stem = Path(input_path).stem
|
|
||||||
|
|
||||||
# Pass 1: strip {node_exec_id}_{operation}_ prefixes iteratively
|
|
||||||
while _BLOCK_PREFIX_RE.match(stem):
|
|
||||||
stem = _BLOCK_PREFIX_RE.sub("", stem, count=1)
|
|
||||||
|
|
||||||
# Pass 2: strip a lone {node_exec_id}_ prefix (e.g. from download block)
|
|
||||||
if _UUID_PREFIX_RE.match(stem):
|
|
||||||
stem = _UUID_PREFIX_RE.sub("", stem, count=1)
|
|
||||||
|
|
||||||
if not stem:
|
|
||||||
return "video"
|
|
||||||
|
|
||||||
return stem[:max_length]
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_codecs(output_path: str) -> tuple[str, str]:
|
|
||||||
"""Get appropriate video and audio codecs based on output file extension.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_path: Path to the output file (used to determine extension)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (video_codec, audio_codec)
|
|
||||||
|
|
||||||
Codec mappings:
|
|
||||||
- .mp4: H.264 + AAC (universal compatibility)
|
|
||||||
- .webm: VP8 + Vorbis (web streaming)
|
|
||||||
- .mkv: H.264 + AAC (container supports many codecs)
|
|
||||||
- .mov: H.264 + AAC (Apple QuickTime, widely compatible)
|
|
||||||
- .m4v: H.264 + AAC (Apple iTunes/devices)
|
|
||||||
- .avi: MPEG-4 + MP3 (legacy Windows)
|
|
||||||
"""
|
|
||||||
ext = os.path.splitext(output_path)[1].lower()
|
|
||||||
|
|
||||||
codec_map: dict[str, tuple[str, str]] = {
|
|
||||||
".mp4": ("libx264", "aac"),
|
|
||||||
".webm": ("libvpx", "libvorbis"),
|
|
||||||
".mkv": ("libx264", "aac"),
|
|
||||||
".mov": ("libx264", "aac"),
|
|
||||||
".m4v": ("libx264", "aac"),
|
|
||||||
".avi": ("mpeg4", "libmp3lame"),
|
|
||||||
}
|
|
||||||
|
|
||||||
return codec_map.get(ext, ("libx264", "aac"))
|
|
||||||
|
|
||||||
|
|
||||||
def strip_chapters_inplace(video_path: str) -> None:
|
|
||||||
"""Strip chapter metadata from a media file in-place using ffmpeg.
|
|
||||||
|
|
||||||
MoviePy 2.x crashes with IndexError when parsing files with embedded
|
|
||||||
chapter metadata (https://github.com/Zulko/moviepy/issues/2419).
|
|
||||||
This strips chapters without re-encoding.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
video_path: Absolute path to the media file to strip chapters from.
|
|
||||||
"""
|
|
||||||
base, ext = os.path.splitext(video_path)
|
|
||||||
tmp_path = base + ".tmp" + ext
|
|
||||||
try:
|
|
||||||
result = subprocess.run(
|
|
||||||
[
|
|
||||||
"ffmpeg",
|
|
||||||
"-y",
|
|
||||||
"-i",
|
|
||||||
video_path,
|
|
||||||
"-map_chapters",
|
|
||||||
"-1",
|
|
||||||
"-codec",
|
|
||||||
"copy",
|
|
||||||
tmp_path,
|
|
||||||
],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=300,
|
|
||||||
)
|
|
||||||
if result.returncode != 0:
|
|
||||||
logger.warning(
|
|
||||||
"ffmpeg chapter strip failed (rc=%d): %s",
|
|
||||||
result.returncode,
|
|
||||||
result.stderr,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
os.replace(tmp_path, video_path)
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.warning("ffmpeg not found; skipping chapter strip")
|
|
||||||
finally:
|
|
||||||
if os.path.exists(tmp_path):
|
|
||||||
os.unlink(tmp_path)
|
|
||||||
@@ -1,113 +0,0 @@
|
|||||||
"""AddAudioToVideoBlock - Attach an audio track to a video file."""
|
|
||||||
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class AddAudioToVideoBlock(Block):
|
|
||||||
"""Add (attach) an audio track to an existing video."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="Video input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
audio_in: MediaFileType = SchemaField(
|
|
||||||
description="Audio input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
volume: float = SchemaField(
|
|
||||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
|
||||||
default=1.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Final video (with attached audio), as a path or data URI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="3503748d-62b6-4425-91d6-725b064af509",
|
|
||||||
description="Block to attach an audio file to a video file using moviepy.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=AddAudioToVideoBlock.Input,
|
|
||||||
output_schema=AddAudioToVideoBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
assert execution_context.node_exec_id is not None
|
|
||||||
graph_exec_id = execution_context.graph_exec_id
|
|
||||||
node_exec_id = execution_context.node_exec_id
|
|
||||||
|
|
||||||
# 1) Store the inputs locally
|
|
||||||
local_video_path = await store_media_file(
|
|
||||||
file=input_data.video_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
local_audio_path = await store_media_file(
|
|
||||||
file=input_data.audio_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
video_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
|
||||||
audio_abspath = get_exec_file_path(graph_exec_id, local_audio_path)
|
|
||||||
|
|
||||||
# 2) Load video + audio with moviepy
|
|
||||||
strip_chapters_inplace(video_abspath)
|
|
||||||
strip_chapters_inplace(audio_abspath)
|
|
||||||
video_clip = None
|
|
||||||
audio_clip = None
|
|
||||||
final_clip = None
|
|
||||||
try:
|
|
||||||
video_clip = VideoFileClip(video_abspath)
|
|
||||||
audio_clip = AudioFileClip(audio_abspath)
|
|
||||||
# Optionally scale volume
|
|
||||||
if input_data.volume != 1.0:
|
|
||||||
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
|
||||||
|
|
||||||
# 3) Attach the new audio track
|
|
||||||
final_clip = video_clip.with_audio(audio_clip)
|
|
||||||
|
|
||||||
# 4) Write to output file
|
|
||||||
source = extract_source_name(local_video_path)
|
|
||||||
output_filename = MediaFileType(f"{node_exec_id}_with_audio_{source}.mp4")
|
|
||||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
|
||||||
final_clip.write_videofile(
|
|
||||||
output_abspath, codec="libx264", audio_codec="aac"
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if final_clip:
|
|
||||||
final_clip.close()
|
|
||||||
if audio_clip:
|
|
||||||
audio_clip.close()
|
|
||||||
if video_clip:
|
|
||||||
video_clip.close()
|
|
||||||
|
|
||||||
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
|
||||||
video_out = await store_media_file(
|
|
||||||
file=output_filename,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
@@ -1,167 +0,0 @@
|
|||||||
"""VideoClipBlock - Extract a segment from a video file."""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import (
|
|
||||||
extract_source_name,
|
|
||||||
get_video_codecs,
|
|
||||||
strip_chapters_inplace,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.exceptions import BlockExecutionError
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class VideoClipBlock(Block):
|
|
||||||
"""Extract a time segment from a video."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="Input video (URL, data URI, or local path)"
|
|
||||||
)
|
|
||||||
start_time: float = SchemaField(description="Start time in seconds", ge=0.0)
|
|
||||||
end_time: float = SchemaField(description="End time in seconds", ge=0.0)
|
|
||||||
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
|
||||||
description="Output format", default="mp4", advanced=True
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Clipped video file (path or data URI)"
|
|
||||||
)
|
|
||||||
duration: float = SchemaField(description="Clip duration in seconds")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8f539119-e580-4d86-ad41-86fbcb22abb1",
|
|
||||||
description="Extract a time segment from a video",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
test_input={
|
|
||||||
"video_in": "/tmp/test.mp4",
|
|
||||||
"start_time": 0.0,
|
|
||||||
"end_time": 10.0,
|
|
||||||
},
|
|
||||||
test_output=[("video_out", str), ("duration", float)],
|
|
||||||
test_mock={
|
|
||||||
"_clip_video": lambda *args: 10.0,
|
|
||||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
|
||||||
"_store_output_video": lambda *args, **kwargs: "clip_test.mp4",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_input_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store input video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_output_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store output video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _clip_video(
|
|
||||||
self,
|
|
||||||
video_abspath: str,
|
|
||||||
output_abspath: str,
|
|
||||||
start_time: float,
|
|
||||||
end_time: float,
|
|
||||||
) -> float:
|
|
||||||
"""Extract a clip from a video. Extracted for testability."""
|
|
||||||
clip = None
|
|
||||||
subclip = None
|
|
||||||
try:
|
|
||||||
strip_chapters_inplace(video_abspath)
|
|
||||||
clip = VideoFileClip(video_abspath)
|
|
||||||
subclip = clip.subclipped(start_time, end_time)
|
|
||||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
|
||||||
subclip.write_videofile(
|
|
||||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
|
||||||
)
|
|
||||||
return subclip.duration
|
|
||||||
finally:
|
|
||||||
if subclip:
|
|
||||||
subclip.close()
|
|
||||||
if clip:
|
|
||||||
clip.close()
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
node_exec_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# Validate time range
|
|
||||||
if input_data.end_time <= input_data.start_time:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
|
|
||||||
# Store the input video locally
|
|
||||||
local_video_path = await self._store_input_video(
|
|
||||||
execution_context, input_data.video_in
|
|
||||||
)
|
|
||||||
video_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, local_video_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build output path
|
|
||||||
source = extract_source_name(local_video_path)
|
|
||||||
output_filename = MediaFileType(
|
|
||||||
f"{node_exec_id}_clip_{source}.{input_data.output_format}"
|
|
||||||
)
|
|
||||||
output_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
duration = self._clip_video(
|
|
||||||
video_abspath,
|
|
||||||
output_abspath,
|
|
||||||
input_data.start_time,
|
|
||||||
input_data.end_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return as workspace path or data URI based on context
|
|
||||||
video_out = await self._store_output_video(
|
|
||||||
execution_context, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
yield "duration", duration
|
|
||||||
|
|
||||||
except BlockExecutionError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Failed to clip video: {e}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
) from e
|
|
||||||
@@ -1,227 +0,0 @@
|
|||||||
"""VideoConcatBlock - Concatenate multiple video clips into one."""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from moviepy import concatenate_videoclips
|
|
||||||
from moviepy.video.fx import CrossFadeIn, CrossFadeOut, FadeIn, FadeOut
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import (
|
|
||||||
extract_source_name,
|
|
||||||
get_video_codecs,
|
|
||||||
strip_chapters_inplace,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.exceptions import BlockExecutionError
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class VideoConcatBlock(Block):
|
|
||||||
"""Merge multiple video clips into one continuous video."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
videos: list[MediaFileType] = SchemaField(
|
|
||||||
description="List of video files to concatenate (in order)"
|
|
||||||
)
|
|
||||||
transition: Literal["none", "crossfade", "fade_black"] = SchemaField(
|
|
||||||
description="Transition between clips", default="none"
|
|
||||||
)
|
|
||||||
transition_duration: int = SchemaField(
|
|
||||||
description="Transition duration in seconds",
|
|
||||||
default=1,
|
|
||||||
ge=0,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
|
||||||
description="Output format", default="mp4", advanced=True
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Concatenated video file (path or data URI)"
|
|
||||||
)
|
|
||||||
total_duration: float = SchemaField(description="Total duration in seconds")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="9b0f531a-1118-487f-aeec-3fa63ea8900a",
|
|
||||||
description="Merge multiple video clips into one continuous video",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
test_input={
|
|
||||||
"videos": ["/tmp/a.mp4", "/tmp/b.mp4"],
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
("video_out", str),
|
|
||||||
("total_duration", float),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_concat_videos": lambda *args: 20.0,
|
|
||||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
|
||||||
"_store_output_video": lambda *args, **kwargs: "concat_test.mp4",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_input_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store input video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_output_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store output video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _concat_videos(
|
|
||||||
self,
|
|
||||||
video_abspaths: list[str],
|
|
||||||
output_abspath: str,
|
|
||||||
transition: str,
|
|
||||||
transition_duration: int,
|
|
||||||
) -> float:
|
|
||||||
"""Concatenate videos. Extracted for testability.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Total duration of the concatenated video.
|
|
||||||
"""
|
|
||||||
clips = []
|
|
||||||
faded_clips = []
|
|
||||||
final = None
|
|
||||||
try:
|
|
||||||
# Load clips
|
|
||||||
for v in video_abspaths:
|
|
||||||
strip_chapters_inplace(v)
|
|
||||||
clips.append(VideoFileClip(v))
|
|
||||||
|
|
||||||
# Validate transition_duration against shortest clip
|
|
||||||
if transition in {"crossfade", "fade_black"} and transition_duration > 0:
|
|
||||||
min_duration = min(c.duration for c in clips)
|
|
||||||
if transition_duration >= min_duration:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=(
|
|
||||||
f"transition_duration ({transition_duration}s) must be "
|
|
||||||
f"shorter than the shortest clip ({min_duration:.2f}s)"
|
|
||||||
),
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
)
|
|
||||||
|
|
||||||
if transition == "crossfade":
|
|
||||||
for i, clip in enumerate(clips):
|
|
||||||
effects = []
|
|
||||||
if i > 0:
|
|
||||||
effects.append(CrossFadeIn(transition_duration))
|
|
||||||
if i < len(clips) - 1:
|
|
||||||
effects.append(CrossFadeOut(transition_duration))
|
|
||||||
if effects:
|
|
||||||
clip = clip.with_effects(effects)
|
|
||||||
faded_clips.append(clip)
|
|
||||||
final = concatenate_videoclips(
|
|
||||||
faded_clips,
|
|
||||||
method="compose",
|
|
||||||
padding=-transition_duration,
|
|
||||||
)
|
|
||||||
elif transition == "fade_black":
|
|
||||||
for clip in clips:
|
|
||||||
faded = clip.with_effects(
|
|
||||||
[FadeIn(transition_duration), FadeOut(transition_duration)]
|
|
||||||
)
|
|
||||||
faded_clips.append(faded)
|
|
||||||
final = concatenate_videoclips(faded_clips)
|
|
||||||
else:
|
|
||||||
final = concatenate_videoclips(clips)
|
|
||||||
|
|
||||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
|
||||||
final.write_videofile(
|
|
||||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
|
||||||
)
|
|
||||||
|
|
||||||
return final.duration
|
|
||||||
finally:
|
|
||||||
if final:
|
|
||||||
final.close()
|
|
||||||
for clip in faded_clips:
|
|
||||||
clip.close()
|
|
||||||
for clip in clips:
|
|
||||||
clip.close()
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
node_exec_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# Validate minimum clips
|
|
||||||
if len(input_data.videos) < 2:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message="At least 2 videos are required for concatenation",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
|
|
||||||
# Store all input videos locally
|
|
||||||
video_abspaths = []
|
|
||||||
for video in input_data.videos:
|
|
||||||
local_path = await self._store_input_video(execution_context, video)
|
|
||||||
video_abspaths.append(
|
|
||||||
get_exec_file_path(execution_context.graph_exec_id, local_path)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build output path
|
|
||||||
source = (
|
|
||||||
extract_source_name(video_abspaths[0]) if video_abspaths else "video"
|
|
||||||
)
|
|
||||||
output_filename = MediaFileType(
|
|
||||||
f"{node_exec_id}_concat_{source}.{input_data.output_format}"
|
|
||||||
)
|
|
||||||
output_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
total_duration = self._concat_videos(
|
|
||||||
video_abspaths,
|
|
||||||
output_abspath,
|
|
||||||
input_data.transition,
|
|
||||||
input_data.transition_duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return as workspace path or data URI based on context
|
|
||||||
video_out = await self._store_output_video(
|
|
||||||
execution_context, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
yield "total_duration", total_duration
|
|
||||||
|
|
||||||
except BlockExecutionError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Failed to concatenate videos: {e}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
) from e
|
|
||||||
@@ -1,172 +0,0 @@
|
|||||||
"""VideoDownloadBlock - Download video from URL (YouTube, Vimeo, news sites, direct links)."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import typing
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import yt_dlp
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from yt_dlp import _Params
|
|
||||||
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.exceptions import BlockExecutionError
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class VideoDownloadBlock(Block):
|
|
||||||
"""Download video from URL using yt-dlp."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
url: str = SchemaField(
|
|
||||||
description="URL of the video to download (YouTube, Vimeo, direct link, etc.)",
|
|
||||||
placeholder="https://www.youtube.com/watch?v=...",
|
|
||||||
)
|
|
||||||
quality: Literal["best", "1080p", "720p", "480p", "audio_only"] = SchemaField(
|
|
||||||
description="Video quality preference", default="720p"
|
|
||||||
)
|
|
||||||
output_format: Literal["mp4", "webm", "mkv"] = SchemaField(
|
|
||||||
description="Output video format", default="mp4", advanced=True
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_file: MediaFileType = SchemaField(
|
|
||||||
description="Downloaded video (path or data URI)"
|
|
||||||
)
|
|
||||||
duration: float = SchemaField(description="Video duration in seconds")
|
|
||||||
title: str = SchemaField(description="Video title from source")
|
|
||||||
source_url: str = SchemaField(description="Original source URL")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="c35daabb-cd60-493b-b9ad-51f1fe4b50c4",
|
|
||||||
description="Download video from URL (YouTube, Vimeo, news sites, direct links)",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
disabled=True, # Disable until we can sandbox yt-dlp and handle security implications
|
|
||||||
test_input={
|
|
||||||
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
|
||||||
"quality": "480p",
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
("video_file", str),
|
|
||||||
("duration", float),
|
|
||||||
("title", str),
|
|
||||||
("source_url", str),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_download_video": lambda *args: (
|
|
||||||
"video.mp4",
|
|
||||||
212.0,
|
|
||||||
"Test Video",
|
|
||||||
),
|
|
||||||
"_store_output_video": lambda *args, **kwargs: "video.mp4",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_output_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store output video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_format_string(self, quality: str) -> str:
|
|
||||||
formats = {
|
|
||||||
"best": "bestvideo+bestaudio/best",
|
|
||||||
"1080p": "bestvideo[height<=1080]+bestaudio/best[height<=1080]",
|
|
||||||
"720p": "bestvideo[height<=720]+bestaudio/best[height<=720]",
|
|
||||||
"480p": "bestvideo[height<=480]+bestaudio/best[height<=480]",
|
|
||||||
"audio_only": "bestaudio/best",
|
|
||||||
}
|
|
||||||
return formats.get(quality, formats["720p"])
|
|
||||||
|
|
||||||
def _download_video(
|
|
||||||
self,
|
|
||||||
url: str,
|
|
||||||
quality: str,
|
|
||||||
output_format: str,
|
|
||||||
output_dir: str,
|
|
||||||
node_exec_id: str,
|
|
||||||
) -> tuple[str, float, str]:
|
|
||||||
"""Download video. Extracted for testability."""
|
|
||||||
output_template = os.path.join(
|
|
||||||
output_dir, f"{node_exec_id}_%(title).50s.%(ext)s"
|
|
||||||
)
|
|
||||||
|
|
||||||
ydl_opts: "_Params" = {
|
|
||||||
"format": f"{self._get_format_string(quality)}/best",
|
|
||||||
"outtmpl": output_template,
|
|
||||||
"merge_output_format": output_format,
|
|
||||||
"quiet": True,
|
|
||||||
"no_warnings": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
|
||||||
info = ydl.extract_info(url, download=True)
|
|
||||||
video_path = ydl.prepare_filename(info)
|
|
||||||
|
|
||||||
# Handle format conversion in filename
|
|
||||||
if not video_path.endswith(f".{output_format}"):
|
|
||||||
video_path = video_path.rsplit(".", 1)[0] + f".{output_format}"
|
|
||||||
|
|
||||||
# Return just the filename, not the full path
|
|
||||||
filename = os.path.basename(video_path)
|
|
||||||
|
|
||||||
return (
|
|
||||||
filename,
|
|
||||||
info.get("duration") or 0.0,
|
|
||||||
info.get("title") or "Unknown",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
node_exec_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
|
|
||||||
# Get the exec file directory
|
|
||||||
output_dir = get_exec_file_path(execution_context.graph_exec_id, "")
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
filename, duration, title = self._download_video(
|
|
||||||
input_data.url,
|
|
||||||
input_data.quality,
|
|
||||||
input_data.output_format,
|
|
||||||
output_dir,
|
|
||||||
node_exec_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return as workspace path or data URI based on context
|
|
||||||
video_out = await self._store_output_video(
|
|
||||||
execution_context, MediaFileType(filename)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_file", video_out
|
|
||||||
yield "duration", duration
|
|
||||||
yield "title", title
|
|
||||||
yield "source_url", input_data.url
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Failed to download video: {e}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
) from e
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
"""MediaDurationBlock - Get the duration of a media file."""
|
|
||||||
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import strip_chapters_inplace
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class MediaDurationBlock(Block):
|
|
||||||
"""Get the duration of a media file (video or audio)."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
media_in: MediaFileType = SchemaField(
|
|
||||||
description="Media input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
is_video: bool = SchemaField(
|
|
||||||
description="Whether the media is a video (True) or audio (False).",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
duration: float = SchemaField(
|
|
||||||
description="Duration of the media file (in seconds)."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
|
||||||
description="Block to get the duration of a media file.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=MediaDurationBlock.Input,
|
|
||||||
output_schema=MediaDurationBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# 1) Store the input media locally
|
|
||||||
local_media_path = await store_media_file(
|
|
||||||
file=input_data.media_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
media_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, local_media_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2) Strip chapters to avoid MoviePy crash, then load the clip
|
|
||||||
strip_chapters_inplace(media_abspath)
|
|
||||||
clip = None
|
|
||||||
try:
|
|
||||||
if input_data.is_video:
|
|
||||||
clip = VideoFileClip(media_abspath)
|
|
||||||
else:
|
|
||||||
clip = AudioFileClip(media_abspath)
|
|
||||||
|
|
||||||
duration = clip.duration
|
|
||||||
finally:
|
|
||||||
if clip:
|
|
||||||
clip.close()
|
|
||||||
|
|
||||||
yield "duration", duration
|
|
||||||
@@ -1,115 +0,0 @@
|
|||||||
"""LoopVideoBlock - Loop a video to a given duration or number of repeats."""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from moviepy.video.fx.Loop import Loop
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class LoopVideoBlock(Block):
|
|
||||||
"""Loop (repeat) a video clip until a given duration or number of loops."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="The input video (can be a URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
duration: Optional[float] = SchemaField(
|
|
||||||
description="Target duration (in seconds) to loop the video to. Either duration or n_loops must be provided.",
|
|
||||||
default=None,
|
|
||||||
ge=0.0,
|
|
||||||
le=3600.0, # Max 1 hour to prevent disk exhaustion
|
|
||||||
)
|
|
||||||
n_loops: Optional[int] = SchemaField(
|
|
||||||
description="Number of times to repeat the video. Either n_loops or duration must be provided.",
|
|
||||||
default=None,
|
|
||||||
ge=1,
|
|
||||||
le=10, # Max 10 loops to prevent disk exhaustion
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Looped video returned either as a relative path or a data URI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
|
||||||
description="Block to loop a video to a given duration or number of repeats.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=LoopVideoBlock.Input,
|
|
||||||
output_schema=LoopVideoBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
assert execution_context.node_exec_id is not None
|
|
||||||
graph_exec_id = execution_context.graph_exec_id
|
|
||||||
node_exec_id = execution_context.node_exec_id
|
|
||||||
|
|
||||||
# 1) Store the input video locally
|
|
||||||
local_video_path = await store_media_file(
|
|
||||||
file=input_data.video_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
|
||||||
|
|
||||||
# 2) Load the clip
|
|
||||||
strip_chapters_inplace(input_abspath)
|
|
||||||
clip = None
|
|
||||||
looped_clip = None
|
|
||||||
try:
|
|
||||||
clip = VideoFileClip(input_abspath)
|
|
||||||
|
|
||||||
# 3) Apply the loop effect
|
|
||||||
if input_data.duration:
|
|
||||||
# Loop until we reach the specified duration
|
|
||||||
looped_clip = clip.with_effects([Loop(duration=input_data.duration)])
|
|
||||||
elif input_data.n_loops:
|
|
||||||
looped_clip = clip.with_effects([Loop(n=input_data.n_loops)])
|
|
||||||
else:
|
|
||||||
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
|
||||||
|
|
||||||
assert isinstance(looped_clip, VideoFileClip)
|
|
||||||
|
|
||||||
# 4) Save the looped output
|
|
||||||
source = extract_source_name(local_video_path)
|
|
||||||
output_filename = MediaFileType(f"{node_exec_id}_looped_{source}.mp4")
|
|
||||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
|
||||||
|
|
||||||
looped_clip = looped_clip.with_audio(clip.audio)
|
|
||||||
looped_clip.write_videofile(
|
|
||||||
output_abspath, codec="libx264", audio_codec="aac"
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if looped_clip:
|
|
||||||
looped_clip.close()
|
|
||||||
if clip:
|
|
||||||
clip.close()
|
|
||||||
|
|
||||||
# Return output - for_block_output returns workspace:// if available, else data URI
|
|
||||||
video_out = await store_media_file(
|
|
||||||
file=output_filename,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
@@ -1,267 +0,0 @@
|
|||||||
"""VideoNarrationBlock - Generate AI voice narration and add to video."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from elevenlabs import ElevenLabs
|
|
||||||
from moviepy import CompositeAudioClip
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.elevenlabs._auth import (
|
|
||||||
TEST_CREDENTIALS,
|
|
||||||
TEST_CREDENTIALS_INPUT,
|
|
||||||
ElevenLabsCredentials,
|
|
||||||
ElevenLabsCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.blocks.video._utils import (
|
|
||||||
extract_source_name,
|
|
||||||
get_video_codecs,
|
|
||||||
strip_chapters_inplace,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
|
||||||
from backend.util.exceptions import BlockExecutionError
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class VideoNarrationBlock(Block):
|
|
||||||
"""Generate AI narration and add to video."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: ElevenLabsCredentialsInput = CredentialsField(
|
|
||||||
description="ElevenLabs API key for voice synthesis"
|
|
||||||
)
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="Input video (URL, data URI, or local path)"
|
|
||||||
)
|
|
||||||
script: str = SchemaField(description="Narration script text")
|
|
||||||
voice_id: str = SchemaField(
|
|
||||||
description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel
|
|
||||||
)
|
|
||||||
model_id: Literal[
|
|
||||||
"eleven_multilingual_v2",
|
|
||||||
"eleven_flash_v2_5",
|
|
||||||
"eleven_turbo_v2_5",
|
|
||||||
"eleven_turbo_v2",
|
|
||||||
] = SchemaField(
|
|
||||||
description="ElevenLabs TTS model",
|
|
||||||
default="eleven_multilingual_v2",
|
|
||||||
)
|
|
||||||
mix_mode: Literal["replace", "mix", "ducking"] = SchemaField(
|
|
||||||
description="How to combine with original audio. 'ducking' applies stronger attenuation than 'mix'.",
|
|
||||||
default="ducking",
|
|
||||||
)
|
|
||||||
narration_volume: float = SchemaField(
|
|
||||||
description="Narration volume (0.0 to 2.0)",
|
|
||||||
default=1.0,
|
|
||||||
ge=0.0,
|
|
||||||
le=2.0,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
original_volume: float = SchemaField(
|
|
||||||
description="Original audio volume when mixing (0.0 to 1.0)",
|
|
||||||
default=0.3,
|
|
||||||
ge=0.0,
|
|
||||||
le=1.0,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Video with narration (path or data URI)"
|
|
||||||
)
|
|
||||||
audio_file: MediaFileType = SchemaField(
|
|
||||||
description="Generated audio file (path or data URI)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="3d036b53-859c-4b17-9826-ca340f736e0e",
|
|
||||||
description="Generate AI narration and add to video",
|
|
||||||
categories={BlockCategory.MULTIMEDIA, BlockCategory.AI},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
test_input={
|
|
||||||
"video_in": "/tmp/test.mp4",
|
|
||||||
"script": "Hello world",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[("video_out", str), ("audio_file", str)],
|
|
||||||
test_mock={
|
|
||||||
"_generate_narration_audio": lambda *args: b"mock audio content",
|
|
||||||
"_add_narration_to_video": lambda *args: None,
|
|
||||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
|
||||||
"_store_output_video": lambda *args, **kwargs: "narrated_test.mp4",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_input_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store input video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_output_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store output video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _generate_narration_audio(
|
|
||||||
self, api_key: str, script: str, voice_id: str, model_id: str
|
|
||||||
) -> bytes:
|
|
||||||
"""Generate narration audio via ElevenLabs API."""
|
|
||||||
client = ElevenLabs(api_key=api_key)
|
|
||||||
audio_generator = client.text_to_speech.convert(
|
|
||||||
voice_id=voice_id,
|
|
||||||
text=script,
|
|
||||||
model_id=model_id,
|
|
||||||
)
|
|
||||||
# The SDK returns a generator, collect all chunks
|
|
||||||
return b"".join(audio_generator)
|
|
||||||
|
|
||||||
def _add_narration_to_video(
|
|
||||||
self,
|
|
||||||
video_abspath: str,
|
|
||||||
audio_abspath: str,
|
|
||||||
output_abspath: str,
|
|
||||||
mix_mode: str,
|
|
||||||
narration_volume: float,
|
|
||||||
original_volume: float,
|
|
||||||
) -> None:
|
|
||||||
"""Add narration audio to video. Extracted for testability."""
|
|
||||||
video = None
|
|
||||||
final = None
|
|
||||||
narration_original = None
|
|
||||||
narration_scaled = None
|
|
||||||
original = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
strip_chapters_inplace(video_abspath)
|
|
||||||
video = VideoFileClip(video_abspath)
|
|
||||||
narration_original = AudioFileClip(audio_abspath)
|
|
||||||
narration_scaled = narration_original.with_volume_scaled(narration_volume)
|
|
||||||
narration = narration_scaled
|
|
||||||
|
|
||||||
if mix_mode == "replace":
|
|
||||||
final_audio = narration
|
|
||||||
elif mix_mode == "mix":
|
|
||||||
if video.audio:
|
|
||||||
original = video.audio.with_volume_scaled(original_volume)
|
|
||||||
final_audio = CompositeAudioClip([original, narration])
|
|
||||||
else:
|
|
||||||
final_audio = narration
|
|
||||||
else: # ducking - apply stronger attenuation
|
|
||||||
if video.audio:
|
|
||||||
# Ducking uses a much lower volume for original audio
|
|
||||||
ducking_volume = original_volume * 0.3
|
|
||||||
original = video.audio.with_volume_scaled(ducking_volume)
|
|
||||||
final_audio = CompositeAudioClip([original, narration])
|
|
||||||
else:
|
|
||||||
final_audio = narration
|
|
||||||
|
|
||||||
final = video.with_audio(final_audio)
|
|
||||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
|
||||||
final.write_videofile(
|
|
||||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if original:
|
|
||||||
original.close()
|
|
||||||
if narration_scaled:
|
|
||||||
narration_scaled.close()
|
|
||||||
if narration_original:
|
|
||||||
narration_original.close()
|
|
||||||
if final:
|
|
||||||
final.close()
|
|
||||||
if video:
|
|
||||||
video.close()
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: ElevenLabsCredentials,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
node_exec_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
|
|
||||||
# Store the input video locally
|
|
||||||
local_video_path = await self._store_input_video(
|
|
||||||
execution_context, input_data.video_in
|
|
||||||
)
|
|
||||||
video_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, local_video_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate narration audio via ElevenLabs
|
|
||||||
audio_content = self._generate_narration_audio(
|
|
||||||
credentials.api_key.get_secret_value(),
|
|
||||||
input_data.script,
|
|
||||||
input_data.voice_id,
|
|
||||||
input_data.model_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save audio to exec file path
|
|
||||||
audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3")
|
|
||||||
audio_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, audio_filename
|
|
||||||
)
|
|
||||||
os.makedirs(os.path.dirname(audio_abspath), exist_ok=True)
|
|
||||||
with open(audio_abspath, "wb") as f:
|
|
||||||
f.write(audio_content)
|
|
||||||
|
|
||||||
# Add narration to video
|
|
||||||
source = extract_source_name(local_video_path)
|
|
||||||
output_filename = MediaFileType(f"{node_exec_id}_narrated_{source}.mp4")
|
|
||||||
output_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
self._add_narration_to_video(
|
|
||||||
video_abspath,
|
|
||||||
audio_abspath,
|
|
||||||
output_abspath,
|
|
||||||
input_data.mix_mode,
|
|
||||||
input_data.narration_volume,
|
|
||||||
input_data.original_volume,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return as workspace path or data URI based on context
|
|
||||||
video_out = await self._store_output_video(
|
|
||||||
execution_context, output_filename
|
|
||||||
)
|
|
||||||
audio_out = await self._store_output_video(
|
|
||||||
execution_context, audio_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
yield "audio_file", audio_out
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Failed to add narration: {e}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
) from e
|
|
||||||
@@ -1,231 +0,0 @@
|
|||||||
"""VideoTextOverlayBlock - Add text overlay to video."""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from moviepy import CompositeVideoClip, TextClip
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import (
|
|
||||||
extract_source_name,
|
|
||||||
get_video_codecs,
|
|
||||||
strip_chapters_inplace,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.exceptions import BlockExecutionError
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class VideoTextOverlayBlock(Block):
|
|
||||||
"""Add text overlay/caption to video."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="Input video (URL, data URI, or local path)"
|
|
||||||
)
|
|
||||||
text: str = SchemaField(description="Text to overlay on video")
|
|
||||||
position: Literal[
|
|
||||||
"top",
|
|
||||||
"center",
|
|
||||||
"bottom",
|
|
||||||
"top-left",
|
|
||||||
"top-right",
|
|
||||||
"bottom-left",
|
|
||||||
"bottom-right",
|
|
||||||
] = SchemaField(description="Position of text on screen", default="bottom")
|
|
||||||
start_time: float | None = SchemaField(
|
|
||||||
description="When to show text (seconds). None = entire video",
|
|
||||||
default=None,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
end_time: float | None = SchemaField(
|
|
||||||
description="When to hide text (seconds). None = until end",
|
|
||||||
default=None,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
font_size: int = SchemaField(
|
|
||||||
description="Font size", default=48, ge=12, le=200, advanced=True
|
|
||||||
)
|
|
||||||
font_color: str = SchemaField(
|
|
||||||
description="Font color (hex or name)", default="white", advanced=True
|
|
||||||
)
|
|
||||||
bg_color: str | None = SchemaField(
|
|
||||||
description="Background color behind text (None for transparent)",
|
|
||||||
default=None,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Video with text overlay (path or data URI)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8ef14de6-cc90-430a-8cfa-3a003be92454",
|
|
||||||
description="Add text overlay/caption to video",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
disabled=True, # Disable until we can lockdown imagemagick security policy
|
|
||||||
test_input={"video_in": "/tmp/test.mp4", "text": "Hello World"},
|
|
||||||
test_output=[("video_out", str)],
|
|
||||||
test_mock={
|
|
||||||
"_add_text_overlay": lambda *args: None,
|
|
||||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
|
||||||
"_store_output_video": lambda *args, **kwargs: "overlay_test.mp4",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_input_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store input video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_output_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store output video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_text_overlay(
|
|
||||||
self,
|
|
||||||
video_abspath: str,
|
|
||||||
output_abspath: str,
|
|
||||||
text: str,
|
|
||||||
position: str,
|
|
||||||
start_time: float | None,
|
|
||||||
end_time: float | None,
|
|
||||||
font_size: int,
|
|
||||||
font_color: str,
|
|
||||||
bg_color: str | None,
|
|
||||||
) -> None:
|
|
||||||
"""Add text overlay to video. Extracted for testability."""
|
|
||||||
video = None
|
|
||||||
final = None
|
|
||||||
txt_clip = None
|
|
||||||
try:
|
|
||||||
strip_chapters_inplace(video_abspath)
|
|
||||||
video = VideoFileClip(video_abspath)
|
|
||||||
|
|
||||||
txt_clip = TextClip(
|
|
||||||
text=text,
|
|
||||||
font_size=font_size,
|
|
||||||
color=font_color,
|
|
||||||
bg_color=bg_color,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Position mapping
|
|
||||||
pos_map = {
|
|
||||||
"top": ("center", "top"),
|
|
||||||
"center": ("center", "center"),
|
|
||||||
"bottom": ("center", "bottom"),
|
|
||||||
"top-left": ("left", "top"),
|
|
||||||
"top-right": ("right", "top"),
|
|
||||||
"bottom-left": ("left", "bottom"),
|
|
||||||
"bottom-right": ("right", "bottom"),
|
|
||||||
}
|
|
||||||
|
|
||||||
txt_clip = txt_clip.with_position(pos_map[position])
|
|
||||||
|
|
||||||
# Set timing
|
|
||||||
start = start_time or 0
|
|
||||||
end = end_time or video.duration
|
|
||||||
duration = max(0, end - start)
|
|
||||||
txt_clip = txt_clip.with_start(start).with_end(end).with_duration(duration)
|
|
||||||
|
|
||||||
final = CompositeVideoClip([video, txt_clip])
|
|
||||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
|
||||||
final.write_videofile(
|
|
||||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if txt_clip:
|
|
||||||
txt_clip.close()
|
|
||||||
if final:
|
|
||||||
final.close()
|
|
||||||
if video:
|
|
||||||
video.close()
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
node_exec_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# Validate time range if both are provided
|
|
||||||
if (
|
|
||||||
input_data.start_time is not None
|
|
||||||
and input_data.end_time is not None
|
|
||||||
and input_data.end_time <= input_data.start_time
|
|
||||||
):
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
|
|
||||||
# Store the input video locally
|
|
||||||
local_video_path = await self._store_input_video(
|
|
||||||
execution_context, input_data.video_in
|
|
||||||
)
|
|
||||||
video_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, local_video_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build output path
|
|
||||||
source = extract_source_name(local_video_path)
|
|
||||||
output_filename = MediaFileType(f"{node_exec_id}_overlay_{source}.mp4")
|
|
||||||
output_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
self._add_text_overlay(
|
|
||||||
video_abspath,
|
|
||||||
output_abspath,
|
|
||||||
input_data.text,
|
|
||||||
input_data.position,
|
|
||||||
input_data.start_time,
|
|
||||||
input_data.end_time,
|
|
||||||
input_data.font_size,
|
|
||||||
input_data.font_color,
|
|
||||||
input_data.bg_color,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return as workspace path or data URI based on context
|
|
||||||
video_out = await self._store_output_video(
|
|
||||||
execution_context, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
|
|
||||||
except BlockExecutionError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Failed to add text overlay: {e}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
) from e
|
|
||||||
@@ -165,13 +165,10 @@ class TranscribeYoutubeVideoBlock(Block):
|
|||||||
credentials: WebshareProxyCredentials,
|
credentials: WebshareProxyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
video_id = self.extract_video_id(input_data.youtube_url)
|
||||||
video_id = self.extract_video_id(input_data.youtube_url)
|
yield "video_id", video_id
|
||||||
transcript = self.get_transcript(video_id, credentials)
|
|
||||||
transcript_text = self.format_transcript(transcript=transcript)
|
|
||||||
|
|
||||||
# Only yield after all operations succeed
|
transcript = self.get_transcript(video_id, credentials)
|
||||||
yield "video_id", video_id
|
transcript_text = self.format_transcript(transcript=transcript)
|
||||||
yield "transcript", transcript_text
|
|
||||||
except Exception as e:
|
yield "transcript", transcript_text
|
||||||
yield "error", str(e)
|
|
||||||
|
|||||||
@@ -246,9 +246,7 @@ class BlockSchema(BaseModel):
|
|||||||
f"is not of type {CredentialsMetaInput.__name__}"
|
f"is not of type {CredentialsMetaInput.__name__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
CredentialsMetaInput.validate_credentials_field_schema(
|
credentials_fields[field_name].validate_credentials_field_schema(cls)
|
||||||
cls.get_field_schema(field_name), field_name
|
|
||||||
)
|
|
||||||
|
|
||||||
elif field_name in credentials_fields:
|
elif field_name in credentials_fields:
|
||||||
raise KeyError(
|
raise KeyError(
|
||||||
|
|||||||
@@ -36,14 +36,12 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
|||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||||
from backend.blocks.video.narration import VideoNarrationBlock
|
|
||||||
from backend.data.block import Block, BlockCost, BlockCostType
|
from backend.data.block import Block, BlockCost, BlockCostType
|
||||||
from backend.integrations.credentials_store import (
|
from backend.integrations.credentials_store import (
|
||||||
aiml_api_credentials,
|
aiml_api_credentials,
|
||||||
anthropic_credentials,
|
anthropic_credentials,
|
||||||
apollo_credentials,
|
apollo_credentials,
|
||||||
did_credentials,
|
did_credentials,
|
||||||
elevenlabs_credentials,
|
|
||||||
enrichlayer_credentials,
|
enrichlayer_credentials,
|
||||||
groq_credentials,
|
groq_credentials,
|
||||||
ideogram_credentials,
|
ideogram_credentials,
|
||||||
@@ -80,7 +78,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
|||||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||||
LlmModel.CLAUDE_4_OPUS: 21,
|
LlmModel.CLAUDE_4_OPUS: 21,
|
||||||
LlmModel.CLAUDE_4_SONNET: 5,
|
LlmModel.CLAUDE_4_SONNET: 5,
|
||||||
LlmModel.CLAUDE_4_6_OPUS: 14,
|
|
||||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||||
@@ -642,16 +639,4 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
VideoNarrationBlock: [
|
|
||||||
BlockCost(
|
|
||||||
cost_amount=5, # ElevenLabs TTS cost
|
|
||||||
cost_filter={
|
|
||||||
"credentials": {
|
|
||||||
"id": elevenlabs_credentials.id,
|
|
||||||
"provider": elevenlabs_credentials.provider,
|
|
||||||
"type": elevenlabs_credentials.type,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,16 +134,6 @@ async def test_block_credit_reset(server: SpinTestServer):
|
|||||||
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
||||||
user_credit.time_now = lambda: month1
|
user_credit.time_now = lambda: month1
|
||||||
|
|
||||||
# IMPORTANT: Set updatedAt to December of previous year to ensure it's
|
|
||||||
# in a different month than month1 (January). This fixes a timing bug
|
|
||||||
# where if the test runs in early February, 35 days ago would be January,
|
|
||||||
# matching the mocked month1 and preventing the refill from triggering.
|
|
||||||
dec_previous_year = month1.replace(year=month1.year - 1, month=12, day=15)
|
|
||||||
await UserBalance.prisma().update(
|
|
||||||
where={"userId": DEFAULT_USER_ID},
|
|
||||||
data={"updatedAt": dec_previous_year},
|
|
||||||
)
|
|
||||||
|
|
||||||
# First call in month 1 should trigger refill
|
# First call in month 1 should trigger refill
|
||||||
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||||
assert balance == REFILL_VALUE # Should get 1000 credits
|
assert balance == REFILL_VALUE # Should get 1000 credits
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import queue
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from multiprocessing import Manager
|
||||||
|
from queue import Empty
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Annotated,
|
Annotated,
|
||||||
@@ -1199,16 +1200,12 @@ class NodeExecutionEntry(BaseModel):
|
|||||||
|
|
||||||
class ExecutionQueue(Generic[T]):
|
class ExecutionQueue(Generic[T]):
|
||||||
"""
|
"""
|
||||||
Thread-safe queue for managing node execution within a single graph execution.
|
Queue for managing the execution of agents.
|
||||||
|
This will be shared between different processes
|
||||||
Note: Uses queue.Queue (not multiprocessing.Queue) since all access is from
|
|
||||||
threads within the same process. If migrating back to ProcessPoolExecutor,
|
|
||||||
replace with multiprocessing.Manager().Queue() for cross-process safety.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Thread-safe queue (not multiprocessing) — see class docstring
|
self.queue = Manager().Queue()
|
||||||
self.queue: queue.Queue[T] = queue.Queue()
|
|
||||||
|
|
||||||
def add(self, execution: T) -> T:
|
def add(self, execution: T) -> T:
|
||||||
self.queue.put(execution)
|
self.queue.put(execution)
|
||||||
@@ -1223,7 +1220,7 @@ class ExecutionQueue(Generic[T]):
|
|||||||
def get_or_none(self) -> T | None:
|
def get_or_none(self) -> T | None:
|
||||||
try:
|
try:
|
||||||
return self.queue.get_nowait()
|
return self.queue.get_nowait()
|
||||||
except queue.Empty:
|
except Empty:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
"""Tests for ExecutionQueue thread-safety."""
|
|
||||||
|
|
||||||
import queue
|
|
||||||
import threading
|
|
||||||
|
|
||||||
from backend.data.execution import ExecutionQueue
|
|
||||||
|
|
||||||
|
|
||||||
def test_execution_queue_uses_stdlib_queue():
|
|
||||||
"""Verify ExecutionQueue uses queue.Queue (not multiprocessing)."""
|
|
||||||
q = ExecutionQueue()
|
|
||||||
assert isinstance(q.queue, queue.Queue)
|
|
||||||
|
|
||||||
|
|
||||||
def test_basic_operations():
|
|
||||||
"""Test add, get, empty, and get_or_none."""
|
|
||||||
q = ExecutionQueue()
|
|
||||||
|
|
||||||
assert q.empty() is True
|
|
||||||
assert q.get_or_none() is None
|
|
||||||
|
|
||||||
result = q.add("item1")
|
|
||||||
assert result == "item1"
|
|
||||||
assert q.empty() is False
|
|
||||||
|
|
||||||
item = q.get()
|
|
||||||
assert item == "item1"
|
|
||||||
assert q.empty() is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_thread_safety():
|
|
||||||
"""Test concurrent access from multiple threads."""
|
|
||||||
q = ExecutionQueue()
|
|
||||||
results = []
|
|
||||||
num_items = 100
|
|
||||||
|
|
||||||
def producer():
|
|
||||||
for i in range(num_items):
|
|
||||||
q.add(f"item_{i}")
|
|
||||||
|
|
||||||
def consumer():
|
|
||||||
count = 0
|
|
||||||
while count < num_items:
|
|
||||||
item = q.get_or_none()
|
|
||||||
if item is not None:
|
|
||||||
results.append(item)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
producer_thread = threading.Thread(target=producer)
|
|
||||||
consumer_thread = threading.Thread(target=consumer)
|
|
||||||
|
|
||||||
producer_thread.start()
|
|
||||||
consumer_thread.start()
|
|
||||||
|
|
||||||
producer_thread.join(timeout=5)
|
|
||||||
consumer_thread.join(timeout=5)
|
|
||||||
|
|
||||||
assert len(results) == num_items
|
|
||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Self, cast
|
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, cast
|
||||||
|
|
||||||
from prisma.enums import SubmissionStatus
|
from prisma.enums import SubmissionStatus
|
||||||
from prisma.models import (
|
from prisma.models import (
|
||||||
@@ -20,7 +20,7 @@ from prisma.types import (
|
|||||||
AgentNodeLinkCreateInput,
|
AgentNodeLinkCreateInput,
|
||||||
StoreListingVersionWhereInput,
|
StoreListingVersionWhereInput,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, BeforeValidator, Field
|
from pydantic import BaseModel, BeforeValidator, Field, create_model
|
||||||
from pydantic.fields import computed_field
|
from pydantic.fields import computed_field
|
||||||
|
|
||||||
from backend.blocks.agent import AgentExecutorBlock
|
from backend.blocks.agent import AgentExecutorBlock
|
||||||
@@ -30,6 +30,7 @@ from backend.data.db import prisma as db
|
|||||||
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
||||||
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
|
CredentialsField,
|
||||||
CredentialsFieldInfo,
|
CredentialsFieldInfo,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
is_credentials_field_name,
|
is_credentials_field_name,
|
||||||
@@ -44,6 +45,7 @@ from .block import (
|
|||||||
AnyBlockSchema,
|
AnyBlockSchema,
|
||||||
Block,
|
Block,
|
||||||
BlockInput,
|
BlockInput,
|
||||||
|
BlockSchema,
|
||||||
BlockType,
|
BlockType,
|
||||||
EmptySchema,
|
EmptySchema,
|
||||||
get_block,
|
get_block,
|
||||||
@@ -111,12 +113,10 @@ class Link(BaseDbModel):
|
|||||||
|
|
||||||
class Node(BaseDbModel):
|
class Node(BaseDbModel):
|
||||||
block_id: str
|
block_id: str
|
||||||
input_default: BlockInput = Field( # dict[input_name, default_value]
|
input_default: BlockInput = {} # dict[input_name, default_value]
|
||||||
default_factory=dict
|
metadata: dict[str, Any] = {}
|
||||||
)
|
input_links: list[Link] = []
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
output_links: list[Link] = []
|
||||||
input_links: list[Link] = Field(default_factory=list)
|
|
||||||
output_links: list[Link] = Field(default_factory=list)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def credentials_optional(self) -> bool:
|
def credentials_optional(self) -> bool:
|
||||||
@@ -221,33 +221,18 @@ class NodeModel(Node):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class GraphBaseMeta(BaseDbModel):
|
class BaseGraph(BaseDbModel):
|
||||||
"""
|
|
||||||
Shared base for `GraphMeta` and `BaseGraph`, with core graph metadata fields.
|
|
||||||
"""
|
|
||||||
|
|
||||||
version: int = 1
|
version: int = 1
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
instructions: str | None = None
|
instructions: str | None = None
|
||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
|
nodes: list[Node] = []
|
||||||
|
links: list[Link] = []
|
||||||
forked_from_id: str | None = None
|
forked_from_id: str | None = None
|
||||||
forked_from_version: int | None = None
|
forked_from_version: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class BaseGraph(GraphBaseMeta):
|
|
||||||
"""
|
|
||||||
Graph with nodes, links, and computed I/O schema fields.
|
|
||||||
|
|
||||||
Used to represent sub-graphs within a `Graph`. Contains the full graph
|
|
||||||
structure including nodes and links, plus computed fields for schemas
|
|
||||||
and trigger info. Does NOT include user_id or created_at (see GraphModel).
|
|
||||||
"""
|
|
||||||
|
|
||||||
nodes: list[Node] = Field(default_factory=list)
|
|
||||||
links: list[Link] = Field(default_factory=list)
|
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def input_schema(self) -> dict[str, Any]:
|
def input_schema(self) -> dict[str, Any]:
|
||||||
@@ -376,79 +361,44 @@ class GraphTriggerInfo(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Graph(BaseGraph):
|
class Graph(BaseGraph):
|
||||||
"""Creatable graph model used in API create/update endpoints."""
|
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
|
||||||
|
|
||||||
sub_graphs: list[BaseGraph] = Field(default_factory=list) # Flattened sub-graphs
|
|
||||||
|
|
||||||
|
|
||||||
class GraphMeta(GraphBaseMeta):
|
|
||||||
"""
|
|
||||||
Lightweight graph metadata model representing an existing graph from the database,
|
|
||||||
for use in listings and summaries.
|
|
||||||
|
|
||||||
Lacks `GraphModel`'s nodes, links, and expensive computed fields.
|
|
||||||
Use for list endpoints where full graph data is not needed and performance matters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: str # type: ignore
|
|
||||||
version: int # type: ignore
|
|
||||||
user_id: str
|
|
||||||
created_at: datetime
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_db(cls, graph: "AgentGraph") -> Self:
|
|
||||||
return cls(
|
|
||||||
id=graph.id,
|
|
||||||
version=graph.version,
|
|
||||||
is_active=graph.isActive,
|
|
||||||
name=graph.name or "",
|
|
||||||
description=graph.description or "",
|
|
||||||
instructions=graph.instructions,
|
|
||||||
recommended_schedule_cron=graph.recommendedScheduleCron,
|
|
||||||
forked_from_id=graph.forkedFromId,
|
|
||||||
forked_from_version=graph.forkedFromVersion,
|
|
||||||
user_id=graph.userId,
|
|
||||||
created_at=graph.createdAt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphModel(Graph, GraphMeta):
|
|
||||||
"""
|
|
||||||
Full graph model representing an existing graph from the database.
|
|
||||||
|
|
||||||
This is the primary model for working with persisted graphs. Includes all
|
|
||||||
graph data (nodes, links, sub_graphs) plus user ownership and timestamps.
|
|
||||||
Provides computed fields (input_schema, output_schema, etc.) used during
|
|
||||||
set-up (frontend) and execution (backend).
|
|
||||||
|
|
||||||
Inherits from:
|
|
||||||
- `Graph`: provides structure (nodes, links, sub_graphs) and computed schemas
|
|
||||||
- `GraphMeta`: provides user_id, created_at for database records
|
|
||||||
"""
|
|
||||||
|
|
||||||
nodes: list[NodeModel] = Field(default_factory=list) # type: ignore
|
|
||||||
|
|
||||||
@property
|
|
||||||
def starting_nodes(self) -> list[NodeModel]:
|
|
||||||
outbound_nodes = {link.sink_id for link in self.links}
|
|
||||||
input_nodes = {
|
|
||||||
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
|
||||||
}
|
|
||||||
return [
|
|
||||||
node
|
|
||||||
for node in self.nodes
|
|
||||||
if node.id not in outbound_nodes or node.id in input_nodes
|
|
||||||
]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
|
||||||
return cast(NodeModel, super().webhook_input_node)
|
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def credentials_input_schema(self) -> dict[str, Any]:
|
def credentials_input_schema(self) -> dict[str, Any]:
|
||||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
schema = self._credentials_input_schema.jsonschema()
|
||||||
|
|
||||||
|
# Determine which credential fields are required based on credentials_optional metadata
|
||||||
|
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||||
|
required_fields = []
|
||||||
|
|
||||||
|
# Build a map of node_id -> node for quick lookup
|
||||||
|
all_nodes = {node.id: node for node in self.nodes}
|
||||||
|
for sub_graph in self.sub_graphs:
|
||||||
|
for node in sub_graph.nodes:
|
||||||
|
all_nodes[node.id] = node
|
||||||
|
|
||||||
|
for field_key, (
|
||||||
|
_field_info,
|
||||||
|
node_field_pairs,
|
||||||
|
) in graph_credentials_inputs.items():
|
||||||
|
# A field is required if ANY node using it has credentials_optional=False
|
||||||
|
is_required = False
|
||||||
|
for node_id, _field_name in node_field_pairs:
|
||||||
|
node = all_nodes.get(node_id)
|
||||||
|
if node and not node.credentials_optional:
|
||||||
|
is_required = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if is_required:
|
||||||
|
required_fields.append(field_key)
|
||||||
|
|
||||||
|
schema["required"] = required_fields
|
||||||
|
return schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||||
|
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||||
f"{graph_credentials_inputs}"
|
f"{graph_credentials_inputs}"
|
||||||
@@ -456,8 +406,8 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
|
|
||||||
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
||||||
graph_cred_fields = list(graph_credentials_inputs.values())
|
graph_cred_fields = list(graph_credentials_inputs.values())
|
||||||
for i, (field, keys, _) in enumerate(graph_cred_fields):
|
for i, (field, keys) in enumerate(graph_cred_fields):
|
||||||
for other_field, other_keys, _ in list(graph_cred_fields)[i + 1 :]:
|
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
||||||
if field.provider != other_field.provider:
|
if field.provider != other_field.provider:
|
||||||
continue
|
continue
|
||||||
if ProviderName.HTTP in field.provider:
|
if ProviderName.HTTP in field.provider:
|
||||||
@@ -473,78 +423,31 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
f"keys: {keys} <> {other_keys}."
|
f"keys: {keys} <> {other_keys}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build JSON schema directly to avoid expensive create_model + validation overhead
|
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
|
||||||
properties = {}
|
agg_field_key: (
|
||||||
required_fields = []
|
CredentialsMetaInput[
|
||||||
|
Literal[tuple(field_info.provider)], # type: ignore
|
||||||
for agg_field_key, (
|
Literal[tuple(field_info.supported_types)], # type: ignore
|
||||||
field_info,
|
],
|
||||||
_,
|
CredentialsField(
|
||||||
is_required,
|
required_scopes=set(field_info.required_scopes or []),
|
||||||
) in graph_credentials_inputs.items():
|
discriminator=field_info.discriminator,
|
||||||
providers = list(field_info.provider)
|
discriminator_mapping=field_info.discriminator_mapping,
|
||||||
cred_types = list(field_info.supported_types)
|
discriminator_values=field_info.discriminator_values,
|
||||||
|
),
|
||||||
field_schema: dict[str, Any] = {
|
|
||||||
"credentials_provider": providers,
|
|
||||||
"credentials_types": cred_types,
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"id": {"title": "Id", "type": "string"},
|
|
||||||
"title": {
|
|
||||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
|
||||||
"default": None,
|
|
||||||
"title": "Title",
|
|
||||||
},
|
|
||||||
"provider": {
|
|
||||||
"title": "Provider",
|
|
||||||
"type": "string",
|
|
||||||
**(
|
|
||||||
{"enum": providers}
|
|
||||||
if len(providers) > 1
|
|
||||||
else {"const": providers[0]}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"title": "Type",
|
|
||||||
"type": "string",
|
|
||||||
**(
|
|
||||||
{"enum": cred_types}
|
|
||||||
if len(cred_types) > 1
|
|
||||||
else {"const": cred_types[0]}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["id", "provider", "type"],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add other (optional) field info items
|
|
||||||
field_schema.update(
|
|
||||||
field_info.model_dump(
|
|
||||||
by_alias=True,
|
|
||||||
exclude_defaults=True,
|
|
||||||
exclude={"provider", "supported_types"}, # already included above
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
||||||
# Ensure field schema is well-formed
|
|
||||||
CredentialsMetaInput.validate_credentials_field_schema(
|
|
||||||
field_schema, agg_field_key
|
|
||||||
)
|
|
||||||
|
|
||||||
properties[agg_field_key] = field_schema
|
|
||||||
if is_required:
|
|
||||||
required_fields.append(agg_field_key)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": properties,
|
|
||||||
"required": required_fields,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return create_model(
|
||||||
|
self.name.replace(" ", "") + "CredentialsInputSchema",
|
||||||
|
__base__=BlockSchema,
|
||||||
|
**fields, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
def aggregate_credentials_inputs(
|
def aggregate_credentials_inputs(
|
||||||
self,
|
self,
|
||||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
dict[aggregated_field_key, tuple(
|
dict[aggregated_field_key, tuple(
|
||||||
@@ -552,19 +455,13 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
(now includes discriminator_values from matching nodes)
|
(now includes discriminator_values from matching nodes)
|
||||||
set[(node_id, field_name)]: Node credentials fields that are
|
set[(node_id, field_name)]: Node credentials fields that are
|
||||||
compatible with this aggregated field spec
|
compatible with this aggregated field spec
|
||||||
bool: True if the field is required (any node has credentials_optional=False)
|
|
||||||
)]
|
)]
|
||||||
"""
|
"""
|
||||||
# First collect all credential field data with input defaults
|
# First collect all credential field data with input defaults
|
||||||
# Track (field_info, (node_id, field_name), is_required) for each credential field
|
node_credential_data = []
|
||||||
node_credential_data: list[tuple[CredentialsFieldInfo, tuple[str, str]]] = []
|
|
||||||
node_required_map: dict[str, bool] = {} # node_id -> is_required
|
|
||||||
|
|
||||||
for graph in [self] + self.sub_graphs:
|
for graph in [self] + self.sub_graphs:
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
# Track if this node requires credentials (credentials_optional=False means required)
|
|
||||||
node_required_map[node.id] = not node.credentials_optional
|
|
||||||
|
|
||||||
for (
|
for (
|
||||||
field_name,
|
field_name,
|
||||||
field_info,
|
field_info,
|
||||||
@@ -588,21 +485,37 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Combine credential field info (this will merge discriminator_values automatically)
|
# Combine credential field info (this will merge discriminator_values automatically)
|
||||||
combined = CredentialsFieldInfo.combine(*node_credential_data)
|
return CredentialsFieldInfo.combine(*node_credential_data)
|
||||||
|
|
||||||
# Add is_required flag to each aggregated field
|
|
||||||
# A field is required if ANY node using it has credentials_optional=False
|
class GraphModel(Graph):
|
||||||
return {
|
user_id: str
|
||||||
key: (
|
nodes: list[NodeModel] = [] # type: ignore
|
||||||
field_info,
|
|
||||||
node_field_pairs,
|
created_at: datetime
|
||||||
any(
|
|
||||||
node_required_map.get(node_id, True)
|
@property
|
||||||
for node_id, _ in node_field_pairs
|
def starting_nodes(self) -> list[NodeModel]:
|
||||||
),
|
outbound_nodes = {link.sink_id for link in self.links}
|
||||||
)
|
input_nodes = {
|
||||||
for key, (field_info, node_field_pairs) in combined.items()
|
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
||||||
}
|
}
|
||||||
|
return [
|
||||||
|
node
|
||||||
|
for node in self.nodes
|
||||||
|
if node.id not in outbound_nodes or node.id in input_nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
||||||
|
return cast(NodeModel, super().webhook_input_node)
|
||||||
|
|
||||||
|
def meta(self) -> "GraphMeta":
|
||||||
|
"""
|
||||||
|
Returns a GraphMeta object with metadata about the graph.
|
||||||
|
This is used to return metadata about the graph without exposing nodes and links.
|
||||||
|
"""
|
||||||
|
return GraphMeta.from_graph(self)
|
||||||
|
|
||||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||||
"""
|
"""
|
||||||
@@ -743,11 +656,6 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
# For invalid blocks, we still raise immediately as this is a structural issue
|
# For invalid blocks, we still raise immediately as this is a structural issue
|
||||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||||
|
|
||||||
if block.disabled:
|
|
||||||
raise ValueError(
|
|
||||||
f"Block {node.block_id} is disabled and cannot be used in graphs"
|
|
||||||
)
|
|
||||||
|
|
||||||
node_input_mask = (
|
node_input_mask = (
|
||||||
nodes_input_masks.get(node.id, {}) if nodes_input_masks else {}
|
nodes_input_masks.get(node.id, {}) if nodes_input_masks else {}
|
||||||
)
|
)
|
||||||
@@ -891,14 +799,13 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
if is_static_output_block(link.source_id):
|
if is_static_output_block(link.source_id):
|
||||||
link.is_static = True # Each value block output should be static.
|
link.is_static = True # Each value block output should be static.
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def from_db( # type: ignore[reportIncompatibleMethodOverride]
|
def from_db(
|
||||||
cls,
|
|
||||||
graph: AgentGraph,
|
graph: AgentGraph,
|
||||||
for_export: bool = False,
|
for_export: bool = False,
|
||||||
sub_graphs: list[AgentGraph] | None = None,
|
sub_graphs: list[AgentGraph] | None = None,
|
||||||
) -> Self:
|
) -> "GraphModel":
|
||||||
return cls(
|
return GraphModel(
|
||||||
id=graph.id,
|
id=graph.id,
|
||||||
user_id=graph.userId if not for_export else "",
|
user_id=graph.userId if not for_export else "",
|
||||||
version=graph.version,
|
version=graph.version,
|
||||||
@@ -924,28 +831,17 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def hide_nodes(self) -> "GraphModelWithoutNodes":
|
|
||||||
"""
|
|
||||||
Returns a copy of the `GraphModel` with nodes, links, and sub-graphs hidden
|
|
||||||
(excluded from serialization). They are still present in the model instance
|
|
||||||
so all computed fields (e.g. `credentials_input_schema`) still work.
|
|
||||||
"""
|
|
||||||
return GraphModelWithoutNodes.model_validate(self, from_attributes=True)
|
|
||||||
|
|
||||||
|
class GraphMeta(Graph):
|
||||||
|
user_id: str
|
||||||
|
|
||||||
class GraphModelWithoutNodes(GraphModel):
|
# Easy work-around to prevent exposing nodes and links in the API response
|
||||||
"""
|
nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
|
||||||
GraphModel variant that excludes nodes, links, and sub-graphs from serialization.
|
links: list[Link] = Field(default=[], exclude=True)
|
||||||
|
|
||||||
Used in contexts like the store where exposing internal graph structure
|
@staticmethod
|
||||||
is not desired. Inherits all computed fields from GraphModel but marks
|
def from_graph(graph: GraphModel) -> "GraphMeta":
|
||||||
nodes and links as excluded from JSON output.
|
return GraphMeta(**graph.model_dump())
|
||||||
"""
|
|
||||||
|
|
||||||
nodes: list[NodeModel] = Field(default_factory=list, exclude=True)
|
|
||||||
links: list[Link] = Field(default_factory=list, exclude=True)
|
|
||||||
|
|
||||||
sub_graphs: list[BaseGraph] = Field(default_factory=list, exclude=True)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphsPaginated(BaseModel):
|
class GraphsPaginated(BaseModel):
|
||||||
@@ -1016,11 +912,21 @@ async def list_graphs_paginated(
|
|||||||
where=where_clause,
|
where=where_clause,
|
||||||
distinct=["id"],
|
distinct=["id"],
|
||||||
order={"version": "desc"},
|
order={"version": "desc"},
|
||||||
|
include=AGENT_GRAPH_INCLUDE,
|
||||||
skip=offset,
|
skip=offset,
|
||||||
take=page_size,
|
take=page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_models = [GraphMeta.from_db(graph) for graph in graphs]
|
graph_models: list[GraphMeta] = []
|
||||||
|
for graph in graphs:
|
||||||
|
try:
|
||||||
|
graph_meta = GraphModel.from_db(graph).meta()
|
||||||
|
# Trigger serialization to validate that the graph is well formed
|
||||||
|
graph_meta.model_dump()
|
||||||
|
graph_models.append(graph_meta)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing graph {graph.id}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
return GraphsPaginated(
|
return GraphsPaginated(
|
||||||
graphs=graph_models,
|
graphs=graph_models,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
get_args,
|
get_args,
|
||||||
)
|
)
|
||||||
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from prisma.enums import CreditTransactionType, OnboardingStep
|
from prisma.enums import CreditTransactionType, OnboardingStep
|
||||||
@@ -41,7 +42,6 @@ from typing_extensions import TypedDict
|
|||||||
|
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.json import loads as json_loads
|
from backend.util.json import loads as json_loads
|
||||||
from backend.util.request import parse_url
|
|
||||||
from backend.util.settings import Secrets
|
from backend.util.settings import Secrets
|
||||||
|
|
||||||
# Type alias for any provider name (including custom ones)
|
# Type alias for any provider name (including custom ones)
|
||||||
@@ -163,6 +163,7 @@ class User(BaseModel):
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from prisma.models import User as PrismaUser
|
from prisma.models import User as PrismaUser
|
||||||
|
|
||||||
|
from backend.data.block import BlockSchema
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -396,25 +397,19 @@ class HostScopedCredentials(_BaseCredentials):
|
|||||||
def matches_url(self, url: str) -> bool:
|
def matches_url(self, url: str) -> bool:
|
||||||
"""Check if this credential should be applied to the given URL."""
|
"""Check if this credential should be applied to the given URL."""
|
||||||
|
|
||||||
request_host, request_port = _extract_host_from_url(url)
|
parsed_url = urlparse(url)
|
||||||
cred_scope_host, cred_scope_port = _extract_host_from_url(self.host)
|
# Extract hostname without port
|
||||||
|
request_host = parsed_url.hostname
|
||||||
if not request_host:
|
if not request_host:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# If a port is specified in credential host, the request host port must match
|
# Simple host matching - exact match or wildcard subdomain match
|
||||||
if cred_scope_port is not None and request_port != cred_scope_port:
|
if self.host == request_host:
|
||||||
return False
|
|
||||||
# Non-standard ports are only allowed if explicitly specified in credential host
|
|
||||||
elif cred_scope_port is None and request_port not in (80, 443, None):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Simple host matching
|
|
||||||
if cred_scope_host == request_host:
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
||||||
if cred_scope_host.startswith("*."):
|
if self.host.startswith("*."):
|
||||||
domain = cred_scope_host[2:] # Remove "*."
|
domain = self.host[2:] # Remove "*."
|
||||||
return request_host.endswith(f".{domain}") or request_host == domain
|
return request_host.endswith(f".{domain}") or request_host == domain
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@@ -507,13 +502,15 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
|
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
|
||||||
return get_args(cls.model_fields["type"].annotation)
|
return get_args(cls.model_fields["type"].annotation)
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def validate_credentials_field_schema(
|
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
|
||||||
field_schema: dict[str, Any], field_name: str
|
|
||||||
):
|
|
||||||
"""Validates the schema of a credentials input field"""
|
"""Validates the schema of a credentials input field"""
|
||||||
|
field_name = next(
|
||||||
|
name for name, type in model.get_credentials_fields().items() if type is cls
|
||||||
|
)
|
||||||
|
field_schema = model.jsonschema()["properties"][field_name]
|
||||||
try:
|
try:
|
||||||
field_info = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
if "Field required [type=missing" not in str(e):
|
if "Field required [type=missing" not in str(e):
|
||||||
raise
|
raise
|
||||||
@@ -523,11 +520,11 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
f"{field_schema}"
|
f"{field_schema}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
providers = field_info.provider
|
providers = cls.allowed_providers()
|
||||||
if (
|
if (
|
||||||
providers is not None
|
providers is not None
|
||||||
and len(providers) > 1
|
and len(providers) > 1
|
||||||
and not field_info.discriminator
|
and not schema_extra.discriminator
|
||||||
):
|
):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Multi-provider CredentialsField '{field_name}' "
|
f"Multi-provider CredentialsField '{field_name}' "
|
||||||
@@ -554,13 +551,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _extract_host_from_url(url: str) -> tuple[str, int | None]:
|
def _extract_host_from_url(url: str) -> str:
|
||||||
"""Extract host and port from URL for grouping host-scoped credentials."""
|
"""Extract host from URL for grouping host-scoped credentials."""
|
||||||
try:
|
try:
|
||||||
parsed = parse_url(url)
|
parsed = urlparse(url)
|
||||||
return parsed.hostname or url, parsed.port
|
return parsed.hostname or url
|
||||||
except Exception:
|
except Exception:
|
||||||
return "", None
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||||
@@ -609,7 +606,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
providers = frozenset(
|
providers = frozenset(
|
||||||
[cast(CP, "http")]
|
[cast(CP, "http")]
|
||||||
+ [
|
+ [
|
||||||
cast(CP, parse_url(str(value)).netloc)
|
cast(CP, _extract_host_from_url(str(value)))
|
||||||
for value in field.discriminator_values
|
for value in field.discriminator_values
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -79,23 +79,10 @@ class TestHostScopedCredentials:
|
|||||||
headers={"Authorization": SecretStr("Bearer token")},
|
headers={"Authorization": SecretStr("Bearer token")},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Non-standard ports require explicit port in credential host
|
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||||
assert not creds.matches_url("http://localhost:8080/api/v1")
|
|
||||||
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
||||||
assert creds.matches_url("http://localhost/simple")
|
assert creds.matches_url("http://localhost/simple")
|
||||||
|
|
||||||
def test_matches_url_with_explicit_port(self):
|
|
||||||
"""Test URL matching with explicit port in credential host."""
|
|
||||||
creds = HostScopedCredentials(
|
|
||||||
provider="custom",
|
|
||||||
host="localhost:8080",
|
|
||||||
headers={"Authorization": SecretStr("Bearer token")},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
|
||||||
assert not creds.matches_url("http://localhost:3000/api/v1")
|
|
||||||
assert not creds.matches_url("http://localhost/simple")
|
|
||||||
|
|
||||||
def test_empty_headers_dict(self):
|
def test_empty_headers_dict(self):
|
||||||
"""Test HostScopedCredentials with empty headers."""
|
"""Test HostScopedCredentials with empty headers."""
|
||||||
creds = HostScopedCredentials(
|
creds = HostScopedCredentials(
|
||||||
@@ -141,20 +128,8 @@ class TestHostScopedCredentials:
|
|||||||
("*.example.com", "https://sub.api.example.com/test", True),
|
("*.example.com", "https://sub.api.example.com/test", True),
|
||||||
("*.example.com", "https://example.com/test", True),
|
("*.example.com", "https://example.com/test", True),
|
||||||
("*.example.com", "https://example.org/test", False),
|
("*.example.com", "https://example.org/test", False),
|
||||||
# Non-standard ports require explicit port in credential host
|
("localhost", "http://localhost:3000/test", True),
|
||||||
("localhost", "http://localhost:3000/test", False),
|
|
||||||
("localhost:3000", "http://localhost:3000/test", True),
|
|
||||||
("localhost", "http://127.0.0.1:3000/test", False),
|
("localhost", "http://127.0.0.1:3000/test", False),
|
||||||
# IPv6 addresses (frontend stores with brackets via URL.hostname)
|
|
||||||
("[::1]", "http://[::1]/test", True),
|
|
||||||
("[::1]", "http://[::1]:80/test", True),
|
|
||||||
("[::1]", "https://[::1]:443/test", True),
|
|
||||||
("[::1]", "http://[::1]:8080/test", False), # Non-standard port
|
|
||||||
("[::1]:8080", "http://[::1]:8080/test", True),
|
|
||||||
("[::1]:8080", "http://[::1]:9090/test", False),
|
|
||||||
("[2001:db8::1]", "http://[2001:db8::1]/path", True),
|
|
||||||
("[2001:db8::1]", "https://[2001:db8::1]:443/path", True),
|
|
||||||
("[2001:db8::1]", "http://[2001:db8::ff]/path", False),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -226,10 +225,6 @@ class SyncRabbitMQ(RabbitMQBase):
|
|||||||
class AsyncRabbitMQ(RabbitMQBase):
|
class AsyncRabbitMQ(RabbitMQBase):
|
||||||
"""Asynchronous RabbitMQ client"""
|
"""Asynchronous RabbitMQ client"""
|
||||||
|
|
||||||
def __init__(self, config: RabbitMQConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
self._reconnect_lock: asyncio.Lock | None = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return bool(self._connection and not self._connection.is_closed)
|
return bool(self._connection and not self._connection.is_closed)
|
||||||
@@ -240,17 +235,7 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
|
|
||||||
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
|
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
if self.is_connected and self._channel and not self._channel.is_closed:
|
if self.is_connected:
|
||||||
return
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.is_connected
|
|
||||||
and self._connection
|
|
||||||
and (self._channel is None or self._channel.is_closed)
|
|
||||||
):
|
|
||||||
self._channel = await self._connection.channel()
|
|
||||||
await self._channel.set_qos(prefetch_count=1)
|
|
||||||
await self.declare_infrastructure()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self._connection = await aio_pika.connect_robust(
|
self._connection = await aio_pika.connect_robust(
|
||||||
@@ -306,46 +291,24 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
exchange, routing_key=queue.routing_key or queue.name
|
exchange, routing_key=queue.routing_key or queue.name
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@func_retry
|
||||||
def _lock(self) -> asyncio.Lock:
|
async def publish_message(
|
||||||
if self._reconnect_lock is None:
|
|
||||||
self._reconnect_lock = asyncio.Lock()
|
|
||||||
return self._reconnect_lock
|
|
||||||
|
|
||||||
async def _ensure_channel(self) -> aio_pika.abc.AbstractChannel:
|
|
||||||
"""Get a valid channel, reconnecting if the current one is stale.
|
|
||||||
|
|
||||||
Uses a lock to prevent concurrent reconnection attempts from racing.
|
|
||||||
"""
|
|
||||||
if self.is_ready:
|
|
||||||
return self._channel # type: ignore # is_ready guarantees non-None
|
|
||||||
|
|
||||||
async with self._lock:
|
|
||||||
# Double-check after acquiring lock
|
|
||||||
if self.is_ready:
|
|
||||||
return self._channel # type: ignore
|
|
||||||
|
|
||||||
self._channel = None
|
|
||||||
await self.connect()
|
|
||||||
|
|
||||||
if self._channel is None:
|
|
||||||
raise RuntimeError("Channel should be established after connect")
|
|
||||||
|
|
||||||
return self._channel
|
|
||||||
|
|
||||||
async def _publish_once(
|
|
||||||
self,
|
self,
|
||||||
routing_key: str,
|
routing_key: str,
|
||||||
message: str,
|
message: str,
|
||||||
exchange: Optional[Exchange] = None,
|
exchange: Optional[Exchange] = None,
|
||||||
persistent: bool = True,
|
persistent: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
channel = await self._ensure_channel()
|
if not self.is_ready:
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
if self._channel is None:
|
||||||
|
raise RuntimeError("Channel should be established after connect")
|
||||||
|
|
||||||
if exchange:
|
if exchange:
|
||||||
exchange_obj = await channel.get_exchange(exchange.name)
|
exchange_obj = await self._channel.get_exchange(exchange.name)
|
||||||
else:
|
else:
|
||||||
exchange_obj = channel.default_exchange
|
exchange_obj = self._channel.default_exchange
|
||||||
|
|
||||||
await exchange_obj.publish(
|
await exchange_obj.publish(
|
||||||
aio_pika.Message(
|
aio_pika.Message(
|
||||||
@@ -359,23 +322,9 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
routing_key=routing_key,
|
routing_key=routing_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@func_retry
|
|
||||||
async def publish_message(
|
|
||||||
self,
|
|
||||||
routing_key: str,
|
|
||||||
message: str,
|
|
||||||
exchange: Optional[Exchange] = None,
|
|
||||||
persistent: bool = True,
|
|
||||||
) -> None:
|
|
||||||
try:
|
|
||||||
await self._publish_once(routing_key, message, exchange, persistent)
|
|
||||||
except aio_pika.exceptions.ChannelInvalidStateError:
|
|
||||||
logger.warning(
|
|
||||||
"RabbitMQ channel invalid, forcing reconnect and retrying publish"
|
|
||||||
)
|
|
||||||
async with self._lock:
|
|
||||||
self._channel = None
|
|
||||||
await self._publish_once(routing_key, message, exchange, persistent)
|
|
||||||
|
|
||||||
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
|
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
|
||||||
return await self._ensure_channel()
|
if not self.is_ready:
|
||||||
|
await self.connect()
|
||||||
|
if self._channel is None:
|
||||||
|
raise RuntimeError("Channel should be established after connect")
|
||||||
|
return self._channel
|
||||||
|
|||||||
@@ -213,9 +213,6 @@ async def execute_node(
|
|||||||
block_name=node_block.name,
|
block_name=node_block.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
if node_block.disabled:
|
|
||||||
raise ValueError(f"Block {node_block.id} is disabled and cannot be executed")
|
|
||||||
|
|
||||||
# Sanity check: validate the execution input.
|
# Sanity check: validate the execution input.
|
||||||
input_data, error = validate_exec(node, data.inputs, resolve_input=False)
|
input_data, error = validate_exec(node, data.inputs, resolve_input=False)
|
||||||
if input_data is None:
|
if input_data is None:
|
||||||
|
|||||||
@@ -373,7 +373,7 @@ def make_node_credentials_input_map(
|
|||||||
# Get aggregated credentials fields for the graph
|
# Get aggregated credentials fields for the graph
|
||||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||||
|
|
||||||
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
|
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
|
||||||
# Best-effort map: skip missing items
|
# Best-effort map: skip missing items
|
||||||
if graph_input_name not in graph_credentials_input:
|
if graph_input_name not in graph_credentials_input:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -224,14 +224,6 @@ openweathermap_credentials = APIKeyCredentials(
|
|||||||
expires_at=None,
|
expires_at=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
elevenlabs_credentials = APIKeyCredentials(
|
|
||||||
id="f4a8b6c2-3d1e-4f5a-9b8c-7d6e5f4a3b2c",
|
|
||||||
provider="elevenlabs",
|
|
||||||
api_key=SecretStr(settings.secrets.elevenlabs_api_key),
|
|
||||||
title="Use Credits for ElevenLabs",
|
|
||||||
expires_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_CREDENTIALS = [
|
DEFAULT_CREDENTIALS = [
|
||||||
ollama_credentials,
|
ollama_credentials,
|
||||||
revid_credentials,
|
revid_credentials,
|
||||||
@@ -260,7 +252,6 @@ DEFAULT_CREDENTIALS = [
|
|||||||
v0_credentials,
|
v0_credentials,
|
||||||
webshare_proxy_credentials,
|
webshare_proxy_credentials,
|
||||||
openweathermap_credentials,
|
openweathermap_credentials,
|
||||||
elevenlabs_credentials,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
||||||
@@ -375,8 +366,6 @@ class IntegrationCredentialsStore:
|
|||||||
all_credentials.append(webshare_proxy_credentials)
|
all_credentials.append(webshare_proxy_credentials)
|
||||||
if settings.secrets.openweathermap_api_key:
|
if settings.secrets.openweathermap_api_key:
|
||||||
all_credentials.append(openweathermap_credentials)
|
all_credentials.append(openweathermap_credentials)
|
||||||
if settings.secrets.elevenlabs_api_key:
|
|
||||||
all_credentials.append(elevenlabs_credentials)
|
|
||||||
return all_credentials
|
return all_credentials
|
||||||
|
|
||||||
async def get_creds_by_id(
|
async def get_creds_by_id(
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ class ProviderName(str, Enum):
|
|||||||
DISCORD = "discord"
|
DISCORD = "discord"
|
||||||
D_ID = "d_id"
|
D_ID = "d_id"
|
||||||
E2B = "e2b"
|
E2B = "e2b"
|
||||||
ELEVENLABS = "elevenlabs"
|
|
||||||
FAL = "fal"
|
FAL = "fal"
|
||||||
GITHUB = "github"
|
GITHUB = "github"
|
||||||
GOOGLE = "google"
|
GOOGLE = "google"
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ from pathlib import Path
|
|||||||
from typing import TYPE_CHECKING, Literal
|
from typing import TYPE_CHECKING, Literal
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
@@ -19,35 +17,6 @@ from backend.util.virus_scanner import scan_content_safe
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceUri(BaseModel):
|
|
||||||
"""Parsed workspace:// URI."""
|
|
||||||
|
|
||||||
file_ref: str # File ID or path (e.g. "abc123" or "/path/to/file.txt")
|
|
||||||
mime_type: str | None = None # MIME type from fragment (e.g. "video/mp4")
|
|
||||||
is_path: bool = False # True if file_ref is a path (starts with "/")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_workspace_uri(uri: str) -> WorkspaceUri:
|
|
||||||
"""Parse a workspace:// URI into its components.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
"workspace://abc123" → WorkspaceUri(file_ref="abc123", mime_type=None, is_path=False)
|
|
||||||
"workspace://abc123#video/mp4" → WorkspaceUri(file_ref="abc123", mime_type="video/mp4", is_path=False)
|
|
||||||
"workspace:///path/to/file.txt" → WorkspaceUri(file_ref="/path/to/file.txt", mime_type=None, is_path=True)
|
|
||||||
"""
|
|
||||||
raw = uri.removeprefix("workspace://")
|
|
||||||
mime_type: str | None = None
|
|
||||||
if "#" in raw:
|
|
||||||
raw, fragment = raw.split("#", 1)
|
|
||||||
mime_type = fragment or None
|
|
||||||
return WorkspaceUri(
|
|
||||||
file_ref=raw,
|
|
||||||
mime_type=mime_type,
|
|
||||||
is_path=raw.startswith("/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Return format options for store_media_file
|
# Return format options for store_media_file
|
||||||
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||||
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
||||||
@@ -214,20 +183,22 @@ async def store_media_file(
|
|||||||
"This file type is only available in CoPilot sessions."
|
"This file type is only available in CoPilot sessions."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse workspace reference (strips #mimeType fragment from file ID)
|
# Parse workspace reference
|
||||||
ws = parse_workspace_uri(file)
|
# workspace://abc123 - by file ID
|
||||||
|
# workspace:///path/to/file.txt - by virtual path
|
||||||
|
file_ref = file[12:] # Remove "workspace://"
|
||||||
|
|
||||||
if ws.is_path:
|
if file_ref.startswith("/"):
|
||||||
# Path reference: workspace:///path/to/file.txt
|
# Path reference
|
||||||
workspace_content = await workspace_manager.read_file(ws.file_ref)
|
workspace_content = await workspace_manager.read_file(file_ref)
|
||||||
file_info = await workspace_manager.get_file_info_by_path(ws.file_ref)
|
file_info = await workspace_manager.get_file_info_by_path(file_ref)
|
||||||
filename = sanitize_filename(
|
filename = sanitize_filename(
|
||||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# ID reference: workspace://abc123 or workspace://abc123#video/mp4
|
# ID reference
|
||||||
workspace_content = await workspace_manager.read_file_by_id(ws.file_ref)
|
workspace_content = await workspace_manager.read_file_by_id(file_ref)
|
||||||
file_info = await workspace_manager.get_file_info(ws.file_ref)
|
file_info = await workspace_manager.get_file_info(file_ref)
|
||||||
filename = sanitize_filename(
|
filename = sanitize_filename(
|
||||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
)
|
)
|
||||||
@@ -342,14 +313,6 @@ async def store_media_file(
|
|||||||
if not target_path.is_file():
|
if not target_path.is_file():
|
||||||
raise ValueError(f"Local file does not exist: {target_path}")
|
raise ValueError(f"Local file does not exist: {target_path}")
|
||||||
|
|
||||||
# Virus scan the local file before any further processing
|
|
||||||
local_content = target_path.read_bytes()
|
|
||||||
if len(local_content) > MAX_FILE_SIZE_BYTES:
|
|
||||||
raise ValueError(
|
|
||||||
f"File too large: {len(local_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
|
||||||
)
|
|
||||||
await scan_content_safe(local_content, filename=sanitized_file)
|
|
||||||
|
|
||||||
# Return based on requested format
|
# Return based on requested format
|
||||||
if return_format == "for_local_processing":
|
if return_format == "for_local_processing":
|
||||||
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
||||||
@@ -371,21 +334,7 @@ async def store_media_file(
|
|||||||
|
|
||||||
# Don't re-save if input was already from workspace
|
# Don't re-save if input was already from workspace
|
||||||
if is_from_workspace:
|
if is_from_workspace:
|
||||||
# Return original workspace reference, ensuring MIME type fragment
|
# Return original workspace reference
|
||||||
ws = parse_workspace_uri(file)
|
|
||||||
if not ws.mime_type:
|
|
||||||
# Add MIME type fragment if missing (older refs without it)
|
|
||||||
try:
|
|
||||||
if ws.is_path:
|
|
||||||
info = await workspace_manager.get_file_info_by_path(
|
|
||||||
ws.file_ref
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
info = await workspace_manager.get_file_info(ws.file_ref)
|
|
||||||
if info:
|
|
||||||
return MediaFileType(f"{file}#{info.mimeType}")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return MediaFileType(file)
|
return MediaFileType(file)
|
||||||
|
|
||||||
# Save new content to workspace
|
# Save new content to workspace
|
||||||
@@ -397,7 +346,7 @@ async def store_media_file(
|
|||||||
filename=filename,
|
filename=filename,
|
||||||
overwrite=True,
|
overwrite=True,
|
||||||
)
|
)
|
||||||
return MediaFileType(f"workspace://{file_record.id}#{file_record.mimeType}")
|
return MediaFileType(f"workspace://{file_record.id}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid return_format: {return_format}")
|
raise ValueError(f"Invalid return_format: {return_format}")
|
||||||
|
|||||||
@@ -247,100 +247,3 @@ class TestFileCloudIntegration:
|
|||||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
return_format="for_local_processing",
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_store_media_file_local_path_scanned(self):
|
|
||||||
"""Test that local file paths are scanned for viruses."""
|
|
||||||
graph_exec_id = "test-exec-123"
|
|
||||||
local_file = "test_video.mp4"
|
|
||||||
file_content = b"fake video content"
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.util.file.get_cloud_storage_handler"
|
|
||||||
) as mock_handler_getter, patch(
|
|
||||||
"backend.util.file.scan_content_safe"
|
|
||||||
) as mock_scan, patch(
|
|
||||||
"backend.util.file.Path"
|
|
||||||
) as mock_path_class:
|
|
||||||
|
|
||||||
# Mock cloud storage handler - not a cloud path
|
|
||||||
mock_handler = MagicMock()
|
|
||||||
mock_handler.is_cloud_path.return_value = False
|
|
||||||
mock_handler_getter.return_value = mock_handler
|
|
||||||
|
|
||||||
# Mock virus scanner
|
|
||||||
mock_scan.return_value = None
|
|
||||||
|
|
||||||
# Mock file system operations
|
|
||||||
mock_base_path = MagicMock()
|
|
||||||
mock_target_path = MagicMock()
|
|
||||||
mock_resolved_path = MagicMock()
|
|
||||||
|
|
||||||
mock_path_class.return_value = mock_base_path
|
|
||||||
mock_base_path.mkdir = MagicMock()
|
|
||||||
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
|
||||||
mock_target_path.resolve.return_value = mock_resolved_path
|
|
||||||
mock_resolved_path.is_relative_to.return_value = True
|
|
||||||
mock_resolved_path.is_file.return_value = True
|
|
||||||
mock_resolved_path.read_bytes.return_value = file_content
|
|
||||||
mock_resolved_path.relative_to.return_value = Path(local_file)
|
|
||||||
mock_resolved_path.name = local_file
|
|
||||||
|
|
||||||
result = await store_media_file(
|
|
||||||
file=MediaFileType(local_file),
|
|
||||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify virus scan was called for local file
|
|
||||||
mock_scan.assert_called_once_with(file_content, filename=local_file)
|
|
||||||
|
|
||||||
# Result should be the relative path
|
|
||||||
assert str(result) == local_file
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_store_media_file_local_path_virus_detected(self):
|
|
||||||
"""Test that infected local files raise VirusDetectedError."""
|
|
||||||
from backend.api.features.store.exceptions import VirusDetectedError
|
|
||||||
|
|
||||||
graph_exec_id = "test-exec-123"
|
|
||||||
local_file = "infected.exe"
|
|
||||||
file_content = b"malicious content"
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.util.file.get_cloud_storage_handler"
|
|
||||||
) as mock_handler_getter, patch(
|
|
||||||
"backend.util.file.scan_content_safe"
|
|
||||||
) as mock_scan, patch(
|
|
||||||
"backend.util.file.Path"
|
|
||||||
) as mock_path_class:
|
|
||||||
|
|
||||||
# Mock cloud storage handler - not a cloud path
|
|
||||||
mock_handler = MagicMock()
|
|
||||||
mock_handler.is_cloud_path.return_value = False
|
|
||||||
mock_handler_getter.return_value = mock_handler
|
|
||||||
|
|
||||||
# Mock virus scanner to detect virus
|
|
||||||
mock_scan.side_effect = VirusDetectedError(
|
|
||||||
"EICAR-Test-File", "File rejected due to virus detection"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock file system operations
|
|
||||||
mock_base_path = MagicMock()
|
|
||||||
mock_target_path = MagicMock()
|
|
||||||
mock_resolved_path = MagicMock()
|
|
||||||
|
|
||||||
mock_path_class.return_value = mock_base_path
|
|
||||||
mock_base_path.mkdir = MagicMock()
|
|
||||||
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
|
||||||
mock_target_path.resolve.return_value = mock_resolved_path
|
|
||||||
mock_resolved_path.is_relative_to.return_value = True
|
|
||||||
mock_resolved_path.is_file.return_value = True
|
|
||||||
mock_resolved_path.read_bytes.return_value = file_content
|
|
||||||
|
|
||||||
with pytest.raises(VirusDetectedError):
|
|
||||||
await store_media_file(
|
|
||||||
file=MediaFileType(local_file),
|
|
||||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -364,44 +364,6 @@ def _remove_orphan_tool_responses(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def validate_and_remove_orphan_tool_responses(
|
|
||||||
messages: list[dict],
|
|
||||||
log_warning: bool = True,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Validate tool_call/tool_response pairs and remove orphaned responses.
|
|
||||||
|
|
||||||
Scans messages in order, tracking all tool_call IDs. Any tool response
|
|
||||||
referencing an ID not seen in a preceding message is considered orphaned
|
|
||||||
and removed. This prevents API errors like Anthropic's "unexpected tool_use_id".
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of messages to validate (OpenAI or Anthropic format)
|
|
||||||
log_warning: Whether to log a warning when orphans are found
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A new list with orphaned tool responses removed
|
|
||||||
"""
|
|
||||||
available_ids: set[str] = set()
|
|
||||||
orphan_ids: set[str] = set()
|
|
||||||
|
|
||||||
for msg in messages:
|
|
||||||
available_ids |= _extract_tool_call_ids_from_message(msg)
|
|
||||||
for resp_id in _extract_tool_response_ids_from_message(msg):
|
|
||||||
if resp_id not in available_ids:
|
|
||||||
orphan_ids.add(resp_id)
|
|
||||||
|
|
||||||
if not orphan_ids:
|
|
||||||
return messages
|
|
||||||
|
|
||||||
if log_warning:
|
|
||||||
logger.warning(
|
|
||||||
f"Removing {len(orphan_ids)} orphan tool response(s): {orphan_ids}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return _remove_orphan_tool_responses(messages, orphan_ids)
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure_tool_pairs_intact(
|
def _ensure_tool_pairs_intact(
|
||||||
recent_messages: list[dict],
|
recent_messages: list[dict],
|
||||||
all_messages: list[dict],
|
all_messages: list[dict],
|
||||||
@@ -761,13 +723,6 @@ async def compress_context(
|
|||||||
|
|
||||||
# Filter out any None values that may have been introduced
|
# Filter out any None values that may have been introduced
|
||||||
final_msgs: list[dict] = [m for m in msgs if m is not None]
|
final_msgs: list[dict] = [m for m in msgs if m is not None]
|
||||||
|
|
||||||
# ---- STEP 6: Final tool-pair validation ---------------------------------
|
|
||||||
# After all compression steps, verify that every tool response has a
|
|
||||||
# matching tool_call in a preceding assistant message. Remove orphans
|
|
||||||
# to prevent API errors (e.g., Anthropic's "unexpected tool_use_id").
|
|
||||||
final_msgs = validate_and_remove_orphan_tool_responses(final_msgs)
|
|
||||||
|
|
||||||
final_count = sum(_msg_tokens(m, enc) for m in final_msgs)
|
final_count = sum(_msg_tokens(m, enc) for m in final_msgs)
|
||||||
error = None
|
error = None
|
||||||
if final_count + reserve > target_tokens:
|
if final_count + reserve > target_tokens:
|
||||||
|
|||||||
@@ -157,7 +157,12 @@ async def validate_url(
|
|||||||
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
||||||
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
||||||
"""
|
"""
|
||||||
parsed = parse_url(url)
|
# Canonicalize URL
|
||||||
|
url = url.strip("/ ").replace("\\", "/")
|
||||||
|
parsed = urlparse(url)
|
||||||
|
if not parsed.scheme:
|
||||||
|
url = f"http://{url}"
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
# Check scheme
|
# Check scheme
|
||||||
if parsed.scheme not in ALLOWED_SCHEMES:
|
if parsed.scheme not in ALLOWED_SCHEMES:
|
||||||
@@ -215,17 +220,6 @@ async def validate_url(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_url(url: str) -> URL:
|
|
||||||
"""Canonicalizes and parses a URL string."""
|
|
||||||
url = url.strip("/ ").replace("\\", "/")
|
|
||||||
|
|
||||||
# Ensure scheme is present for proper parsing
|
|
||||||
if not re.match(r"[a-z0-9+.\-]+://", url):
|
|
||||||
url = f"http://{url}"
|
|
||||||
|
|
||||||
return urlparse(url)
|
|
||||||
|
|
||||||
|
|
||||||
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
|
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
|
||||||
"""
|
"""
|
||||||
Pins a URL to a specific IP address to prevent DNS rebinding attacks.
|
Pins a URL to a specific IP address to prevent DNS rebinding attacks.
|
||||||
|
|||||||
@@ -656,7 +656,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
|||||||
e2b_api_key: str = Field(default="", description="E2B API key")
|
e2b_api_key: str = Field(default="", description="E2B API key")
|
||||||
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
||||||
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
||||||
elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key")
|
|
||||||
|
|
||||||
linear_client_id: str = Field(default="", description="Linear client ID")
|
linear_client_id: str = Field(default="", description="Linear client ID")
|
||||||
linear_client_secret: str = Field(default="", description="Linear client secret")
|
linear_client_secret: str = Field(default="", description="Linear client secret")
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from backend.data.workspace import (
|
|||||||
soft_delete_workspace_file,
|
soft_delete_workspace_file,
|
||||||
)
|
)
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
from backend.util.virus_scanner import scan_content_safe
|
|
||||||
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -188,9 +187,6 @@ class WorkspaceManager:
|
|||||||
f"{Config().max_file_size_mb}MB limit"
|
f"{Config().max_file_size_mb}MB limit"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Virus scan content before persisting (defense in depth)
|
|
||||||
await scan_content_safe(content, filename=filename)
|
|
||||||
|
|
||||||
# Determine path with session scoping
|
# Determine path with session scoping
|
||||||
if path is None:
|
if path is None:
|
||||||
path = f"/{filename}"
|
path = f"/{filename}"
|
||||||
|
|||||||
7062
autogpt_platform/backend/poetry.lock
generated
7062
autogpt_platform/backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -12,16 +12,15 @@ python = ">=3.10,<3.14"
|
|||||||
aio-pika = "^9.5.5"
|
aio-pika = "^9.5.5"
|
||||||
aiohttp = "^3.10.0"
|
aiohttp = "^3.10.0"
|
||||||
aiodns = "^3.5.0"
|
aiodns = "^3.5.0"
|
||||||
anthropic = "^0.79.0"
|
anthropic = "^0.59.0"
|
||||||
apscheduler = "^3.11.1"
|
apscheduler = "^3.11.1"
|
||||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||||
click = "^8.2.0"
|
click = "^8.2.0"
|
||||||
cryptography = "^46.0"
|
cryptography = "^45.0"
|
||||||
discord-py = "^2.5.2"
|
discord-py = "^2.5.2"
|
||||||
e2b-code-interpreter = "^1.5.2"
|
e2b-code-interpreter = "^1.5.2"
|
||||||
elevenlabs = "^1.50.0"
|
fastapi = "^0.116.1"
|
||||||
fastapi = "^0.128.6"
|
|
||||||
feedparser = "^6.0.11"
|
feedparser = "^6.0.11"
|
||||||
flake8 = "^7.3.0"
|
flake8 = "^7.3.0"
|
||||||
google-api-python-client = "^2.177.0"
|
google-api-python-client = "^2.177.0"
|
||||||
@@ -34,11 +33,11 @@ html2text = "^2024.2.26"
|
|||||||
jinja2 = "^3.1.6"
|
jinja2 = "^3.1.6"
|
||||||
jsonref = "^1.1.0"
|
jsonref = "^1.1.0"
|
||||||
jsonschema = "^4.25.0"
|
jsonschema = "^4.25.0"
|
||||||
langfuse = "^3.14.1"
|
langfuse = "^3.11.0"
|
||||||
launchdarkly-server-sdk = "^9.14.1"
|
launchdarkly-server-sdk = "^9.12.0"
|
||||||
mem0ai = "^0.1.115"
|
mem0ai = "^0.1.115"
|
||||||
moviepy = "^2.1.2"
|
moviepy = "^2.1.2"
|
||||||
ollama = "^0.6.1"
|
ollama = "^0.5.1"
|
||||||
openai = "^1.97.1"
|
openai = "^1.97.1"
|
||||||
orjson = "^3.10.0"
|
orjson = "^3.10.0"
|
||||||
pika = "^1.3.2"
|
pika = "^1.3.2"
|
||||||
@@ -48,16 +47,16 @@ postmarker = "^1.0"
|
|||||||
praw = "~7.8.1"
|
praw = "~7.8.1"
|
||||||
prisma = "^0.15.0"
|
prisma = "^0.15.0"
|
||||||
rank-bm25 = "^0.2.2"
|
rank-bm25 = "^0.2.2"
|
||||||
prometheus-client = "^0.24.1"
|
prometheus-client = "^0.22.1"
|
||||||
prometheus-fastapi-instrumentator = "^7.0.0"
|
prometheus-fastapi-instrumentator = "^7.0.0"
|
||||||
psutil = "^7.0.0"
|
psutil = "^7.0.0"
|
||||||
psycopg2-binary = "^2.9.10"
|
psycopg2-binary = "^2.9.10"
|
||||||
pydantic = { extras = ["email"], version = "^2.12.5" }
|
pydantic = { extras = ["email"], version = "^2.11.7" }
|
||||||
pydantic-settings = "^2.12.0"
|
pydantic-settings = "^2.10.1"
|
||||||
pytest = "^8.4.1"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.1.0"
|
pytest-asyncio = "^1.1.0"
|
||||||
python-dotenv = "^1.1.1"
|
python-dotenv = "^1.1.1"
|
||||||
python-multipart = "^0.0.22"
|
python-multipart = "^0.0.20"
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
regex = "^2025.9.18"
|
regex = "^2025.9.18"
|
||||||
replicate = "^1.0.6"
|
replicate = "^1.0.6"
|
||||||
@@ -65,19 +64,18 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal
|
|||||||
sqlalchemy = "^2.0.40"
|
sqlalchemy = "^2.0.40"
|
||||||
strenum = "^0.4.9"
|
strenum = "^0.4.9"
|
||||||
stripe = "^11.5.0"
|
stripe = "^11.5.0"
|
||||||
supabase = "2.27.3"
|
supabase = "2.17.0"
|
||||||
tenacity = "^9.1.4"
|
tenacity = "^9.1.2"
|
||||||
todoist-api-python = "^2.1.7"
|
todoist-api-python = "^2.1.7"
|
||||||
tweepy = "^4.16.0"
|
tweepy = "^4.16.0"
|
||||||
uvicorn = { extras = ["standard"], version = "^0.40.0" }
|
uvicorn = { extras = ["standard"], version = "^0.35.0" }
|
||||||
websockets = "^15.0"
|
websockets = "^15.0"
|
||||||
youtube-transcript-api = "^1.2.1"
|
youtube-transcript-api = "^1.2.1"
|
||||||
yt-dlp = "2025.12.08"
|
|
||||||
zerobouncesdk = "^1.1.2"
|
zerobouncesdk = "^1.1.2"
|
||||||
# NOTE: please insert new dependencies in their alphabetical location
|
# NOTE: please insert new dependencies in their alphabetical location
|
||||||
pytest-snapshot = "^0.9.0"
|
pytest-snapshot = "^0.9.0"
|
||||||
aiofiles = "^25.1.0"
|
aiofiles = "^24.1.0"
|
||||||
tiktoken = "^0.12.0"
|
tiktoken = "^0.9.0"
|
||||||
aioclamd = "^1.0.0"
|
aioclamd = "^1.0.0"
|
||||||
setuptools = "^80.9.0"
|
setuptools = "^80.9.0"
|
||||||
gcloud-aio-storage = "^9.5.0"
|
gcloud-aio-storage = "^9.5.0"
|
||||||
@@ -95,13 +93,13 @@ black = "^24.10.0"
|
|||||||
faker = "^38.2.0"
|
faker = "^38.2.0"
|
||||||
httpx = "^0.28.1"
|
httpx = "^0.28.1"
|
||||||
isort = "^5.13.2"
|
isort = "^5.13.2"
|
||||||
poethepoet = "^0.41.0"
|
poethepoet = "^0.37.0"
|
||||||
pre-commit = "^4.4.0"
|
pre-commit = "^4.4.0"
|
||||||
pyright = "^1.1.407"
|
pyright = "^1.1.407"
|
||||||
pytest-mock = "^3.15.1"
|
pytest-mock = "^3.15.1"
|
||||||
pytest-watcher = "^0.6.3"
|
pytest-watcher = "^0.4.2"
|
||||||
requests = "^2.32.5"
|
requests = "^2.32.5"
|
||||||
ruff = "^0.15.0"
|
ruff = "^0.14.5"
|
||||||
# NOTE: please insert new dependencies in their alphabetical location
|
# NOTE: please insert new dependencies in their alphabetical location
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
"credentials_input_schema": {
|
"credentials_input_schema": {
|
||||||
"properties": {},
|
"properties": {},
|
||||||
"required": [],
|
"required": [],
|
||||||
|
"title": "TestGraphCredentialsInputSchema",
|
||||||
"type": "object"
|
"type": "object"
|
||||||
},
|
},
|
||||||
"description": "A test graph",
|
"description": "A test graph",
|
||||||
|
|||||||
@@ -1,14 +1,34 @@
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
"created_at": "2025-09-04T13:37:00",
|
"credentials_input_schema": {
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
"title": "TestGraphCredentialsInputSchema",
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"description": "A test graph",
|
"description": "A test graph",
|
||||||
"forked_from_id": null,
|
"forked_from_id": null,
|
||||||
"forked_from_version": null,
|
"forked_from_version": null,
|
||||||
|
"has_external_trigger": false,
|
||||||
|
"has_human_in_the_loop": false,
|
||||||
|
"has_sensitive_action": false,
|
||||||
"id": "graph-123",
|
"id": "graph-123",
|
||||||
|
"input_schema": {
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"instructions": null,
|
"instructions": null,
|
||||||
"is_active": true,
|
"is_active": true,
|
||||||
"name": "Test Graph",
|
"name": "Test Graph",
|
||||||
|
"output_schema": {
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"recommended_schedule_cron": null,
|
"recommended_schedule_cron": null,
|
||||||
|
"sub_graphs": [],
|
||||||
|
"trigger_setup_info": null,
|
||||||
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||||
"version": 1
|
"version": 1
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,12 +25,8 @@ RUN if [ -f .env.production ]; then \
|
|||||||
cp .env.default .env; \
|
cp .env.default .env; \
|
||||||
fi
|
fi
|
||||||
RUN pnpm run generate:api
|
RUN pnpm run generate:api
|
||||||
# Disable source-map generation in Docker builds to halve webpack memory usage.
|
|
||||||
# Source maps are only useful when SENTRY_AUTH_TOKEN is set (Vercel deploys);
|
|
||||||
# the Docker image never uploads them, so generating them just wastes RAM.
|
|
||||||
ENV NEXT_PUBLIC_SOURCEMAPS="false"
|
|
||||||
# In CI, we want NEXT_PUBLIC_PW_TEST=true during build so Next.js inlines it
|
# In CI, we want NEXT_PUBLIC_PW_TEST=true during build so Next.js inlines it
|
||||||
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=8192" pnpm build; else NODE_OPTIONS="--max-old-space-size=8192" pnpm build; fi
|
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=4096" pnpm build; else NODE_OPTIONS="--max-old-space-size=4096" pnpm build; fi
|
||||||
|
|
||||||
# Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile
|
# Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile
|
||||||
FROM node:21-alpine AS prod
|
FROM node:21-alpine AS prod
|
||||||
|
|||||||
@@ -1,12 +1,8 @@
|
|||||||
import { withSentryConfig } from "@sentry/nextjs";
|
import { withSentryConfig } from "@sentry/nextjs";
|
||||||
|
|
||||||
// Allow Docker builds to skip source-map generation (halves memory usage).
|
|
||||||
// Defaults to true so Vercel/local builds are unaffected.
|
|
||||||
const enableSourceMaps = process.env.NEXT_PUBLIC_SOURCEMAPS !== "false";
|
|
||||||
|
|
||||||
/** @type {import('next').NextConfig} */
|
/** @type {import('next').NextConfig} */
|
||||||
const nextConfig = {
|
const nextConfig = {
|
||||||
productionBrowserSourceMaps: enableSourceMaps,
|
productionBrowserSourceMaps: true,
|
||||||
// Externalize OpenTelemetry packages to fix Turbopack HMR issues
|
// Externalize OpenTelemetry packages to fix Turbopack HMR issues
|
||||||
serverExternalPackages: [
|
serverExternalPackages: [
|
||||||
"@opentelemetry/instrumentation",
|
"@opentelemetry/instrumentation",
|
||||||
@@ -18,37 +14,9 @@ const nextConfig = {
|
|||||||
serverActions: {
|
serverActions: {
|
||||||
bodySizeLimit: "256mb",
|
bodySizeLimit: "256mb",
|
||||||
},
|
},
|
||||||
|
// Increase body size limit for API routes (file uploads) - 256MB to match backend limit
|
||||||
|
proxyClientMaxBodySize: "256mb",
|
||||||
middlewareClientMaxBodySize: "256mb",
|
middlewareClientMaxBodySize: "256mb",
|
||||||
// Limit parallel webpack workers to reduce peak memory during builds.
|
|
||||||
cpus: 2,
|
|
||||||
},
|
|
||||||
// Work around cssnano "Invalid array length" bug in Next.js's bundled
|
|
||||||
// cssnano-simple comment parser when processing very large CSS chunks.
|
|
||||||
// CSS is still bundled correctly; gzip handles most of the size savings anyway.
|
|
||||||
webpack: (config, { dev }) => {
|
|
||||||
if (!dev) {
|
|
||||||
// Next.js adds CssMinimizerPlugin internally (after user config), so we
|
|
||||||
// can't filter it from config.plugins. Instead, intercept the webpack
|
|
||||||
// compilation hooks and replace the buggy plugin's tap with a no-op.
|
|
||||||
config.plugins.push({
|
|
||||||
apply(compiler) {
|
|
||||||
compiler.hooks.compilation.tap(
|
|
||||||
"DisableCssMinimizer",
|
|
||||||
(compilation) => {
|
|
||||||
compilation.hooks.processAssets.intercept({
|
|
||||||
register: (tap) => {
|
|
||||||
if (tap.name === "CssMinimizerPlugin") {
|
|
||||||
return { ...tap, fn: async () => {} };
|
|
||||||
}
|
|
||||||
return tap;
|
|
||||||
},
|
|
||||||
});
|
|
||||||
},
|
|
||||||
);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return config;
|
|
||||||
},
|
},
|
||||||
images: {
|
images: {
|
||||||
domains: [
|
domains: [
|
||||||
@@ -86,16 +54,9 @@ const nextConfig = {
|
|||||||
transpilePackages: ["geist"],
|
transpilePackages: ["geist"],
|
||||||
};
|
};
|
||||||
|
|
||||||
// Only run the Sentry webpack plugin when we can actually upload source maps
|
const isDevelopmentBuild = process.env.NODE_ENV !== "production";
|
||||||
// (i.e. on Vercel with SENTRY_AUTH_TOKEN set). The Sentry *runtime* SDK
|
|
||||||
// (imported in app code) still captures errors without the plugin.
|
|
||||||
// Skipping the plugin saves ~1 GB of peak memory during `next build`.
|
|
||||||
const skipSentryPlugin =
|
|
||||||
process.env.NODE_ENV !== "production" ||
|
|
||||||
!enableSourceMaps ||
|
|
||||||
!process.env.SENTRY_AUTH_TOKEN;
|
|
||||||
|
|
||||||
export default skipSentryPlugin
|
export default isDevelopmentBuild
|
||||||
? nextConfig
|
? nextConfig
|
||||||
: withSentryConfig(nextConfig, {
|
: withSentryConfig(nextConfig, {
|
||||||
// For all available options, see:
|
// For all available options, see:
|
||||||
@@ -135,7 +96,7 @@ export default skipSentryPlugin
|
|||||||
|
|
||||||
// This helps Sentry with sourcemaps... https://docs.sentry.io/platforms/javascript/guides/nextjs/sourcemaps/
|
// This helps Sentry with sourcemaps... https://docs.sentry.io/platforms/javascript/guides/nextjs/sourcemaps/
|
||||||
sourcemaps: {
|
sourcemaps: {
|
||||||
disable: !enableSourceMaps,
|
disable: false,
|
||||||
assets: [".next/**/*.js", ".next/**/*.js.map"],
|
assets: [".next/**/*.js", ".next/**/*.js.map"],
|
||||||
ignore: ["**/node_modules/**"],
|
ignore: ["**/node_modules/**"],
|
||||||
deleteSourcemapsAfterUpload: false, // Source is public anyway :)
|
deleteSourcemapsAfterUpload: false, // Source is public anyway :)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
},
|
},
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"dev": "pnpm run generate:api:force && next dev --turbo",
|
"dev": "pnpm run generate:api:force && next dev --turbo",
|
||||||
"build": "cross-env NODE_OPTIONS=--max-old-space-size=16384 next build",
|
"build": "next build",
|
||||||
"start": "next start",
|
"start": "next start",
|
||||||
"start:standalone": "cd .next/standalone && node server.js",
|
"start:standalone": "cd .next/standalone && node server.js",
|
||||||
"lint": "next lint && prettier --check .",
|
"lint": "next lint && prettier --check .",
|
||||||
@@ -30,7 +30,6 @@
|
|||||||
"defaults"
|
"defaults"
|
||||||
],
|
],
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@ai-sdk/react": "3.0.61",
|
|
||||||
"@faker-js/faker": "10.0.0",
|
"@faker-js/faker": "10.0.0",
|
||||||
"@hookform/resolvers": "5.2.2",
|
"@hookform/resolvers": "5.2.2",
|
||||||
"@next/third-parties": "15.4.6",
|
"@next/third-parties": "15.4.6",
|
||||||
@@ -61,10 +60,6 @@
|
|||||||
"@rjsf/utils": "6.1.2",
|
"@rjsf/utils": "6.1.2",
|
||||||
"@rjsf/validator-ajv8": "6.1.2",
|
"@rjsf/validator-ajv8": "6.1.2",
|
||||||
"@sentry/nextjs": "10.27.0",
|
"@sentry/nextjs": "10.27.0",
|
||||||
"@streamdown/cjk": "1.0.1",
|
|
||||||
"@streamdown/code": "1.0.1",
|
|
||||||
"@streamdown/math": "1.0.1",
|
|
||||||
"@streamdown/mermaid": "1.0.1",
|
|
||||||
"@supabase/ssr": "0.7.0",
|
"@supabase/ssr": "0.7.0",
|
||||||
"@supabase/supabase-js": "2.78.0",
|
"@supabase/supabase-js": "2.78.0",
|
||||||
"@tanstack/react-query": "5.90.6",
|
"@tanstack/react-query": "5.90.6",
|
||||||
@@ -73,7 +68,6 @@
|
|||||||
"@vercel/analytics": "1.5.0",
|
"@vercel/analytics": "1.5.0",
|
||||||
"@vercel/speed-insights": "1.2.0",
|
"@vercel/speed-insights": "1.2.0",
|
||||||
"@xyflow/react": "12.9.2",
|
"@xyflow/react": "12.9.2",
|
||||||
"ai": "6.0.59",
|
|
||||||
"boring-avatars": "1.11.2",
|
"boring-avatars": "1.11.2",
|
||||||
"class-variance-authority": "0.7.1",
|
"class-variance-authority": "0.7.1",
|
||||||
"clsx": "2.1.1",
|
"clsx": "2.1.1",
|
||||||
@@ -93,6 +87,7 @@
|
|||||||
"launchdarkly-react-client-sdk": "3.9.0",
|
"launchdarkly-react-client-sdk": "3.9.0",
|
||||||
"lodash": "4.17.21",
|
"lodash": "4.17.21",
|
||||||
"lucide-react": "0.552.0",
|
"lucide-react": "0.552.0",
|
||||||
|
"moment": "2.30.1",
|
||||||
"next": "15.4.10",
|
"next": "15.4.10",
|
||||||
"next-themes": "0.4.6",
|
"next-themes": "0.4.6",
|
||||||
"nuqs": "2.7.2",
|
"nuqs": "2.7.2",
|
||||||
@@ -107,7 +102,7 @@
|
|||||||
"react-markdown": "9.0.3",
|
"react-markdown": "9.0.3",
|
||||||
"react-modal": "3.16.3",
|
"react-modal": "3.16.3",
|
||||||
"react-shepherd": "6.1.9",
|
"react-shepherd": "6.1.9",
|
||||||
"react-window": "2.2.0",
|
"react-window": "1.8.11",
|
||||||
"recharts": "3.3.0",
|
"recharts": "3.3.0",
|
||||||
"rehype-autolink-headings": "7.1.0",
|
"rehype-autolink-headings": "7.1.0",
|
||||||
"rehype-highlight": "7.0.2",
|
"rehype-highlight": "7.0.2",
|
||||||
@@ -117,11 +112,9 @@
|
|||||||
"remark-math": "6.0.0",
|
"remark-math": "6.0.0",
|
||||||
"shepherd.js": "14.5.1",
|
"shepherd.js": "14.5.1",
|
||||||
"sonner": "2.0.7",
|
"sonner": "2.0.7",
|
||||||
"streamdown": "2.1.0",
|
|
||||||
"tailwind-merge": "2.6.0",
|
"tailwind-merge": "2.6.0",
|
||||||
"tailwind-scrollbar": "3.1.0",
|
"tailwind-scrollbar": "3.1.0",
|
||||||
"tailwindcss-animate": "1.0.7",
|
"tailwindcss-animate": "1.0.7",
|
||||||
"use-stick-to-bottom": "1.1.2",
|
|
||||||
"uuid": "11.1.0",
|
"uuid": "11.1.0",
|
||||||
"vaul": "1.1.2",
|
"vaul": "1.1.2",
|
||||||
"zod": "3.25.76",
|
"zod": "3.25.76",
|
||||||
@@ -147,7 +140,7 @@
|
|||||||
"@types/react": "18.3.17",
|
"@types/react": "18.3.17",
|
||||||
"@types/react-dom": "18.3.5",
|
"@types/react-dom": "18.3.5",
|
||||||
"@types/react-modal": "3.16.3",
|
"@types/react-modal": "3.16.3",
|
||||||
"@types/react-window": "2.0.0",
|
"@types/react-window": "1.8.8",
|
||||||
"@vitejs/plugin-react": "5.1.2",
|
"@vitejs/plugin-react": "5.1.2",
|
||||||
"axe-playwright": "2.2.2",
|
"axe-playwright": "2.2.2",
|
||||||
"chromatic": "13.3.3",
|
"chromatic": "13.3.3",
|
||||||
@@ -179,8 +172,7 @@
|
|||||||
},
|
},
|
||||||
"pnpm": {
|
"pnpm": {
|
||||||
"overrides": {
|
"overrides": {
|
||||||
"@opentelemetry/instrumentation": "0.209.0",
|
"@opentelemetry/instrumentation": "0.209.0"
|
||||||
"lodash-es": "4.17.23"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"
|
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"
|
||||||
|
|||||||
1218
autogpt_platform/frontend/pnpm-lock.yaml
generated
1218
autogpt_platform/frontend/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user