Compare commits

..

1 Commits

Author SHA1 Message Date
Otto
7f7a7067ec refactor(copilot): use Pydantic models and match/case in customize_agent
Addresses review feedback from ntindle:

1. Use typed parameters instead of kwargs.get():
   - Added CustomizeAgentInput Pydantic model with field_validator for stripping strings
   - Tool now uses params = CustomizeAgentInput(**kwargs) pattern

2. Use match/case for cleaner pattern matching:
   - Extracted response handling to _handle_customization_result method
   - Uses match result_type: case 'error' | 'clarifying_questions' | _

3. Improved code organization:
   - Split monolithic _execute into smaller focused methods
   - _handle_customization_result for response type handling
   - _save_or_preview_agent for final save/preview logic
2026-02-04 08:53:02 +00:00
285 changed files with 13422 additions and 21975 deletions

View File

@@ -49,7 +49,7 @@ jobs:
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }} - name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
if: github.event_name == 'push' if: github.event_name == 'push'
uses: peter-evans/create-pull-request@v8 uses: peter-evans/create-pull-request@v7
with: with:
add-paths: classic/frontend/build/web add-paths: classic/frontend/build/web
base: ${{ github.ref_name }} base: ${{ github.ref_name }}

View File

@@ -42,7 +42,7 @@ jobs:
- name: Get CI failure details - name: Get CI failure details
id: failure_details id: failure_details
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
const run = await github.rest.actions.getWorkflowRun({ const run = await github.rest.actions.getWorkflowRun({

View File

@@ -41,7 +41,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache - name: Set up Python dependency cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.cache/pypoetry path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -78,7 +78,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v6 uses: actions/setup-node@v4
with: with:
node-version: "22" node-version: "22"
@@ -91,7 +91,7 @@ jobs:
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies - name: Cache frontend dependencies
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.pnpm-store path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }} key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
@@ -124,7 +124,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup # Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache - name: Set up Docker image cache
id: docker-cache id: docker-cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/docker-cache path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes # Use a versioned key for cache invalidation when image list changes
@@ -309,7 +309,6 @@ jobs:
uses: anthropics/claude-code-action@v1 uses: anthropics/claude-code-action@v1
with: with:
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
allowed_bots: "dependabot[bot]"
claude_args: | claude_args: |
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)" --allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
prompt: | prompt: |

View File

@@ -57,7 +57,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache - name: Set up Python dependency cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.cache/pypoetry path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -94,7 +94,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v6 uses: actions/setup-node@v4
with: with:
node-version: "22" node-version: "22"
@@ -107,7 +107,7 @@ jobs:
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies - name: Cache frontend dependencies
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.pnpm-store path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }} key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
@@ -140,7 +140,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup # Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache - name: Set up Docker image cache
id: docker-cache id: docker-cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/docker-cache path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes # Use a versioned key for cache invalidation when image list changes

View File

@@ -39,7 +39,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache - name: Set up Python dependency cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.cache/pypoetry path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -76,7 +76,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v6 uses: actions/setup-node@v4
with: with:
node-version: "22" node-version: "22"
@@ -89,7 +89,7 @@ jobs:
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies - name: Cache frontend dependencies
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.pnpm-store path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }} key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
@@ -132,7 +132,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup # Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache - name: Set up Docker image cache
id: docker-cache id: docker-cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/docker-cache path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes # Use a versioned key for cache invalidation when image list changes

View File

@@ -33,7 +33,7 @@ jobs:
python-version: "3.11" python-version: "3.11"
- name: Set up Python dependency cache - name: Set up Python dependency cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.cache/pypoetry path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -33,7 +33,7 @@ jobs:
python-version: "3.11" python-version: "3.11"
- name: Set up Python dependency cache - name: Set up Python dependency cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.cache/pypoetry path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -38,7 +38,7 @@ jobs:
python-version: "3.11" python-version: "3.11"
- name: Set up Python dependency cache - name: Set up Python dependency cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.cache/pypoetry path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -88,7 +88,7 @@ jobs:
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
- name: Set up Python dependency cache - name: Set up Python dependency cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.cache/pypoetry path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -17,7 +17,7 @@ jobs:
- name: Check comment permissions and deployment status - name: Check comment permissions and deployment status
id: check_status id: check_status
if: github.event_name == 'issue_comment' && github.event.issue.pull_request if: github.event_name == 'issue_comment' && github.event.issue.pull_request
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
const commentBody = context.payload.comment.body.trim(); const commentBody = context.payload.comment.body.trim();
@@ -55,7 +55,7 @@ jobs:
- name: Post permission denied comment - name: Post permission denied comment
if: steps.check_status.outputs.permission_denied == 'true' if: steps.check_status.outputs.permission_denied == 'true'
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
await github.rest.issues.createComment({ await github.rest.issues.createComment({
@@ -68,7 +68,7 @@ jobs:
- name: Get PR details for deployment - name: Get PR details for deployment
id: pr_details id: pr_details
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true' if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
const pr = await github.rest.pulls.get({ const pr = await github.rest.pulls.get({
@@ -98,7 +98,7 @@ jobs:
- name: Post deploy success comment - name: Post deploy success comment
if: steps.check_status.outputs.should_deploy == 'true' if: steps.check_status.outputs.should_deploy == 'true'
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
await github.rest.issues.createComment({ await github.rest.issues.createComment({
@@ -126,7 +126,7 @@ jobs:
- name: Post undeploy success comment - name: Post undeploy success comment
if: steps.check_status.outputs.should_undeploy == 'true' if: steps.check_status.outputs.should_undeploy == 'true'
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
await github.rest.issues.createComment({ await github.rest.issues.createComment({
@@ -139,7 +139,7 @@ jobs:
- name: Check deployment status on PR close - name: Check deployment status on PR close
id: check_pr_close id: check_pr_close
if: github.event_name == 'pull_request' && github.event.action == 'closed' if: github.event_name == 'pull_request' && github.event.action == 'closed'
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
const comments = await github.rest.issues.listComments({ const comments = await github.rest.issues.listComments({
@@ -187,7 +187,7 @@ jobs:
github.event_name == 'pull_request' && github.event_name == 'pull_request' &&
github.event.action == 'closed' && github.event.action == 'closed' &&
steps.check_pr_close.outputs.should_undeploy == 'true' steps.check_pr_close.outputs.should_undeploy == 'true'
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
await github.rest.issues.createComment({ await github.rest.issues.createComment({

View File

@@ -27,22 +27,13 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
outputs: outputs:
cache-key: ${{ steps.cache-key.outputs.key }} cache-key: ${{ steps.cache-key.outputs.key }}
components-changed: ${{ steps.filter.outputs.components }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Check for component changes
uses: dorny/paths-filter@v3
id: filter
with:
filters: |
components:
- 'autogpt_platform/frontend/src/components/**'
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v6 uses: actions/setup-node@v4
with: with:
node-version: "22.18.0" node-version: "22.18.0"
@@ -54,7 +45,7 @@ jobs:
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies - name: Cache dependencies
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.pnpm-store path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }} key: ${{ steps.cache-key.outputs.key }}
@@ -74,7 +65,7 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v6 uses: actions/setup-node@v4
with: with:
node-version: "22.18.0" node-version: "22.18.0"
@@ -82,7 +73,7 @@ jobs:
run: corepack enable run: corepack enable
- name: Restore dependencies cache - name: Restore dependencies cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.pnpm-store path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }} key: ${{ needs.setup.outputs.cache-key }}
@@ -99,11 +90,8 @@ jobs:
chromatic: chromatic:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: setup needs: setup
# Disabled: to re-enable, remove 'false &&' from the condition below # Only run on dev branch pushes or PRs targeting dev
if: >- if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
false
&& (github.ref == 'refs/heads/dev' || github.base_ref == 'dev')
&& needs.setup.outputs.components-changed == 'true'
steps: steps:
- name: Checkout repository - name: Checkout repository
@@ -112,7 +100,7 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v6 uses: actions/setup-node@v4
with: with:
node-version: "22.18.0" node-version: "22.18.0"
@@ -120,7 +108,7 @@ jobs:
run: corepack enable run: corepack enable
- name: Restore dependencies cache - name: Restore dependencies cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.pnpm-store path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }} key: ${{ needs.setup.outputs.cache-key }}
@@ -153,7 +141,7 @@ jobs:
submodules: recursive submodules: recursive
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v6 uses: actions/setup-node@v4
with: with:
node-version: "22.18.0" node-version: "22.18.0"
@@ -176,7 +164,7 @@ jobs:
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
- name: Cache Docker layers - name: Cache Docker layers
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: /tmp/.buildx-cache path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }} key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
@@ -231,7 +219,7 @@ jobs:
fi fi
- name: Restore dependencies cache - name: Restore dependencies cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.pnpm-store path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }} key: ${{ needs.setup.outputs.cache-key }}
@@ -282,7 +270,7 @@ jobs:
submodules: recursive submodules: recursive
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v6 uses: actions/setup-node@v4
with: with:
node-version: "22.18.0" node-version: "22.18.0"
@@ -290,7 +278,7 @@ jobs:
run: corepack enable run: corepack enable
- name: Restore dependencies cache - name: Restore dependencies cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.pnpm-store path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }} key: ${{ needs.setup.outputs.cache-key }}

View File

@@ -32,7 +32,7 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v6 uses: actions/setup-node@v4
with: with:
node-version: "22.18.0" node-version: "22.18.0"
@@ -44,7 +44,7 @@ jobs:
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies - name: Cache dependencies
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.pnpm-store path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }} key: ${{ steps.cache-key.outputs.key }}
@@ -56,7 +56,7 @@ jobs:
run: pnpm install --frozen-lockfile run: pnpm install --frozen-lockfile
types: types:
runs-on: big-boi runs-on: ubuntu-latest
needs: setup needs: setup
strategy: strategy:
fail-fast: false fail-fast: false
@@ -68,7 +68,7 @@ jobs:
submodules: recursive submodules: recursive
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v6 uses: actions/setup-node@v4
with: with:
node-version: "22.18.0" node-version: "22.18.0"
@@ -85,10 +85,10 @@ jobs:
- name: Run docker compose - name: Run docker compose
run: | run: |
docker compose -f ../docker-compose.yml --profile local up -d deps_backend docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
- name: Restore dependencies cache - name: Restore dependencies cache
uses: actions/cache@v5 uses: actions/cache@v4
with: with:
path: ~/.pnpm-store path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }} key: ${{ needs.setup.outputs.cache-key }}

File diff suppressed because it is too large Load Diff

View File

@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.10,<4.0" python = ">=3.10,<4.0"
colorama = "^0.4.6" colorama = "^0.4.6"
cryptography = "^46.0" cryptography = "^45.0"
expiringdict = "^1.2.2" expiringdict = "^1.2.2"
fastapi = "^0.128.0" fastapi = "^0.116.1"
google-cloud-logging = "^3.13.0" google-cloud-logging = "^3.12.1"
launchdarkly-server-sdk = "^9.14.1" launchdarkly-server-sdk = "^9.12.0"
pydantic = "^2.12.5" pydantic = "^2.11.7"
pydantic-settings = "^2.12.0" pydantic-settings = "^2.10.1"
pyjwt = { version = "^2.11.0", extras = ["crypto"] } pyjwt = { version = "^2.10.1", extras = ["crypto"] }
redis = "^6.2.0" redis = "^6.2.0"
supabase = "^2.27.2" supabase = "^2.16.0"
uvicorn = "^0.40.0" uvicorn = "^0.35.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pyright = "^1.1.408" pyright = "^1.1.404"
pytest = "^8.4.1" pytest = "^8.4.1"
pytest-asyncio = "^1.3.0" pytest-asyncio = "^1.1.0"
pytest-mock = "^3.15.1" pytest-mock = "^3.14.1"
pytest-cov = "^7.0.0" pytest-cov = "^6.2.1"
ruff = "^0.15.0" ruff = "^0.12.11"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]

View File

@@ -152,7 +152,6 @@ REPLICATE_API_KEY=
REVID_API_KEY= REVID_API_KEY=
SCREENSHOTONE_API_KEY= SCREENSHOTONE_API_KEY=
UNREAL_SPEECH_API_KEY= UNREAL_SPEECH_API_KEY=
ELEVENLABS_API_KEY=
# Data & Search Services # Data & Search Services
E2B_API_KEY= E2B_API_KEY=

View File

@@ -19,6 +19,3 @@ load-tests/*.json
load-tests/*.log load-tests/*.log
load-tests/node_modules/* load-tests/node_modules/*
migrations/*/rollback*.sql migrations/*/rollback*.sql
# Workspace files
workspaces/

View File

@@ -62,12 +62,10 @@ ENV POETRY_HOME=/opt/poetry \
DEBIAN_FRONTEND=noninteractive DEBIAN_FRONTEND=noninteractive
ENV PATH=/opt/poetry/bin:$PATH ENV PATH=/opt/poetry/bin:$PATH
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks) # Install Python without upgrading system-managed packages
RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y \
python3.13 \ python3.13 \
python3-pip \ python3-pip \
ffmpeg \
imagemagick \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copy only necessary files from builder # Copy only necessary files from builder

View File

@@ -11,7 +11,7 @@ class ChatConfig(BaseSettings):
# OpenAI API Configuration # OpenAI API Configuration
model: str = Field( model: str = Field(
default="anthropic/claude-opus-4.6", description="Default model to use" default="anthropic/claude-opus-4.5", description="Default model to use"
) )
title_model: str = Field( title_model: str = Field(
default="openai/gpt-4o-mini", default="openai/gpt-4o-mini",

View File

@@ -45,7 +45,10 @@ async def create_chat_session(
successfulAgentRuns=SafeJson({}), successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}), successfulAgentSchedules=SafeJson({}),
) )
return await PrismaChatSession.prisma().create(data=data) return await PrismaChatSession.prisma().create(
data=data,
include={"Messages": True},
)
async def update_chat_session( async def update_chat_session(

View File

@@ -18,10 +18,6 @@ class ResponseType(str, Enum):
START = "start" START = "start"
FINISH = "finish" FINISH = "finish"
# Step lifecycle (one LLM API call within a message)
START_STEP = "start-step"
FINISH_STEP = "finish-step"
# Text streaming # Text streaming
TEXT_START = "text-start" TEXT_START = "text-start"
TEXT_DELTA = "text-delta" TEXT_DELTA = "text-delta"
@@ -61,16 +57,6 @@ class StreamStart(StreamBaseResponse):
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream", description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
) )
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-protocol fields like taskId."""
import json
data: dict[str, Any] = {
"type": self.type.value,
"messageId": self.messageId,
}
return f"data: {json.dumps(data)}\n\n"
class StreamFinish(StreamBaseResponse): class StreamFinish(StreamBaseResponse):
"""End of message/stream.""" """End of message/stream."""
@@ -78,26 +64,6 @@ class StreamFinish(StreamBaseResponse):
type: ResponseType = ResponseType.FINISH type: ResponseType = ResponseType.FINISH
class StreamStartStep(StreamBaseResponse):
"""Start of a step (one LLM API call within a message).
The AI SDK uses this to add a step-start boundary to message.parts,
enabling visual separation between multiple LLM calls in a single message.
"""
type: ResponseType = ResponseType.START_STEP
class StreamFinishStep(StreamBaseResponse):
"""End of a step (one LLM API call within a message).
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
so the next LLM call in a tool-call continuation starts with clean state.
"""
type: ResponseType = ResponseType.FINISH_STEP
# ========== Text Streaming ========== # ========== Text Streaming ==========
@@ -151,7 +117,7 @@ class StreamToolOutputAvailable(StreamBaseResponse):
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
toolCallId: str = Field(..., description="Tool call ID this responds to") toolCallId: str = Field(..., description="Tool call ID this responds to")
output: str | dict[str, Any] = Field(..., description="Tool execution output") output: str | dict[str, Any] = Field(..., description="Tool execution output")
# Keep these for internal backend use # Additional fields for internal use (not part of AI SDK spec but useful)
toolName: str | None = Field( toolName: str | None = Field(
default=None, description="Name of the tool that was executed" default=None, description="Name of the tool that was executed"
) )
@@ -159,17 +125,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
default=True, description="Whether the tool execution succeeded" default=True, description="Whether the tool execution succeeded"
) )
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-spec fields."""
import json
data = {
"type": self.type.value,
"toolCallId": self.toolCallId,
"output": self.output,
}
return f"data: {json.dumps(data)}\n\n"
# ========== Other ========== # ========== Other ==========

View File

@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator
from typing import Annotated from typing import Annotated
from autogpt_libs import auth from autogpt_libs import auth
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
@@ -17,29 +17,7 @@ from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat from .response_model import StreamFinish, StreamHeartbeat, StreamStart
from .tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
AgentPreviewResponse,
AgentSavedResponse,
AgentsFoundResponse,
BlockListResponse,
BlockOutputResponse,
ClarificationNeededResponse,
DocPageResponse,
DocSearchResultsResponse,
ErrorResponse,
ExecutionStartedResponse,
InputValidationErrorResponse,
NeedLoginResponse,
NoResultsResponse,
OperationInProgressResponse,
OperationPendingResponse,
OperationStartedResponse,
SetupRequirementsResponse,
UnderstandingUpdatedResponse,
)
config = ChatConfig() config = ChatConfig()
@@ -288,36 +266,12 @@ async def stream_chat_post(
""" """
import asyncio import asyncio
import time
stream_start_time = time.perf_counter()
log_meta = {"component": "ChatStream", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
f"user={user_id}, message_len={len(request.message)}",
extra={"json_fields": log_meta},
)
session = await _validate_and_get_session(session_id, user_id) session = await _validate_and_get_session(session_id, user_id)
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
}
},
)
# Create a task in the stream registry for reconnection support # Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4()) task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4()) operation_id = str(uuid_module.uuid4())
log_meta["task_id"] = task_id
task_create_start = time.perf_counter()
await stream_registry.create_task( await stream_registry.create_task(
task_id=task_id, task_id=task_id,
session_id=session_id, session_id=session_id,
@@ -326,28 +280,14 @@ async def stream_chat_post(
tool_name="chat", tool_name="chat",
operation_id=operation_id, operation_id=operation_id,
) )
logger.info(
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
}
},
)
# Background task that runs the AI generation independently of SSE connection # Background task that runs the AI generation independently of SSE connection
async def run_ai_generation(): async def run_ai_generation():
import time as time_module
gen_start_time = time_module.perf_counter()
logger.info(
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
extra={"json_fields": log_meta},
)
first_chunk_time, ttfc = None, None
chunk_count = 0
try: try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
async for chunk in chat_service.stream_chat_completion( async for chunk in chat_service.stream_chat_completion(
session_id, session_id,
request.message, request.message,
@@ -355,79 +295,25 @@ async def stream_chat_post(
user_id=user_id, user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context, context=request.context,
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
): ):
chunk_count += 1
if first_chunk_time is None:
first_chunk_time = time_module.perf_counter()
ttfc = first_chunk_time - gen_start_time
logger.info(
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"chunk_type": type(chunk).__name__,
"time_to_first_chunk_ms": ttfc * 1000,
}
},
)
# Write to Redis (subscribers will receive via XREAD) # Write to Redis (subscribers will receive via XREAD)
await stream_registry.publish_chunk(task_id, chunk) await stream_registry.publish_chunk(task_id, chunk)
gen_end_time = time_module.perf_counter() # Mark task as completed
total_time = (gen_end_time - gen_start_time) * 1000
logger.info(
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
f"task={task_id}, session={session_id}, "
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"time_to_first_chunk_ms": (
ttfc * 1000 if ttfc is not None else None
),
"n_chunks": chunk_count,
}
},
)
await stream_registry.mark_task_completed(task_id, "completed") await stream_registry.mark_task_completed(task_id, "completed")
except Exception as e: except Exception as e:
elapsed = time_module.perf_counter() - gen_start_time
logger.error( logger.error(
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}", f"Error in background AI generation for session {session_id}: {e}"
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed * 1000,
"error": str(e),
}
},
) )
await stream_registry.mark_task_completed(task_id, "failed") await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task # Start the AI generation in a background task
bg_task = asyncio.create_task(run_ai_generation()) bg_task = asyncio.create_task(run_ai_generation())
await stream_registry.set_task_asyncio_task(task_id, bg_task) await stream_registry.set_task_asyncio_task(task_id, bg_task)
setup_time = (time.perf_counter() - stream_start_time) * 1000
logger.info(
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
# SSE endpoint that subscribes to the task's stream # SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]: async def event_generator() -> AsyncGenerator[str, None]:
import time as time_module
event_gen_start = time_module.perf_counter()
logger.info(
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
f"user={user_id}",
extra={"json_fields": log_meta},
)
subscriber_queue = None subscriber_queue = None
first_chunk_yielded = False
chunks_yielded = 0
try: try:
# Subscribe to the task stream (this replays existing messages + live updates) # Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task( subscriber_queue = await stream_registry.subscribe_to_task(
@@ -442,70 +328,22 @@ async def stream_chat_post(
return return
# Read from the subscriber queue and yield to SSE # Read from the subscriber queue and yield to SSE
logger.info(
"[TIMING] Starting to read from subscriber_queue",
extra={"json_fields": log_meta},
)
while True: while True:
try: try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0) chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
chunks_yielded += 1
if not first_chunk_yielded:
first_chunk_yielded = True
elapsed = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
f"type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"chunk_type": type(chunk).__name__,
"elapsed_ms": elapsed * 1000,
}
},
)
yield chunk.to_sse() yield chunk.to_sse()
# Check for finish signal # Check for finish signal
if isinstance(chunk, StreamFinish): if isinstance(chunk, StreamFinish):
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
f"n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"total_time_ms": total_time * 1000,
}
},
)
break break
except asyncio.TimeoutError: except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse() yield StreamHeartbeat().to_sse()
except GeneratorExit: except GeneratorExit:
logger.info(
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"reason": "client_disconnect",
}
},
)
pass # Client disconnected - background task continues pass # Client disconnected - background task continues
except Exception as e: except Exception as e:
elapsed = (time_module.perf_counter() - event_gen_start) * 1000 logger.error(f"Error in SSE stream for task {task_id}: {e}")
logger.error(
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
extra={
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
},
)
finally: finally:
# Unsubscribe when client disconnects or stream ends to prevent resource leak # Unsubscribe when client disconnects or stream ends to prevent resource leak
if subscriber_queue is not None: if subscriber_queue is not None:
@@ -519,18 +357,6 @@ async def stream_chat_post(
exc_info=True, exc_info=True,
) )
# AI SDK protocol termination - always yield even if unsubscribe fails # AI SDK protocol termination - always yield even if unsubscribe fails
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time * 1000,
"chunks_yielded": chunks_yielded,
}
},
)
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse( return StreamingResponse(
@@ -548,90 +374,63 @@ async def stream_chat_post(
@router.get( @router.get(
"/sessions/{session_id}/stream", "/sessions/{session_id}/stream",
) )
async def resume_session_stream( async def stream_chat_get(
session_id: str, session_id: str,
message: Annotated[str, Query(min_length=1, max_length=10000)],
user_id: str | None = Depends(auth.get_user_id), user_id: str | None = Depends(auth.get_user_id),
is_user_message: bool = Query(default=True),
): ):
""" """
Resume an active stream for a session. Stream chat responses for a session (GET - legacy endpoint).
Called by the AI SDK's ``useChat(resume: true)`` on page load. Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
Checks for an active (in-progress) task on the session and either replays - Text fragments as they are generated
the full SSE stream or returns 204 No Content if nothing is running. - Tool call UI elements (if invoked)
- Tool execution results
Args: Args:
session_id: The chat session identifier. session_id: The chat session identifier to associate with the streamed messages.
message: The user's new message to process.
user_id: Optional authenticated user ID. user_id: Optional authenticated user ID.
is_user_message: Whether the message is a user message.
Returns: Returns:
StreamingResponse (SSE) when an active stream exists, StreamingResponse: SSE-formatted response chunks.
or 204 No Content when there is nothing to resume.
""" """
import asyncio session = await _validate_and_get_session(session_id, user_id)
active_task, _last_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if not active_task:
return Response(status_code=204)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=active_task.task_id,
user_id=user_id,
last_message_id="0-0", # Full replay so useChat rebuilds the message
)
if subscriber_queue is None:
return Response(status_code=204)
async def event_generator() -> AsyncGenerator[str, None]: async def event_generator() -> AsyncGenerator[str, None]:
chunk_count = 0 chunk_count = 0
first_chunk_type: str | None = None first_chunk_type: str | None = None
try: async for chunk in chat_service.stream_chat_completion(
while True: session_id,
try: message,
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0) is_user_message=is_user_message,
if chunk_count < 3: user_id=user_id,
logger.info( session=session, # Pass pre-fetched session to avoid double-fetch
"Resume stream chunk", ):
extra={ if chunk_count < 3:
"session_id": session_id, logger.info(
"chunk_type": str(chunk.type), "Chat stream chunk",
}, extra={
) "session_id": session_id,
if not first_chunk_type: "chunk_type": str(chunk.type),
first_chunk_type = str(chunk.type) },
chunk_count += 1
yield chunk.to_sse()
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
except GeneratorExit:
pass
except Exception as e:
logger.error(f"Error in resume stream for session {session_id}: {e}")
finally:
try:
await stream_registry.unsubscribe_from_task(
active_task.task_id, subscriber_queue
) )
except Exception as unsub_err: if not first_chunk_type:
logger.error( first_chunk_type = str(chunk.type)
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}", chunk_count += 1
exc_info=True, yield chunk.to_sse()
) logger.info(
logger.info( "Chat stream completed",
"Resume stream completed", extra={
extra={ "session_id": session_id,
"session_id": session_id, "chunk_count": chunk_count,
"n_chunks": chunk_count, "first_chunk_type": first_chunk_type,
"first_chunk_type": first_chunk_type, },
}, )
) # AI SDK protocol termination
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse( return StreamingResponse(
event_generator(), event_generator(),
@@ -639,8 +438,8 @@ async def resume_session_stream(
headers={ headers={
"Cache-Control": "no-cache", "Cache-Control": "no-cache",
"Connection": "keep-alive", "Connection": "keep-alive",
"X-Accel-Buffering": "no", "X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", "x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
}, },
) )
@@ -952,42 +751,3 @@ async def health_check() -> dict:
"service": "chat", "service": "chat",
"version": "0.1.0", "version": "0.1.0",
} }
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
ToolResponseUnion = (
AgentsFoundResponse
| NoResultsResponse
| AgentDetailsResponse
| SetupRequirementsResponse
| ExecutionStartedResponse
| NeedLoginResponse
| ErrorResponse
| InputValidationErrorResponse
| AgentOutputResponse
| UnderstandingUpdatedResponse
| AgentPreviewResponse
| AgentSavedResponse
| ClarificationNeededResponse
| BlockListResponse
| BlockOutputResponse
| DocSearchResultsResponse
| DocPageResponse
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
)
@router.get(
"/schema/tool-responses",
response_model=ToolResponseUnion,
include_in_schema=True,
summary="[Dummy] Tool response type export for codegen",
description="This endpoint is not meant to be called. It exists solely to "
"expose tool response models in the OpenAPI schema for frontend codegen.",
)
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
"""Never called at runtime. Exists only so Orval generates TS types."""
raise HTTPException(status_code=501, detail="Schema-only endpoint")

View File

@@ -33,7 +33,7 @@ from backend.data.understanding import (
get_business_understanding, get_business_understanding,
) )
from backend.util.exceptions import NotFoundError from backend.util.exceptions import NotFoundError
from backend.util.settings import AppEnvironment, Settings from backend.util.settings import Settings
from . import db as chat_db from . import db as chat_db
from . import stream_registry from . import stream_registry
@@ -52,10 +52,8 @@ from .response_model import (
StreamBaseResponse, StreamBaseResponse,
StreamError, StreamError,
StreamFinish, StreamFinish,
StreamFinishStep,
StreamHeartbeat, StreamHeartbeat,
StreamStart, StreamStart,
StreamStartStep,
StreamTextDelta, StreamTextDelta,
StreamTextEnd, StreamTextEnd,
StreamTextStart, StreamTextStart,
@@ -224,18 +222,8 @@ async def _get_system_prompt_template(context: str) -> str:
try: try:
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt # cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
# Use asyncio.to_thread to avoid blocking the event loop # Use asyncio.to_thread to avoid blocking the event loop
# In non-production environments, fetch the latest prompt version
# instead of the production-labeled version for easier testing
label = (
None
if settings.config.app_env == AppEnvironment.PRODUCTION
else "latest"
)
prompt = await asyncio.to_thread( prompt = await asyncio.to_thread(
langfuse.get_prompt, langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0
config.langfuse_prompt_name,
label=label,
cache_ttl_seconds=0,
) )
return prompt.compile(users_information=context) return prompt.compile(users_information=context)
except Exception as e: except Exception as e:
@@ -353,10 +341,6 @@ async def stream_chat_completion(
retry_count: int = 0, retry_count: int = 0,
session: ChatSession | None = None, session: ChatSession | None = None,
context: dict[str, str] | None = None, # {url: str, content: str} context: dict[str, str] | None = None, # {url: str, content: str}
_continuation_message_id: (
str | None
) = None, # Internal: reuse message ID for tool call continuations
_task_id: str | None = None, # Internal: task ID for SSE reconnection support
) -> AsyncGenerator[StreamBaseResponse, None]: ) -> AsyncGenerator[StreamBaseResponse, None]:
"""Main entry point for streaming chat completions with database handling. """Main entry point for streaming chat completions with database handling.
@@ -377,45 +361,21 @@ async def stream_chat_completion(
ValueError: If max_context_messages is exceeded ValueError: If max_context_messages is exceeded
""" """
completion_start = time.monotonic()
# Build log metadata for structured logging
log_meta = {"component": "ChatService", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info( logger.info(
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, " f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
extra={
"json_fields": {
**log_meta,
"message_len": len(message) if message else 0,
"is_user_message": is_user_message,
}
},
) )
# Only fetch from Redis if session not provided (initial call) # Only fetch from Redis if session not provided (initial call)
if session is None: if session is None:
fetch_start = time.monotonic()
session = await get_chat_session(session_id, user_id) session = await get_chat_session(session_id, user_id)
fetch_time = (time.monotonic() - fetch_start) * 1000
logger.info( logger.info(
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, " f"Fetched session from Redis: {session.session_id if session else 'None'}, "
f"n_messages={len(session.messages) if session else 0}", f"message_count={len(session.messages) if session else 0}"
extra={
"json_fields": {
**log_meta,
"duration_ms": fetch_time,
"n_messages": len(session.messages) if session else 0,
}
},
) )
else: else:
logger.info( logger.info(
f"[TIMING] Using provided session, messages={len(session.messages)}", f"Using provided session object: {session.session_id}, "
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}}, f"message_count={len(session.messages)}"
) )
if not session: if not session:
@@ -436,25 +396,17 @@ async def stream_chat_completion(
# Track user message in PostHog # Track user message in PostHog
if is_user_message: if is_user_message:
posthog_start = time.monotonic()
track_user_message( track_user_message(
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
message_length=len(message), message_length=len(message),
) )
posthog_time = (time.monotonic() - posthog_start) * 1000
logger.info(
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
)
upsert_start = time.monotonic()
session = await upsert_chat_session(session)
upsert_time = (time.monotonic() - upsert_start) * 1000
logger.info( logger.info(
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms", f"Upserting session: {session.session_id} with user id {session.user_id}, "
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}}, f"message_count={len(session.messages)}"
) )
session = await upsert_chat_session(session)
assert session, "Session not found" assert session, "Session not found"
# Generate title for new sessions on first user message (non-blocking) # Generate title for new sessions on first user message (non-blocking)
@@ -492,13 +444,7 @@ async def stream_chat_completion(
asyncio.create_task(_update_title()) asyncio.create_task(_update_title())
# Build system prompt with business understanding # Build system prompt with business understanding
prompt_start = time.monotonic()
system_prompt, understanding = await _build_system_prompt(user_id) system_prompt, understanding = await _build_system_prompt(user_id)
prompt_time = (time.monotonic() - prompt_start) * 1000
logger.info(
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
)
# Initialize variables for streaming # Initialize variables for streaming
assistant_response = ChatMessage( assistant_response = ChatMessage(
@@ -523,27 +469,13 @@ async def stream_chat_completion(
# Generate unique IDs for AI SDK protocol # Generate unique IDs for AI SDK protocol
import uuid as uuid_module import uuid as uuid_module
is_continuation = _continuation_message_id is not None message_id = str(uuid_module.uuid4())
message_id = _continuation_message_id or str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4()) text_block_id = str(uuid_module.uuid4())
# Only yield message start for the initial call, not for continuations. # Yield message start
setup_time = (time.monotonic() - completion_start) * 1000 yield StreamStart(messageId=message_id)
logger.info(
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
if not is_continuation:
yield StreamStart(messageId=message_id, taskId=_task_id)
# Emit start-step before each LLM call (AI SDK uses this to add step boundaries)
yield StreamStartStep()
try: try:
logger.info(
"[TIMING] Calling _stream_chat_chunks",
extra={"json_fields": log_meta},
)
async for chunk in _stream_chat_chunks( async for chunk in _stream_chat_chunks(
session=session, session=session,
tools=tools, tools=tools,
@@ -643,10 +575,6 @@ async def stream_chat_completion(
) )
yield chunk yield chunk
elif isinstance(chunk, StreamFinish): elif isinstance(chunk, StreamFinish):
if has_done_tool_call:
# Tool calls happened — close the step but don't send message-level finish.
# The continuation will open a new step, and finish will come at the end.
yield StreamFinishStep()
if not has_done_tool_call: if not has_done_tool_call:
# Emit text-end before finish if we received text but haven't closed it # Emit text-end before finish if we received text but haven't closed it
if has_received_text and not text_streaming_ended: if has_received_text and not text_streaming_ended:
@@ -678,8 +606,6 @@ async def stream_chat_completion(
has_saved_assistant_message = True has_saved_assistant_message = True
has_yielded_end = True has_yielded_end = True
# Emit finish-step before finish (resets AI SDK text/reasoning state)
yield StreamFinishStep()
yield chunk yield chunk
elif isinstance(chunk, StreamError): elif isinstance(chunk, StreamError):
has_yielded_error = True has_yielded_error = True
@@ -692,9 +618,6 @@ async def stream_chat_completion(
total_tokens=chunk.totalTokens, total_tokens=chunk.totalTokens,
) )
) )
elif isinstance(chunk, StreamHeartbeat):
# Pass through heartbeat to keep SSE connection alive
yield chunk
else: else:
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True) logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
@@ -729,10 +652,6 @@ async def stream_chat_completion(
logger.info( logger.info(
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}" f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
) )
# Close the current step before retrying so the recursive call's
# StreamStartStep doesn't produce unbalanced step events.
if not has_yielded_end:
yield StreamFinishStep()
should_retry = True should_retry = True
else: else:
# Non-retryable error or max retries exceeded # Non-retryable error or max retries exceeded
@@ -768,7 +687,6 @@ async def stream_chat_completion(
error_response = StreamError(errorText=error_message) error_response = StreamError(errorText=error_message)
yield error_response yield error_response
if not has_yielded_end: if not has_yielded_end:
yield StreamFinishStep()
yield StreamFinish() yield StreamFinish()
return return
@@ -783,8 +701,6 @@ async def stream_chat_completion(
retry_count=retry_count + 1, retry_count=retry_count + 1,
session=session, session=session,
context=context, context=context,
_continuation_message_id=message_id, # Reuse message ID since start was already sent
_task_id=_task_id,
): ):
yield chunk yield chunk
return # Exit after retry to avoid double-saving in finally block return # Exit after retry to avoid double-saving in finally block
@@ -854,8 +770,6 @@ async def stream_chat_completion(
session=session, # Pass session object to avoid Redis refetch session=session, # Pass session object to avoid Redis refetch
context=context, context=context,
tool_call_response=str(tool_response_messages), tool_call_response=str(tool_response_messages),
_continuation_message_id=message_id, # Reuse message ID to avoid duplicates
_task_id=_task_id,
): ):
yield chunk yield chunk
@@ -966,21 +880,9 @@ async def _stream_chat_chunks(
SSE formatted JSON response objects SSE formatted JSON response objects
""" """
import time as time_module
stream_chunks_start = time_module.perf_counter()
model = config.model model = config.model
# Build log metadata for structured logging logger.info("Starting pure chat stream")
log_meta = {"component": "ChatService", "session_id": session.session_id}
if session.user_id:
log_meta["user_id"] = session.user_id
logger.info(
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
f"user={session.user_id}, n_messages={len(session.messages)}",
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
)
messages = session.to_openai_messages() messages = session.to_openai_messages()
if system_prompt: if system_prompt:
@@ -991,18 +893,12 @@ async def _stream_chat_chunks(
messages = [system_message] + messages messages = [system_message] + messages
# Apply context window management # Apply context window management
context_start = time_module.perf_counter()
context_result = await _manage_context_window( context_result = await _manage_context_window(
messages=messages, messages=messages,
model=model, model=model,
api_key=config.api_key, api_key=config.api_key,
base_url=config.base_url, base_url=config.base_url,
) )
context_time = (time_module.perf_counter() - context_start) * 1000
logger.info(
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
)
if context_result.error: if context_result.error:
if "System prompt dropped" in context_result.error: if "System prompt dropped" in context_result.error:
@@ -1037,19 +933,9 @@ async def _stream_chat_chunks(
while retry_count <= MAX_RETRIES: while retry_count <= MAX_RETRIES:
try: try:
elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000
retry_info = (
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
)
logger.info( logger.info(
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}", f"Creating OpenAI chat completion stream..."
extra={ f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"retry_count": retry_count,
}
},
) )
# Build extra_body for OpenRouter tracing and PostHog analytics # Build extra_body for OpenRouter tracing and PostHog analytics
@@ -1066,7 +952,6 @@ async def _stream_chat_chunks(
:128 :128
] # OpenRouter limit ] # OpenRouter limit
api_call_start = time_module.perf_counter()
stream = await client.chat.completions.create( stream = await client.chat.completions.create(
model=model, model=model,
messages=cast(list[ChatCompletionMessageParam], messages), messages=cast(list[ChatCompletionMessageParam], messages),
@@ -1076,11 +961,6 @@ async def _stream_chat_chunks(
stream_options=ChatCompletionStreamOptionsParam(include_usage=True), stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
extra_body=extra_body, extra_body=extra_body,
) )
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
logger.info(
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
)
# Variables to accumulate tool calls # Variables to accumulate tool calls
tool_calls: list[dict[str, Any]] = [] tool_calls: list[dict[str, Any]] = []
@@ -1091,13 +971,10 @@ async def _stream_chat_chunks(
# Track if we've started the text block # Track if we've started the text block
text_started = False text_started = False
first_content_chunk = True
chunk_count = 0
# Process the stream # Process the stream
chunk: ChatCompletionChunk chunk: ChatCompletionChunk
async for chunk in stream: async for chunk in stream:
chunk_count += 1
if chunk.usage: if chunk.usage:
yield StreamUsage( yield StreamUsage(
promptTokens=chunk.usage.prompt_tokens, promptTokens=chunk.usage.prompt_tokens,
@@ -1120,23 +997,6 @@ async def _stream_chat_chunks(
if not text_started and text_block_id: if not text_started and text_block_id:
yield StreamTextStart(id=text_block_id) yield StreamTextStart(id=text_block_id)
text_started = True text_started = True
# Log timing for first content chunk
if first_content_chunk:
first_content_chunk = False
ttfc = (
time_module.perf_counter() - api_call_start
) * 1000
logger.info(
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
f"(since API call), n_chunks={chunk_count}",
extra={
"json_fields": {
**log_meta,
"time_to_first_chunk_ms": ttfc,
"n_chunks": chunk_count,
}
},
)
# Stream the text delta # Stream the text delta
text_response = StreamTextDelta( text_response = StreamTextDelta(
id=text_block_id or "", id=text_block_id or "",
@@ -1193,21 +1053,7 @@ async def _stream_chat_chunks(
toolName=tool_calls[idx]["function"]["name"], toolName=tool_calls[idx]["function"]["name"],
) )
emitted_start_for_idx.add(idx) emitted_start_for_idx.add(idx)
stream_duration = time_module.perf_counter() - api_call_start logger.info(f"Stream complete. Finish reason: {finish_reason}")
logger.info(
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
f"duration={stream_duration:.2f}s, "
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
extra={
"json_fields": {
**log_meta,
"stream_duration_ms": stream_duration * 1000,
"finish_reason": finish_reason,
"n_chunks": chunk_count,
"n_tool_calls": len(tool_calls),
}
},
)
# Yield all accumulated tool calls after the stream is complete # Yield all accumulated tool calls after the stream is complete
# This ensures all tool call arguments have been fully received # This ensures all tool call arguments have been fully received
@@ -1227,12 +1073,6 @@ async def _stream_chat_chunks(
# Re-raise to trigger retry logic in the parent function # Re-raise to trigger retry logic in the parent function
raise raise
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
logger.info(
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
f"session={session.session_id}, user={session.user_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
yield StreamFinish() yield StreamFinish()
return return
except Exception as e: except Exception as e:
@@ -1712,7 +1552,6 @@ async def _execute_long_running_tool_with_streaming(
task_id, task_id,
StreamError(errorText=str(e)), StreamError(errorText=str(e)),
) )
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish()) await stream_registry.publish_chunk(task_id, StreamFinish())
await _update_pending_operation( await _update_pending_operation(
@@ -1970,7 +1809,6 @@ async def _generate_llm_continuation_with_streaming(
# Publish start event # Publish start event
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id)) await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
await stream_registry.publish_chunk(task_id, StreamStartStep())
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id)) await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
# Stream the response # Stream the response
@@ -1994,7 +1832,6 @@ async def _generate_llm_continuation_with_streaming(
# Publish end events # Publish end events
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id)) await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
await stream_registry.publish_chunk(task_id, StreamFinishStep())
if assistant_content: if assistant_content:
# Reload session from DB to avoid race condition with user messages # Reload session from DB to avoid race condition with user messages
@@ -2036,5 +1873,4 @@ async def _generate_llm_continuation_with_streaming(
task_id, task_id,
StreamError(errorText=f"Failed to generate response: {e}"), StreamError(errorText=f"Failed to generate response: {e}"),
) )
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish()) await stream_registry.publish_chunk(task_id, StreamFinish())

View File

@@ -104,24 +104,6 @@ async def create_task(
Returns: Returns:
The created ActiveTask instance (metadata only) The created ActiveTask instance (metadata only)
""" """
import time
start_time = time.perf_counter()
# Build log metadata for structured logging
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"session_id": session_id,
}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
extra={"json_fields": log_meta},
)
task = ActiveTask( task = ActiveTask(
task_id=task_id, task_id=task_id,
session_id=session_id, session_id=session_id,
@@ -132,18 +114,10 @@ async def create_task(
) )
# Store metadata in Redis # Store metadata in Redis
redis_start = time.perf_counter()
redis = await get_redis_async() redis = await get_redis_async()
redis_time = (time.perf_counter() - redis_start) * 1000
logger.info(
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
)
meta_key = _get_task_meta_key(task_id) meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id) op_key = _get_operation_mapping_key(operation_id)
hset_start = time.perf_counter()
await redis.hset( # type: ignore[misc] await redis.hset( # type: ignore[misc]
meta_key, meta_key,
mapping={ mapping={
@@ -157,22 +131,12 @@ async def create_task(
"created_at": task.created_at.isoformat(), "created_at": task.created_at.isoformat(),
}, },
) )
hset_time = (time.perf_counter() - hset_start) * 1000
logger.info(
f"[TIMING] redis.hset took {hset_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
)
await redis.expire(meta_key, config.stream_ttl) await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups # Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl) await redis.set(op_key, task_id, ex=config.stream_ttl)
total_time = (time.perf_counter() - start_time) * 1000 logger.debug(f"Created task {task_id} for session {session_id}")
logger.info(
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
return task return task
@@ -192,60 +156,26 @@ async def publish_chunk(
Returns: Returns:
The Redis Stream message ID The Redis Stream message ID
""" """
import time
start_time = time.perf_counter()
chunk_type = type(chunk).__name__
chunk_json = chunk.model_dump_json() chunk_json = chunk.model_dump_json()
message_id = "0-0" message_id = "0-0"
# Build log metadata
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"chunk_type": chunk_type,
}
try: try:
redis = await get_redis_async() redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id) stream_key = _get_task_stream_key(task_id)
# Write to Redis Stream for persistence and real-time delivery # Write to Redis Stream for persistence and real-time delivery
xadd_start = time.perf_counter()
raw_id = await redis.xadd( raw_id = await redis.xadd(
stream_key, stream_key,
{"data": chunk_json}, {"data": chunk_json},
maxlen=config.stream_max_length, maxlen=config.stream_max_length,
) )
xadd_time = (time.perf_counter() - xadd_start) * 1000
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode() message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Set TTL on stream to match task metadata TTL # Set TTL on stream to match task metadata TTL
await redis.expire(stream_key, config.stream_ttl) await redis.expire(stream_key, config.stream_ttl)
total_time = (time.perf_counter() - start_time) * 1000
# Only log timing for significant chunks or slow operations
if (
chunk_type
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
or total_time > 50
):
logger.info(
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"xadd_time_ms": xadd_time,
"message_id": message_id,
}
},
)
except Exception as e: except Exception as e:
elapsed = (time.perf_counter() - start_time) * 1000
logger.error( logger.error(
f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}", f"Failed to publish chunk for task {task_id}: {e}",
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
exc_info=True, exc_info=True,
) )
@@ -270,61 +200,24 @@ async def subscribe_to_task(
An asyncio Queue that will receive stream chunks, or None if task not found An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access or user doesn't have access
""" """
import time
start_time = time.perf_counter()
# Build log metadata
log_meta = {"component": "StreamRegistry", "task_id": task_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
)
redis_start = time.perf_counter()
redis = await get_redis_async() redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id) meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
hgetall_time = (time.perf_counter() - redis_start) * 1000
logger.info(
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
)
if not meta: if not meta:
elapsed = (time.perf_counter() - start_time) * 1000 logger.debug(f"Task {task_id} not found in Redis")
logger.info(
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"reason": "task_not_found",
}
},
)
return None return None
# Note: Redis client uses decode_responses=True, so keys are strings # Note: Redis client uses decode_responses=True, so keys are strings
task_status = meta.get("status", "") task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None task_user_id = meta.get("user_id", "") or None
log_meta["session_id"] = meta.get("session_id", "")
# Validate ownership - if task has an owner, requester must match # Validate ownership - if task has an owner, requester must match
if task_user_id: if task_user_id:
if user_id != task_user_id: if user_id != task_user_id:
logger.warning( logger.warning(
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}", f"User {user_id} denied access to task {task_id} "
extra={ f"owned by {task_user_id}"
"json_fields": {
**log_meta,
"task_owner": task_user_id,
"reason": "access_denied",
}
},
) )
return None return None
@@ -332,19 +225,7 @@ async def subscribe_to_task(
stream_key = _get_task_stream_key(task_id) stream_key = _get_task_stream_key(task_id)
# Step 1: Replay messages from Redis Stream # Step 1: Replay messages from Redis Stream
xread_start = time.perf_counter()
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000) messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
xread_time = (time.perf_counter() - xread_start) * 1000
logger.info(
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
extra={
"json_fields": {
**log_meta,
"duration_ms": xread_time,
"task_status": task_status,
}
},
)
replayed_count = 0 replayed_count = 0
replay_last_id = last_message_id replay_last_id = last_message_id
@@ -363,48 +244,19 @@ async def subscribe_to_task(
except Exception as e: except Exception as e:
logger.warning(f"Failed to replay message: {e}") logger.warning(f"Failed to replay message: {e}")
logger.info( logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
extra={
"json_fields": {
**log_meta,
"n_messages_replayed": replayed_count,
"replay_last_id": replay_last_id,
}
},
)
# Step 2: If task is still running, start stream listener for live updates # Step 2: If task is still running, start stream listener for live updates
if task_status == "running": if task_status == "running":
logger.info(
"[TIMING] Task still running, starting _stream_listener",
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
listener_task = asyncio.create_task( listener_task = asyncio.create_task(
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta) _stream_listener(task_id, subscriber_queue, replay_last_id)
) )
# Track listener task for cleanup on unsubscribe # Track listener task for cleanup on unsubscribe
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task) _listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
else: else:
# Task is completed/failed - add finish marker # Task is completed/failed - add finish marker
logger.info(
f"[TIMING] Task already {task_status}, adding StreamFinish",
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
await subscriber_queue.put(StreamFinish()) await subscriber_queue.put(StreamFinish())
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
f"n_messages_replayed={replayed_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"n_messages_replayed": replayed_count,
}
},
)
return subscriber_queue return subscriber_queue
@@ -412,7 +264,6 @@ async def _stream_listener(
task_id: str, task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse], subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str, last_replayed_id: str,
log_meta: dict | None = None,
) -> None: ) -> None:
"""Listen to Redis Stream for new messages using blocking XREAD. """Listen to Redis Stream for new messages using blocking XREAD.
@@ -423,27 +274,10 @@ async def _stream_listener(
task_id: Task ID to listen for task_id: Task ID to listen for
subscriber_queue: Queue to deliver messages to subscriber_queue: Queue to deliver messages to
last_replayed_id: Last message ID from replay (continue from here) last_replayed_id: Last message ID from replay (continue from here)
log_meta: Structured logging metadata
""" """
import time
start_time = time.perf_counter()
# Use provided log_meta or build minimal one
if log_meta is None:
log_meta = {"component": "StreamRegistry", "task_id": task_id}
logger.info(
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
)
queue_id = id(subscriber_queue) queue_id = id(subscriber_queue)
# Track the last successfully delivered message ID for recovery hints # Track the last successfully delivered message ID for recovery hints
last_delivered_id = last_replayed_id last_delivered_id = last_replayed_id
messages_delivered = 0
first_message_time = None
xread_count = 0
try: try:
redis = await get_redis_async() redis = await get_redis_async()
@@ -453,39 +287,9 @@ async def _stream_listener(
while True: while True:
# Block for up to 30 seconds waiting for new messages # Block for up to 30 seconds waiting for new messages
# This allows periodic checking if task is still running # This allows periodic checking if task is still running
xread_start = time.perf_counter()
xread_count += 1
messages = await redis.xread( messages = await redis.xread(
{stream_key: current_id}, block=30000, count=100 {stream_key: current_id}, block=30000, count=100
) )
xread_time = (time.perf_counter() - xread_start) * 1000
if messages:
msg_count = sum(len(msgs) for _, msgs in messages)
logger.info(
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"xread_count": xread_count,
"n_messages": msg_count,
"duration_ms": xread_time,
}
},
)
elif xread_time > 1000:
# Only log timeouts (30s blocking)
logger.info(
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"xread_count": xread_count,
"duration_ms": xread_time,
"reason": "timeout",
}
},
)
if not messages: if not messages:
# Timeout - check if task is still running # Timeout - check if task is still running
@@ -522,30 +326,10 @@ async def _stream_listener(
) )
# Update last delivered ID on successful delivery # Update last delivered ID on successful delivery
last_delivered_id = current_id last_delivered_id = current_id
messages_delivered += 1
if first_message_time is None:
first_message_time = time.perf_counter()
elapsed = (first_message_time - start_time) * 1000
logger.info(
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"chunk_type": type(chunk).__name__,
}
},
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning( logger.warning(
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s", f"Subscriber queue full for task {task_id}, "
extra={ f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
"json_fields": {
**log_meta,
"timeout_s": QUEUE_PUT_TIMEOUT,
"reason": "queue_full",
}
},
) )
# Send overflow error with recovery info # Send overflow error with recovery info
try: try:
@@ -567,44 +351,15 @@ async def _stream_listener(
# Stop listening on finish # Stop listening on finish
if isinstance(chunk, StreamFinish): if isinstance(chunk, StreamFinish):
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"messages_delivered": messages_delivered,
}
},
)
return return
except Exception as e: except Exception as e:
logger.warning( logger.warning(f"Error processing stream message: {e}")
f"Error processing stream message: {e}",
extra={"json_fields": {**log_meta, "error": str(e)}},
)
except asyncio.CancelledError: except asyncio.CancelledError:
elapsed = (time.perf_counter() - start_time) * 1000 logger.debug(f"Stream listener cancelled for task {task_id}")
logger.info(
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"messages_delivered": messages_delivered,
"reason": "cancelled",
}
},
)
raise # Re-raise to propagate cancellation raise # Re-raise to propagate cancellation
except Exception as e: except Exception as e:
elapsed = (time.perf_counter() - start_time) * 1000 logger.error(f"Stream listener error for task {task_id}: {e}")
logger.error(
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
)
# On error, send finish to unblock subscriber # On error, send finish to unblock subscriber
try: try:
await asyncio.wait_for( await asyncio.wait_for(
@@ -613,24 +368,10 @@ async def _stream_listener(
) )
except (asyncio.TimeoutError, asyncio.QueueFull): except (asyncio.TimeoutError, asyncio.QueueFull):
logger.warning( logger.warning(
"Could not deliver finish event after error", f"Could not deliver finish event for task {task_id} after error"
extra={"json_fields": log_meta},
) )
finally: finally:
# Clean up listener task mapping on exit # Clean up listener task mapping on exit
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
f"delivered={messages_delivered}, xread_count={xread_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"messages_delivered": messages_delivered,
"xread_count": xread_count,
}
},
)
_listener_tasks.pop(queue_id, None) _listener_tasks.pop(queue_id, None)
@@ -857,10 +598,8 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
ResponseType, ResponseType,
StreamError, StreamError,
StreamFinish, StreamFinish,
StreamFinishStep,
StreamHeartbeat, StreamHeartbeat,
StreamStart, StreamStart,
StreamStartStep,
StreamTextDelta, StreamTextDelta,
StreamTextEnd, StreamTextEnd,
StreamTextStart, StreamTextStart,
@@ -874,8 +613,6 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
type_to_class: dict[str, type[StreamBaseResponse]] = { type_to_class: dict[str, type[StreamBaseResponse]] = {
ResponseType.START.value: StreamStart, ResponseType.START.value: StreamStart,
ResponseType.FINISH.value: StreamFinish, ResponseType.FINISH.value: StreamFinish,
ResponseType.START_STEP.value: StreamStartStep,
ResponseType.FINISH_STEP.value: StreamFinishStep,
ResponseType.TEXT_START.value: StreamTextStart, ResponseType.TEXT_START.value: StreamTextStart,
ResponseType.TEXT_DELTA.value: StreamTextDelta, ResponseType.TEXT_DELTA.value: StreamTextDelta,
ResponseType.TEXT_END.value: StreamTextEnd, ResponseType.TEXT_END.value: StreamTextEnd,

View File

@@ -7,7 +7,15 @@ from typing import Any, NotRequired, TypedDict
from backend.api.features.library import db as library_db from backend.api.features.library import db as library_db
from backend.api.features.store import db as store_db from backend.api.features.store import db as store_db
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs from backend.data.graph import (
Graph,
Link,
Node,
create_graph,
get_graph,
get_graph_all_versions,
get_store_listed_graphs,
)
from backend.util.exceptions import DatabaseError, NotFoundError from backend.util.exceptions import DatabaseError, NotFoundError
from .service import ( from .service import (
@@ -20,6 +28,8 @@ from .service import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
class ExecutionSummary(TypedDict): class ExecutionSummary(TypedDict):
"""Summary of a single execution for quality assessment.""" """Summary of a single execution for quality assessment."""
@@ -659,6 +669,45 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
) )
def _reassign_node_ids(graph: Graph) -> None:
"""Reassign all node and link IDs to new UUIDs.
This is needed when creating a new version to avoid unique constraint violations.
"""
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
for node in graph.nodes:
node.id = id_map[node.id]
for link in graph.links:
link.id = str(uuid.uuid4())
if link.source_id in id_map:
link.source_id = id_map[link.source_id]
if link.sink_id in id_map:
link.sink_id = id_map[link.sink_id]
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
"""Populate user_id in AgentExecutorBlock nodes.
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
This function fills in the actual user_id so sub-agents run with correct permissions.
Args:
agent_json: Agent JSON dict (modified in place)
user_id: User ID to set
"""
for node in agent_json.get("nodes", []):
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
input_default = node.get("input_default") or {}
if not input_default.get("user_id"):
input_default["user_id"] = user_id
node["input_default"] = input_default
logger.debug(
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
)
async def save_agent_to_library( async def save_agent_to_library(
agent_json: dict[str, Any], user_id: str, is_update: bool = False agent_json: dict[str, Any], user_id: str, is_update: bool = False
) -> tuple[Graph, Any]: ) -> tuple[Graph, Any]:
@@ -672,10 +721,35 @@ async def save_agent_to_library(
Returns: Returns:
Tuple of (created Graph, LibraryAgent) Tuple of (created Graph, LibraryAgent)
""" """
# Populate user_id in AgentExecutorBlock nodes before conversion
_populate_agent_executor_user_ids(agent_json, user_id)
graph = json_to_graph(agent_json) graph = json_to_graph(agent_json)
if is_update: if is_update:
return await library_db.update_graph_in_library(graph, user_id) if graph.id:
return await library_db.create_graph_in_library(graph, user_id) existing_versions = await get_graph_all_versions(graph.id, user_id)
if existing_versions:
latest_version = max(v.version for v in existing_versions)
graph.version = latest_version + 1
_reassign_node_ids(graph)
logger.info(f"Updating agent {graph.id} to version {graph.version}")
else:
graph.id = str(uuid.uuid4())
graph.version = 1
_reassign_node_ids(graph)
logger.info(f"Creating new agent with ID {graph.id}")
created_graph = await create_graph(graph, user_id)
library_agents = await library_db.create_library_agent(
graph=created_graph,
user_id=user_id,
sensitive_action_safe_mode=True,
create_library_agents_for_sub_graphs=False,
)
return created_graph, library_agents[0]
def graph_to_json(graph: Graph) -> dict[str, Any]: def graph_to_json(graph: Graph) -> dict[str, Any]:

View File

@@ -206,9 +206,9 @@ async def search_agents(
] ]
) )
no_results_msg = ( no_results_msg = (
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs." f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
if source == "marketplace" if source == "marketplace"
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs." else f"No agents matching '{query}' found in your library."
) )
return NoResultsResponse( return NoResultsResponse(
message=no_results_msg, session_id=session_id, suggestions=suggestions message=no_results_msg, session_id=session_id, suggestions=suggestions
@@ -224,10 +224,10 @@ async def search_agents(
message = ( message = (
"Now you have found some options for the user to choose from. " "Now you have found some options for the user to choose from. "
"You can add a link to a recommended agent at: /marketplace/agent/agent_id " "You can add a link to a recommended agent at: /marketplace/agent/agent_id "
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs." "Please ask the user if they would like to use any of these agents."
if source == "marketplace" if source == "marketplace"
else "Found agents in the user's library. You can provide a link to view an agent at: " else "Found agents in the user's library. You can provide a link to view an agent at: "
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs." "/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
) )
return AgentsFoundResponse( return AgentsFoundResponse(

View File

@@ -3,6 +3,8 @@
import logging import logging
from typing import Any from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db from backend.api.features.store import db as store_db
from backend.api.features.store.exceptions import AgentNotFoundError from backend.api.features.store.exceptions import AgentNotFoundError
@@ -27,6 +29,23 @@ from .models import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CustomizeAgentInput(BaseModel):
"""Input parameters for the customize_agent tool."""
agent_id: str = ""
modifications: str = ""
context: str = ""
save: bool = True
@field_validator("agent_id", "modifications", "context", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str:
"""Strip whitespace from string fields."""
if isinstance(v, str):
return v.strip()
return v if v is not None else ""
class CustomizeAgentTool(BaseTool): class CustomizeAgentTool(BaseTool):
"""Tool for customizing marketplace/template agents using natural language.""" """Tool for customizing marketplace/template agents using natural language."""
@@ -92,7 +111,7 @@ class CustomizeAgentTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs, **kwargs: Any,
) -> ToolResponseBase: ) -> ToolResponseBase:
"""Execute the customize_agent tool. """Execute the customize_agent tool.
@@ -102,20 +121,17 @@ class CustomizeAgentTool(BaseTool):
3. Call customize_template with the modification request 3. Call customize_template with the modification request
4. Preview or save based on the save parameter 4. Preview or save based on the save parameter
""" """
agent_id = kwargs.get("agent_id", "").strip() params = CustomizeAgentInput(**kwargs)
modifications = kwargs.get("modifications", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None session_id = session.session_id if session else None
if not agent_id: if not params.agent_id:
return ErrorResponse( return ErrorResponse(
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').", message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
error="missing_agent_id", error="missing_agent_id",
session_id=session_id, session_id=session_id,
) )
if not modifications: if not params.modifications:
return ErrorResponse( return ErrorResponse(
message="Please describe how you want to customize this agent.", message="Please describe how you want to customize this agent.",
error="missing_modifications", error="missing_modifications",
@@ -123,11 +139,11 @@ class CustomizeAgentTool(BaseTool):
) )
# Parse agent_id in format "creator/slug" # Parse agent_id in format "creator/slug"
parts = [p.strip() for p in agent_id.split("/")] parts = params.agent_id.split("/")
if len(parts) != 2 or not parts[0] or not parts[1]: if len(parts) != 2 or not parts[0] or not parts[1]:
return ErrorResponse( return ErrorResponse(
message=( message=(
f"Invalid agent ID format: '{agent_id}'. " f"Invalid agent ID format: '{params.agent_id}'. "
"Expected format is 'creator/agent-name' " "Expected format is 'creator/agent-name' "
"(e.g., 'autogpt/newsletter-writer')." "(e.g., 'autogpt/newsletter-writer')."
), ),
@@ -145,14 +161,14 @@ class CustomizeAgentTool(BaseTool):
except AgentNotFoundError: except AgentNotFoundError:
return ErrorResponse( return ErrorResponse(
message=( message=(
f"Could not find marketplace agent '{agent_id}'. " f"Could not find marketplace agent '{params.agent_id}'. "
"Please check the agent ID and try again." "Please check the agent ID and try again."
), ),
error="agent_not_found", error="agent_not_found",
session_id=session_id, session_id=session_id,
) )
except Exception as e: except Exception as e:
logger.error(f"Error fetching marketplace agent {agent_id}: {e}") logger.error(f"Error fetching marketplace agent {params.agent_id}: {e}")
return ErrorResponse( return ErrorResponse(
message="Failed to fetch the marketplace agent. Please try again.", message="Failed to fetch the marketplace agent. Please try again.",
error="fetch_error", error="fetch_error",
@@ -162,7 +178,7 @@ class CustomizeAgentTool(BaseTool):
if not agent_details.store_listing_version_id: if not agent_details.store_listing_version_id:
return ErrorResponse( return ErrorResponse(
message=( message=(
f"The agent '{agent_id}' does not have an available version. " f"The agent '{params.agent_id}' does not have an available version. "
"Please try a different agent." "Please try a different agent."
), ),
error="no_version_available", error="no_version_available",
@@ -174,7 +190,7 @@ class CustomizeAgentTool(BaseTool):
graph = await store_db.get_agent(agent_details.store_listing_version_id) graph = await store_db.get_agent(agent_details.store_listing_version_id)
template_agent = graph_to_json(graph) template_agent = graph_to_json(graph)
except Exception as e: except Exception as e:
logger.error(f"Error fetching agent graph for {agent_id}: {e}") logger.error(f"Error fetching agent graph for {params.agent_id}: {e}")
return ErrorResponse( return ErrorResponse(
message="Failed to fetch the agent configuration. Please try again.", message="Failed to fetch the agent configuration. Please try again.",
error="graph_fetch_error", error="graph_fetch_error",
@@ -185,8 +201,8 @@ class CustomizeAgentTool(BaseTool):
try: try:
result = await customize_template( result = await customize_template(
template_agent=template_agent, template_agent=template_agent,
modification_request=modifications, modification_request=params.modifications,
context=context, context=params.context,
) )
except AgentGeneratorNotConfiguredError: except AgentGeneratorNotConfiguredError:
return ErrorResponse( return ErrorResponse(
@@ -198,7 +214,7 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
except Exception as e: except Exception as e:
logger.error(f"Error calling customize_template for {agent_id}: {e}") logger.error(f"Error calling customize_template for {params.agent_id}: {e}")
return ErrorResponse( return ErrorResponse(
message=( message=(
"Failed to customize the agent due to a service error. " "Failed to customize the agent due to a service error. "
@@ -219,55 +235,25 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
# Handle error response # Handle response using match/case for cleaner pattern matching
if isinstance(result, dict) and result.get("type") == "error": return await self._handle_customization_result(
error_msg = result.get("error", "Unknown error") result=result,
error_type = result.get("error_type", "unknown") params=params,
user_message = get_user_message_for_error( agent_details=agent_details,
error_type, user_id=user_id,
operation="customize the agent", session_id=session_id,
llm_parse_message=( )
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
# Handle clarifying questions async def _handle_customization_result(
if isinstance(result, dict) and result.get("type") == "clarifying_questions": self,
questions = result.get("questions") or [] result: dict[str, Any],
if not isinstance(questions, list): params: CustomizeAgentInput,
logger.error( agent_details: Any,
f"Unexpected clarifying questions format: {type(questions)}" user_id: str | None,
) session_id: str | None,
questions = [] ) -> ToolResponseBase:
return ClarificationNeededResponse( """Handle the result from customize_template using pattern matching."""
message=( # Ensure result is a dict
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
if isinstance(q, dict)
],
session_id=session_id,
)
# Result should be the customized agent JSON
if not isinstance(result, dict): if not isinstance(result, dict):
logger.error(f"Unexpected customize_template response type: {type(result)}") logger.error(f"Unexpected customize_template response type: {type(result)}")
return ErrorResponse( return ErrorResponse(
@@ -276,8 +262,77 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
customized_agent = result result_type = result.get("type")
match result_type:
case "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
case "clarifying_questions":
questions_data = result.get("questions") or []
if not isinstance(questions_data, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions_data)}"
)
questions_data = []
questions = [
ClarifyingQuestion(
question=q.get("question", "") if isinstance(q, dict) else "",
keyword=q.get("keyword", "") if isinstance(q, dict) else "",
example=q.get("example") if isinstance(q, dict) else None,
)
for q in questions_data
if isinstance(q, dict)
]
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=questions,
session_id=session_id,
)
case _:
# Default case: result is the customized agent JSON
return await self._save_or_preview_agent(
customized_agent=result,
params=params,
agent_details=agent_details,
user_id=user_id,
session_id=session_id,
)
async def _save_or_preview_agent(
self,
customized_agent: dict[str, Any],
params: CustomizeAgentInput,
agent_details: Any,
user_id: str | None,
session_id: str | None,
) -> ToolResponseBase:
"""Save or preview the customized agent based on params.save."""
agent_name = customized_agent.get( agent_name = customized_agent.get(
"name", f"Customized {agent_details.agent_name}" "name", f"Customized {agent_details.agent_name}"
) )
@@ -287,7 +342,7 @@ class CustomizeAgentTool(BaseTool):
node_count = len(nodes) if isinstance(nodes, list) else 0 node_count = len(nodes) if isinstance(nodes, list) else 0
link_count = len(links) if isinstance(links, list) else 0 link_count = len(links) if isinstance(links, list) else 0
if not save: if not params.save:
return AgentPreviewResponse( return AgentPreviewResponse(
message=( message=(
f"I've customized the agent '{agent_details.agent_name}'. " f"I've customized the agent '{agent_details.agent_name}'. "

View File

@@ -13,32 +13,10 @@ from backend.api.features.chat.tools.models import (
NoResultsResponse, NoResultsResponse,
) )
from backend.api.features.store.hybrid_search import unified_hybrid_search from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.data.block import BlockType, get_block from backend.data.block import get_block
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_TARGET_RESULTS = 10
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
_OVERFETCH_PAGE_SIZE = 40
# Block types that only work within graphs and cannot run standalone in CoPilot.
COPILOT_EXCLUDED_BLOCK_TYPES = {
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
BlockType.NOTE, # Visual annotation only - no runtime behavior
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
}
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
COPILOT_EXCLUDED_BLOCK_IDS = {
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
"3b191d9f-356f-482d-8238-ba04b6d18381",
}
class FindBlockTool(BaseTool): class FindBlockTool(BaseTool):
"""Tool for searching available blocks.""" """Tool for searching available blocks."""
@@ -110,7 +88,7 @@ class FindBlockTool(BaseTool):
query=query, query=query,
content_types=[ContentType.BLOCK], content_types=[ContentType.BLOCK],
page=1, page=1,
page_size=_OVERFETCH_PAGE_SIZE, page_size=10,
) )
if not results: if not results:
@@ -130,90 +108,60 @@ class FindBlockTool(BaseTool):
block = get_block(block_id) block = get_block(block_id)
# Skip disabled blocks # Skip disabled blocks
if not block or block.disabled: if block and not block.disabled:
continue # Get input/output schemas
input_schema = {}
output_schema = {}
try:
input_schema = block.input_schema.jsonschema()
except Exception:
pass
try:
output_schema = block.output_schema.jsonschema()
except Exception:
pass
# Skip blocks excluded from CoPilot (graph-only blocks) # Get categories from block instance
if ( categories = []
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES if hasattr(block, "categories") and block.categories:
or block.id in COPILOT_EXCLUDED_BLOCK_IDS categories = [cat.value for cat in block.categories]
):
continue
# Get input/output schemas # Extract required inputs for easier use
input_schema = {} required_inputs: list[BlockInputFieldInfo] = []
output_schema = {} if input_schema:
try: properties = input_schema.get("properties", {})
input_schema = block.input_schema.jsonschema() required_fields = set(input_schema.get("required", []))
except Exception as e: # Get credential field names to exclude from required inputs
logger.debug( credentials_fields = set(
"Failed to generate input schema for block %s: %s", block.input_schema.get_credentials_fields().keys()
block_id,
e,
)
try:
output_schema = block.output_schema.jsonschema()
except Exception as e:
logger.debug(
"Failed to generate output schema for block %s: %s",
block_id,
e,
)
# Get categories from block instance
categories = []
if hasattr(block, "categories") and block.categories:
categories = [cat.value for cat in block.categories]
# Extract required inputs for easier use
required_inputs: list[BlockInputFieldInfo] = []
if input_schema:
properties = input_schema.get("properties", {})
required_fields = set(input_schema.get("required", []))
# Get credential field names to exclude from required inputs
credentials_fields = set(
block.input_schema.get_credentials_fields().keys()
)
for field_name, field_schema in properties.items():
# Skip credential fields - they're handled separately
if field_name in credentials_fields:
continue
required_inputs.append(
BlockInputFieldInfo(
name=field_name,
type=field_schema.get("type", "string"),
description=field_schema.get("description", ""),
required=field_name in required_fields,
default=field_schema.get("default"),
)
) )
blocks.append( for field_name, field_schema in properties.items():
BlockInfoSummary( # Skip credential fields - they're handled separately
id=block_id, if field_name in credentials_fields:
name=block.name, continue
description=block.description or "",
categories=categories, required_inputs.append(
input_schema=input_schema, BlockInputFieldInfo(
output_schema=output_schema, name=field_name,
required_inputs=required_inputs, type=field_schema.get("type", "string"),
description=field_schema.get("description", ""),
required=field_name in required_fields,
default=field_schema.get("default"),
)
)
blocks.append(
BlockInfoSummary(
id=block_id,
name=block.name,
description=block.description or "",
categories=categories,
input_schema=input_schema,
output_schema=output_schema,
required_inputs=required_inputs,
)
) )
)
if len(blocks) >= _TARGET_RESULTS:
break
if blocks and len(blocks) < _TARGET_RESULTS:
logger.debug(
"find_block returned %d/%d results for query '%s' "
"(filtered %d excluded/disabled blocks)",
len(blocks),
_TARGET_RESULTS,
query,
len(results) - len(blocks),
)
if not blocks: if not blocks:
return NoResultsResponse( return NoResultsResponse(

View File

@@ -1,139 +0,0 @@
"""Tests for block filtering in FindBlockTool."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.api.features.chat.tools.find_block import (
COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES,
FindBlockTool,
)
from backend.api.features.chat.tools.models import BlockListResponse
from backend.data.block import BlockType
from ._test_data import make_session
_TEST_USER_ID = "test-user-find-block"
def make_mock_block(
block_id: str, name: str, block_type: BlockType, disabled: bool = False
):
"""Create a mock block for testing."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.description = f"{name} description"
mock.block_type = block_type
mock.disabled = disabled
mock.input_schema = MagicMock()
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
mock.input_schema.get_credentials_fields.return_value = {}
mock.output_schema = MagicMock()
mock.output_schema.jsonschema.return_value = {}
mock.categories = []
return mock
class TestFindBlockFiltering:
"""Tests for block filtering in FindBlockTool."""
def test_excluded_block_types_contains_expected_types(self):
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
def test_excluded_block_ids_contains_smart_decision_maker(self):
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_type_filtered_from_results(self):
"""Verify blocks with excluded BlockTypes are filtered from search results."""
session = make_session(user_id=_TEST_USER_ID)
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
search_results = [
{"content_id": "input-block-id", "score": 0.9},
{"content_id": "standard-block-id", "score": 0.8},
]
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
standard_block = make_mock_block(
"standard-block-id", "HTTP Request", BlockType.STANDARD
)
def mock_get_block(block_id):
return {
"input-block-id": input_block,
"standard-block-id": standard_block,
}.get(block_id)
with patch(
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
new_callable=AsyncMock,
return_value=(search_results, 2),
):
with patch(
"backend.api.features.chat.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID, session=session, query="test"
)
# Should only return the standard block, not the INPUT block
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 1
assert response.blocks[0].id == "standard-block-id"
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_id_filtered_from_results(self):
"""Verify SmartDecisionMakerBlock is filtered from search results."""
session = make_session(user_id=_TEST_USER_ID)
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
search_results = [
{"content_id": smart_decision_id, "score": 0.9},
{"content_id": "normal-block-id", "score": 0.8},
]
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
smart_block = make_mock_block(
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
)
normal_block = make_mock_block(
"normal-block-id", "Normal Block", BlockType.STANDARD
)
def mock_get_block(block_id):
return {
smart_decision_id: smart_block,
"normal-block-id": normal_block,
}.get(block_id)
with patch(
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
new_callable=AsyncMock,
return_value=(search_results, 2),
):
with patch(
"backend.api.features.chat.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID, session=session, query="decision"
)
# Should only return normal block, not SmartDecisionMakerBlock
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 1
assert response.blocks[0].id == "normal-block-id"

View File

@@ -1,29 +0,0 @@
"""Shared helpers for chat tools."""
from typing import Any
def get_inputs_from_schema(
input_schema: dict[str, Any],
exclude_fields: set[str] | None = None,
) -> list[dict[str, Any]]:
"""Extract input field info from JSON schema."""
if not isinstance(input_schema, dict):
return []
exclude = exclude_fields or set()
properties = input_schema.get("properties", {})
required = set(input_schema.get("required", []))
return [
{
"name": name,
"title": schema.get("title", name),
"type": schema.get("type", "string"),
"description": schema.get("description", ""),
"required": name in required,
"default": schema.get("default"),
}
for name, schema in properties.items()
if name not in exclude
]

View File

@@ -24,7 +24,6 @@ from backend.util.timezone_utils import (
) )
from .base import BaseTool from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import ( from .models import (
AgentDetails, AgentDetails,
AgentDetailsResponse, AgentDetailsResponse,
@@ -262,7 +261,7 @@ class RunAgentTool(BaseTool):
), ),
requirements={ requirements={
"credentials": requirements_creds_list, "credentials": requirements_creds_list,
"inputs": get_inputs_from_schema(graph.input_schema), "inputs": self._get_inputs_list(graph.input_schema),
"execution_modes": self._get_execution_modes(graph), "execution_modes": self._get_execution_modes(graph),
}, },
), ),
@@ -370,6 +369,22 @@ class RunAgentTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract inputs list from schema."""
inputs_list = []
if isinstance(input_schema, dict) and "properties" in input_schema:
for field_name, field_schema in input_schema["properties"].items():
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in input_schema.get("required", []),
}
)
return inputs_list
def _get_execution_modes(self, graph: GraphModel) -> list[str]: def _get_execution_modes(self, graph: GraphModel) -> list[str]:
"""Get available execution modes for the graph.""" """Get available execution modes for the graph."""
trigger_info = graph.trigger_setup_info trigger_info = graph.trigger_setup_info
@@ -383,7 +398,7 @@ class RunAgentTool(BaseTool):
suffix: str, suffix: str,
) -> str: ) -> str:
"""Build a message describing available inputs for an agent.""" """Build a message describing available inputs for an agent."""
inputs_list = get_inputs_from_schema(graph.input_schema) inputs_list = self._get_inputs_list(graph.input_schema)
required_names = [i["name"] for i in inputs_list if i["required"]] required_names = [i["name"] for i in inputs_list if i["required"]]
optional_names = [i["name"] for i in inputs_list if not i["required"]] optional_names = [i["name"] for i in inputs_list if not i["required"]]

View File

@@ -8,19 +8,14 @@ from typing import Any
from pydantic_core import PydanticUndefined from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.find_block import ( from backend.data.block import get_block
COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES,
)
from backend.data.block import AnyBlockSchema, get_block
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput from backend.data.model import CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError from backend.util.exceptions import BlockError
from .base import BaseTool from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import ( from .models import (
BlockOutputResponse, BlockOutputResponse,
ErrorResponse, ErrorResponse,
@@ -29,10 +24,7 @@ from .models import (
ToolResponseBase, ToolResponseBase,
UserReadiness, UserReadiness,
) )
from .utils import ( from .utils import build_missing_credentials_from_field_info
build_missing_credentials_from_field_info,
match_credentials_to_requirements,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -81,6 +73,91 @@ class RunBlockTool(BaseTool):
def requires_auth(self) -> bool: def requires_auth(self) -> bool:
return True return True
async def _check_block_credentials(
self,
user_id: str,
block: Any,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Args:
user_id: User ID
block: Block to check credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple[matched_credentials, missing_credentials]
"""
matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
input_data = input_data or {}
# Get credential field info from block's input schema
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return matched_credentials, missing_credentials
# Get user's available credentials
creds_manager = IntegrationCredentialsManager()
available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
# Get discriminator from input, falling back to schema default
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in effective_field_info.provider
and cred.type in effective_field_info.supported_types
),
None,
)
if matching_cred:
matched_credentials[field_name] = CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
else:
# Create a placeholder for the missing credential
provider = next(iter(effective_field_info.provider), "unknown")
cred_type = next(iter(effective_field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched_credentials, missing_credentials
async def _execute( async def _execute(
self, self,
user_id: str | None, user_id: str | None,
@@ -135,24 +212,11 @@ class RunBlockTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
# Check if block is excluded from CoPilot (graph-only blocks)
if (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
return ErrorResponse(
message=(
f"Block '{block.name}' cannot be run directly in CoPilot. "
"This block is designed for use within graphs only."
),
session_id=session_id,
)
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}") logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
creds_manager = IntegrationCredentialsManager() creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = ( matched_credentials, missing_credentials = await self._check_block_credentials(
await self._resolve_block_credentials(user_id, block, input_data) user_id, block, input_data
) )
if missing_credentials: if missing_credentials:
@@ -281,75 +345,29 @@ class RunBlockTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
async def _resolve_block_credentials( def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
self,
user_id: str,
block: AnyBlockSchema,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Resolve credentials for a block by matching user's available credentials.
Args:
user_id: User ID
block: Block to resolve credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple of (matched_credentials, missing_credentials) - matched credentials
are used for block execution, missing ones indicate setup requirements.
"""
input_data = input_data or {}
requirements = self._resolve_discriminated_credentials(block, input_data)
if not requirements:
return {}, []
return await match_credentials_to_requirements(user_id, requirements)
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
"""Extract non-credential inputs from block schema.""" """Extract non-credential inputs from block schema."""
inputs_list = []
schema = block.input_schema.jsonschema() schema = block.input_schema.jsonschema()
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
# Get credential field names to exclude
credentials_fields = set(block.input_schema.get_credentials_fields().keys()) credentials_fields = set(block.input_schema.get_credentials_fields().keys())
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
def _resolve_discriminated_credentials( for field_name, field_schema in properties.items():
self, # Skip credential fields
block: AnyBlockSchema, if field_name in credentials_fields:
input_data: dict[str, Any], continue
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed."""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return {}
resolved: dict[str, CredentialsFieldInfo] = {} inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in required_fields,
}
)
for field_name, field_info in credentials_fields_info.items(): return inputs_list
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
# For host-scoped credentials, add the discriminator value
# (e.g., URL) so _credential_is_for_host can match it
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
resolved[field_name] = effective_field_info
return resolved

View File

@@ -1,106 +0,0 @@
"""Tests for block execution guards in RunBlockTool."""
from unittest.mock import MagicMock, patch
import pytest
from backend.api.features.chat.tools.models import ErrorResponse
from backend.api.features.chat.tools.run_block import RunBlockTool
from backend.data.block import BlockType
from ._test_data import make_session
_TEST_USER_ID = "test-user-run-block"
def make_mock_block(
block_id: str, name: str, block_type: BlockType, disabled: bool = False
):
"""Create a mock block for testing."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.block_type = block_type
mock.disabled = disabled
mock.input_schema = MagicMock()
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
mock.input_schema.get_credentials_fields_info.return_value = []
return mock
class TestRunBlockFiltering:
"""Tests for block execution guards in RunBlockTool."""
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_type_returns_error(self):
"""Attempting to execute a block with excluded BlockType returns error."""
session = make_session(user_id=_TEST_USER_ID)
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
return_value=input_block,
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="input-block-id",
input_data={},
)
assert isinstance(response, ErrorResponse)
assert "cannot be run directly in CoPilot" in response.message
assert "designed for use within graphs only" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_id_returns_error(self):
"""Attempting to execute SmartDecisionMakerBlock returns error."""
session = make_session(user_id=_TEST_USER_ID)
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
smart_block = make_mock_block(
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
return_value=smart_block,
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id=smart_decision_id,
input_data={},
)
assert isinstance(response, ErrorResponse)
assert "cannot be run directly in CoPilot" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_non_excluded_block_passes_guard(self):
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
session = make_session(user_id=_TEST_USER_ID)
standard_block = make_mock_block(
"standard-id", "HTTP Request", BlockType.STANDARD
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
return_value=standard_block,
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="standard-id",
input_data={},
)
# Should NOT be an ErrorResponse about CoPilot exclusion
# (may be other errors like missing credentials, but not the exclusion guard)
if isinstance(response, ErrorResponse):
assert "cannot be run directly in CoPilot" not in response.message

View File

@@ -6,14 +6,9 @@ from typing import Any
from backend.api.features.library import db as library_db from backend.api.features.library import db as library_db
from backend.api.features.library import model as library_model from backend.api.features.library import model as library_model
from backend.api.features.store import db as store_db from backend.api.features.store import db as store_db
from backend.data import graph as graph_db
from backend.data.graph import GraphModel from backend.data.graph import GraphModel
from backend.data.model import ( from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
HostScopedCredentials,
OAuth2Credentials,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import NotFoundError from backend.util.exceptions import NotFoundError
@@ -44,8 +39,14 @@ async def fetch_graph_from_store_slug(
return None, None return None, None
# Get the graph from store listing version # Get the graph from store listing version
graph = await store_db.get_available_graph( graph_meta = await store_db.get_available_graph(
store_agent.store_listing_version_id, hide_nodes=False store_agent.store_listing_version_id
)
graph = await graph_db.get_graph(
graph_id=graph_meta.id,
version=graph_meta.version,
user_id=None, # Public access
include_subgraphs=True,
) )
return graph, store_agent return graph, store_agent
@@ -122,7 +123,7 @@ def build_missing_credentials_from_graph(
return { return {
field_key: _serialize_missing_credential(field_key, field_info) field_key: _serialize_missing_credential(field_key, field_info)
for field_key, (field_info, _, _) in aggregated_fields.items() for field_key, (field_info, _node_fields) in aggregated_fields.items()
if field_key not in matched_keys if field_key not in matched_keys
} }
@@ -224,99 +225,6 @@ async def get_or_create_library_agent(
return library_agents[0] return library_agents[0]
async def match_credentials_to_requirements(
user_id: str,
requirements: dict[str, CredentialsFieldInfo],
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Match user's credentials against a dictionary of credential requirements.
This is the core matching logic shared by both graph and block credential matching.
"""
matched: dict[str, CredentialsMetaInput] = {}
missing: list[CredentialsMetaInput] = []
if not requirements:
return matched, missing
available_creds = await get_user_credentials(user_id)
for field_name, field_info in requirements.items():
matching_cred = find_matching_credential(available_creds, field_info)
if matching_cred:
try:
matched[field_name] = create_credential_meta_from_match(matching_cred)
except Exception as e:
logger.error(
f"Failed to create CredentialsMetaInput for field '{field_name}': "
f"provider={matching_cred.provider}, type={matching_cred.type}, "
f"credential_id={matching_cred.id}",
exc_info=True,
)
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=f"{field_name} (validation failed: {e})",
)
)
else:
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched, missing
async def get_user_credentials(user_id: str) -> list[Credentials]:
"""Get all available credentials for a user."""
creds_manager = IntegrationCredentialsManager()
return await creds_manager.store.get_all_creds(user_id)
def find_matching_credential(
available_creds: list[Credentials],
field_info: CredentialsFieldInfo,
) -> Credentials | None:
"""Find a credential that matches the required provider, type, scopes, and host."""
for cred in available_creds:
if cred.provider not in field_info.provider:
continue
if cred.type not in field_info.supported_types:
continue
if cred.type == "oauth2" and not _credential_has_required_scopes(
cred, field_info
):
continue
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
continue
return cred
return None
def create_credential_meta_from_match(
matching_cred: Credentials,
) -> CredentialsMetaInput:
"""Create a CredentialsMetaInput from a matched credential."""
return CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
async def match_user_credentials_to_graph( async def match_user_credentials_to_graph(
user_id: str, user_id: str,
graph: GraphModel, graph: GraphModel,
@@ -356,8 +264,7 @@ async def match_user_credentials_to_graph(
# provider is in the set of acceptable providers. # provider is in the set of acceptable providers.
for credential_field_name, ( for credential_field_name, (
credential_requirements, credential_requirements,
_, _node_fields,
_,
) in aggregated_creds.items(): ) in aggregated_creds.items():
# Find first matching credential by provider, type, and scopes # Find first matching credential by provider, type, and scopes
matching_cred = next( matching_cred = next(
@@ -366,14 +273,7 @@ async def match_user_credentials_to_graph(
for cred in available_creds for cred in available_creds
if cred.provider in credential_requirements.provider if cred.provider in credential_requirements.provider
and cred.type in credential_requirements.supported_types and cred.type in credential_requirements.supported_types
and ( and _credential_has_required_scopes(cred, credential_requirements)
cred.type != "oauth2"
or _credential_has_required_scopes(cred, credential_requirements)
)
and (
cred.type != "host_scoped"
or _credential_is_for_host(cred, credential_requirements)
)
), ),
None, None,
) )
@@ -418,32 +318,27 @@ async def match_user_credentials_to_graph(
def _credential_has_required_scopes( def _credential_has_required_scopes(
credential: OAuth2Credentials, credential: Credentials,
requirements: CredentialsFieldInfo, requirements: CredentialsFieldInfo,
) -> bool: ) -> bool:
"""Check if an OAuth2 credential has all the scopes required by the input.""" """
Check if a credential has all the scopes required by the block.
For OAuth2 credentials, verifies that the credential's scopes are a superset
of the required scopes. For other credential types, returns True (no scope check).
"""
# Only OAuth2 credentials have scopes to check
if credential.type != "oauth2":
return True
# If no scopes are required, any credential matches # If no scopes are required, any credential matches
if not requirements.required_scopes: if not requirements.required_scopes:
return True return True
# Check that credential scopes are a superset of required scopes
return set(credential.scopes).issuperset(requirements.required_scopes) return set(credential.scopes).issuperset(requirements.required_scopes)
def _credential_is_for_host(
credential: HostScopedCredentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if a host-scoped credential matches the host required by the input."""
# We need to know the host to match host-scoped credentials to.
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
# to discriminator_values. No discriminator_values -> no host to match against.
if not requirements.discriminator_values:
return True
# Check that credential host matches required host.
# Host-scoped credential inputs are grouped by host, so any item from the set works.
return credential.matches_url(list(requirements.discriminator_values)[0])
async def check_user_has_required_credentials( async def check_user_has_required_credentials(
user_id: str, user_id: str,
required_credentials: list[CredentialsMetaInput], required_credentials: list[CredentialsMetaInput],

View File

@@ -19,10 +19,7 @@ from backend.data.graph import GraphSettings
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
from backend.data.model import CredentialsMetaInput from backend.data.model import CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import ( from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
on_graph_activate,
on_graph_deactivate,
)
from backend.util.clients import get_scheduler_client from backend.util.clients import get_scheduler_client
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
from backend.util.json import SafeJson from backend.util.json import SafeJson
@@ -374,7 +371,7 @@ async def get_library_agent_by_graph_id(
async def add_generated_agent_image( async def add_generated_agent_image(
graph: graph_db.GraphBaseMeta, graph: graph_db.BaseGraph,
user_id: str, user_id: str,
library_agent_id: str, library_agent_id: str,
) -> Optional[prisma.models.LibraryAgent]: ) -> Optional[prisma.models.LibraryAgent]:
@@ -540,92 +537,6 @@ async def update_agent_version_in_library(
return library_model.LibraryAgent.from_db(lib) return library_model.LibraryAgent.from_db(lib)
async def create_graph_in_library(
graph: graph_db.Graph,
user_id: str,
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
"""Create a new graph and add it to the user's library."""
graph.version = 1
graph_model = graph_db.make_graph_model(graph, user_id)
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=True)
created_graph = await graph_db.create_graph(graph_model, user_id)
library_agents = await create_library_agent(
graph=created_graph,
user_id=user_id,
sensitive_action_safe_mode=True,
create_library_agents_for_sub_graphs=False,
)
if created_graph.is_active:
created_graph = await on_graph_activate(created_graph, user_id=user_id)
return created_graph, library_agents[0]
async def update_graph_in_library(
graph: graph_db.Graph,
user_id: str,
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
"""Create a new version of an existing graph and update the library entry."""
existing_versions = await graph_db.get_graph_all_versions(graph.id, user_id)
current_active_version = (
next((v for v in existing_versions if v.is_active), None)
if existing_versions
else None
)
graph.version = (
max(v.version for v in existing_versions) + 1 if existing_versions else 1
)
graph_model = graph_db.make_graph_model(graph, user_id)
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=False)
created_graph = await graph_db.create_graph(graph_model, user_id)
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
if not library_agent:
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
library_agent = await update_library_agent_version_and_settings(
user_id, created_graph
)
if created_graph.is_active:
created_graph = await on_graph_activate(created_graph, user_id=user_id)
await graph_db.set_graph_active_version(
graph_id=created_graph.id,
version=created_graph.version,
user_id=user_id,
)
if current_active_version:
await on_graph_deactivate(current_active_version, user_id=user_id)
return created_graph, library_agent
async def update_library_agent_version_and_settings(
user_id: str, agent_graph: graph_db.GraphModel
) -> library_model.LibraryAgent:
"""Update library agent to point to new graph version and sync settings."""
library = await update_agent_version_in_library(
user_id, agent_graph.id, agent_graph.version
)
updated_settings = GraphSettings.from_graph(
graph=agent_graph,
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
)
if updated_settings != library.settings:
library = await update_library_agent(
library_agent_id=library.id,
user_id=user_id,
settings=updated_settings,
)
return library
async def update_library_agent( async def update_library_agent(
library_agent_id: str, library_agent_id: str,
user_id: str, user_id: str,

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Literal, overload from typing import Any, Literal
import fastapi import fastapi
import prisma.enums import prisma.enums
@@ -11,8 +11,8 @@ import prisma.types
from backend.data.db import transaction from backend.data.db import transaction
from backend.data.graph import ( from backend.data.graph import (
GraphMeta,
GraphModel, GraphModel,
GraphModelWithoutNodes,
get_graph, get_graph,
get_graph_as_admin, get_graph_as_admin,
get_sub_graphs, get_sub_graphs,
@@ -334,22 +334,7 @@ async def get_store_agent_details(
raise DatabaseError("Failed to fetch agent details") from e raise DatabaseError("Failed to fetch agent details") from e
@overload async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
async def get_available_graph(
store_listing_version_id: str, hide_nodes: Literal[False]
) -> GraphModel: ...
@overload
async def get_available_graph(
store_listing_version_id: str, hide_nodes: Literal[True] = True
) -> GraphModelWithoutNodes: ...
async def get_available_graph(
store_listing_version_id: str,
hide_nodes: bool = True,
) -> GraphModelWithoutNodes | GraphModel:
try: try:
# Get avaialble, non-deleted store listing version # Get avaialble, non-deleted store listing version
store_listing_version = ( store_listing_version = (
@@ -359,7 +344,7 @@ async def get_available_graph(
"isAvailable": True, "isAvailable": True,
"isDeleted": False, "isDeleted": False,
}, },
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}}, include={"AgentGraph": {"include": {"Nodes": True}}},
) )
) )
@@ -369,9 +354,7 @@ async def get_available_graph(
detail=f"Store listing version {store_listing_version_id} not found", detail=f"Store listing version {store_listing_version_id} not found",
) )
return (GraphModelWithoutNodes if hide_nodes else GraphModel).from_db( return GraphModel.from_db(store_listing_version.AgentGraph).meta()
store_listing_version.AgentGraph
)
except Exception as e: except Exception as e:
logger.error(f"Error getting agent: {e}") logger.error(f"Error getting agent: {e}")

View File

@@ -8,7 +8,6 @@ Includes BM25 reranking for improved lexical relevance.
import logging import logging
import re import re
import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Literal from typing import Any, Literal
@@ -363,11 +362,7 @@ async def unified_hybrid_search(
LIMIT {limit_param} OFFSET {offset_param} LIMIT {limit_param} OFFSET {offset_param}
""" """
try: results = await query_raw_with_schema(sql_query, *params)
results = await query_raw_with_schema(sql_query, *params)
except Exception as e:
await _log_vector_error_diagnostics(e)
raise
total = results[0]["total_count"] if results else 0 total = results[0]["total_count"] if results else 0
# Apply BM25 reranking # Apply BM25 reranking
@@ -691,11 +686,7 @@ async def hybrid_search(
LIMIT {limit_param} OFFSET {offset_param} LIMIT {limit_param} OFFSET {offset_param}
""" """
try: results = await query_raw_with_schema(sql_query, *params)
results = await query_raw_with_schema(sql_query, *params)
except Exception as e:
await _log_vector_error_diagnostics(e)
raise
total = results[0]["total_count"] if results else 0 total = results[0]["total_count"] if results else 0
@@ -727,87 +718,6 @@ async def hybrid_search_simple(
return await hybrid_search(query=query, page=page, page_size=page_size) return await hybrid_search(query=query, page=page, page_size=page_size)
# ============================================================================
# Diagnostics
# ============================================================================
# Rate limit: only log vector error diagnostics once per this interval
_VECTOR_DIAG_INTERVAL_SECONDS = 60
_last_vector_diag_time: float = 0
async def _log_vector_error_diagnostics(error: Exception) -> None:
"""Log diagnostic info when 'type vector does not exist' error occurs.
Note: Diagnostic queries use query_raw_with_schema which may run on a different
pooled connection than the one that failed. Session-level search_path can differ,
so these diagnostics show cluster-wide state, not necessarily the failed session.
Includes rate limiting to avoid log spam - only logs once per minute.
Caller should re-raise the error after calling this function.
"""
global _last_vector_diag_time
# Check if this is the vector type error
error_str = str(error).lower()
if not (
"type" in error_str and "vector" in error_str and "does not exist" in error_str
):
return
# Rate limit: only log once per interval
now = time.time()
if now - _last_vector_diag_time < _VECTOR_DIAG_INTERVAL_SECONDS:
return
_last_vector_diag_time = now
try:
diagnostics: dict[str, object] = {}
try:
search_path_result = await query_raw_with_schema("SHOW search_path")
diagnostics["search_path"] = search_path_result
except Exception as e:
diagnostics["search_path"] = f"Error: {e}"
try:
schema_result = await query_raw_with_schema("SELECT current_schema()")
diagnostics["current_schema"] = schema_result
except Exception as e:
diagnostics["current_schema"] = f"Error: {e}"
try:
user_result = await query_raw_with_schema(
"SELECT current_user, session_user, current_database()"
)
diagnostics["user_info"] = user_result
except Exception as e:
diagnostics["user_info"] = f"Error: {e}"
try:
# Check pgvector extension installation (cluster-wide, stable info)
ext_result = await query_raw_with_schema(
"SELECT extname, extversion, nspname as schema "
"FROM pg_extension e "
"JOIN pg_namespace n ON e.extnamespace = n.oid "
"WHERE extname = 'vector'"
)
diagnostics["pgvector_extension"] = ext_result
except Exception as e:
diagnostics["pgvector_extension"] = f"Error: {e}"
logger.error(
f"Vector type error diagnostics:\n"
f" Error: {error}\n"
f" search_path: {diagnostics.get('search_path')}\n"
f" current_schema: {diagnostics.get('current_schema')}\n"
f" user_info: {diagnostics.get('user_info')}\n"
f" pgvector_extension: {diagnostics.get('pgvector_extension')}"
)
except Exception as diag_error:
logger.error(f"Failed to collect vector error diagnostics: {diag_error}")
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights # Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
# for existing code that expects the popularity parameter # for existing code that expects the popularity parameter
HybridSearchWeights = StoreAgentSearchWeights HybridSearchWeights = StoreAgentSearchWeights

View File

@@ -16,7 +16,7 @@ from backend.blocks.ideogram import (
StyleType, StyleType,
UpscaleOption, UpscaleOption,
) )
from backend.data.graph import GraphBaseMeta from backend.data.graph import BaseGraph
from backend.data.model import CredentialsMetaInput, ProviderName from backend.data.model import CredentialsMetaInput, ProviderName
from backend.integrations.credentials_store import ideogram_credentials from backend.integrations.credentials_store import ideogram_credentials
from backend.util.request import Requests from backend.util.request import Requests
@@ -34,14 +34,14 @@ class ImageStyle(str, Enum):
DIGITAL_ART = "digital art" DIGITAL_ART = "digital art"
async def generate_agent_image(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO: async def generate_agent_image(agent: BaseGraph | AgentGraph) -> io.BytesIO:
if settings.config.use_agent_image_generation_v2: if settings.config.use_agent_image_generation_v2:
return await generate_agent_image_v2(graph=agent) return await generate_agent_image_v2(graph=agent)
else: else:
return await generate_agent_image_v1(agent=agent) return await generate_agent_image_v1(agent=agent)
async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.BytesIO: async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
""" """
Generate an image for an agent using Ideogram model. Generate an image for an agent using Ideogram model.
Returns: Returns:
@@ -54,17 +54,14 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
description = f"{name} ({graph.description})" if graph.description else name description = f"{name} ({graph.description})" if graph.description else name
prompt = ( prompt = (
"Create a visually striking retro-futuristic vector pop art illustration " f"Create a visually striking retro-futuristic vector pop art illustration prominently featuring "
f'prominently featuring "{name}" in bold typography. The image clearly and ' f'"{name}" in bold typography. The image clearly and literally depicts a {description}, '
f"literally depicts a {description}, along with recognizable objects directly " f"along with recognizable objects directly associated with the primary function of a {name}. "
f"associated with the primary function of a {name}. " f"Ensure the imagery is concrete, intuitive, and immediately understandable, clearly conveying the "
f"Ensure the imagery is concrete, intuitive, and immediately understandable, " f"purpose of a {name}. Maintain vibrant, limited-palette colors, sharp vector lines, geometric "
f"clearly conveying the purpose of a {name}. " f"shapes, flat illustration techniques, and solid colors without gradients or shading. Preserve a "
"Maintain vibrant, limited-palette colors, sharp vector lines, " f"retro-futuristic aesthetic influenced by mid-century futurism and 1960s psychedelia, "
"geometric shapes, flat illustration techniques, and solid colors " f"prioritizing clear visual storytelling and thematic clarity above all else."
"without gradients or shading. Preserve a retro-futuristic aesthetic "
"influenced by mid-century futurism and 1960s psychedelia, "
"prioritizing clear visual storytelling and thematic clarity above all else."
) )
custom_colors = [ custom_colors = [
@@ -102,12 +99,12 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
return io.BytesIO(response.content) return io.BytesIO(response.content)
async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO: async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
""" """
Generate an image for an agent using Flux model via Replicate API. Generate an image for an agent using Flux model via Replicate API.
Args: Args:
agent (GraphBaseMeta | AgentGraph): The agent to generate an image for agent (Graph): The agent to generate an image for
Returns: Returns:
io.BytesIO: The generated image as bytes io.BytesIO: The generated image as bytes
@@ -117,13 +114,7 @@ async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.Bytes
raise ValueError("Missing Replicate API key in settings") raise ValueError("Missing Replicate API key in settings")
# Construct prompt from agent details # Construct prompt from agent details
prompt = ( prompt = f"Create a visually engaging app store thumbnail for the AI agent that highlights what it does in a clear and captivating way:\n- **Name**: {agent.name}\n- **Description**: {agent.description}\nFocus on showcasing its core functionality with an appealing design."
"Create a visually engaging app store thumbnail for the AI agent "
"that highlights what it does in a clear and captivating way:\n"
f"- **Name**: {agent.name}\n"
f"- **Description**: {agent.description}\n"
f"Focus on showcasing its core functionality with an appealing design."
)
# Set up Replicate client # Set up Replicate client
client = ReplicateClient(api_token=settings.secrets.replicate_api_key) client = ReplicateClient(api_token=settings.secrets.replicate_api_key)

View File

@@ -278,7 +278,7 @@ async def get_agent(
) )
async def get_graph_meta_by_store_listing_version_id( async def get_graph_meta_by_store_listing_version_id(
store_listing_version_id: str, store_listing_version_id: str,
) -> backend.data.graph.GraphModelWithoutNodes: ) -> backend.data.graph.GraphMeta:
""" """
Get Agent Graph from Store Listing Version ID. Get Agent Graph from Store Listing Version ID.
""" """

View File

@@ -101,6 +101,7 @@ from backend.util.timezone_utils import (
from backend.util.virus_scanner import scan_content_safe from backend.util.virus_scanner import scan_content_safe
from .library import db as library_db from .library import db as library_db
from .library import model as library_model
from .store.model import StoreAgentDetails from .store.model import StoreAgentDetails
@@ -822,16 +823,18 @@ async def update_graph(
graph: graph_db.Graph, graph: graph_db.Graph,
user_id: Annotated[str, Security(get_user_id)], user_id: Annotated[str, Security(get_user_id)],
) -> graph_db.GraphModel: ) -> graph_db.GraphModel:
# Sanity check
if graph.id and graph.id != graph_id: if graph.id and graph.id != graph_id:
raise HTTPException(400, detail="Graph ID does not match ID in URI") raise HTTPException(400, detail="Graph ID does not match ID in URI")
# Determine new version
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id) existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
if not existing_versions: if not existing_versions:
raise HTTPException(404, detail=f"Graph #{graph_id} not found") raise HTTPException(404, detail=f"Graph #{graph_id} not found")
latest_version_number = max(g.version for g in existing_versions)
graph.version = latest_version_number + 1
graph.version = max(g.version for g in existing_versions) + 1
current_active_version = next((v for v in existing_versions if v.is_active), None) current_active_version = next((v for v in existing_versions if v.is_active), None)
graph = graph_db.make_graph_model(graph, user_id) graph = graph_db.make_graph_model(graph, user_id)
graph.reassign_ids(user_id=user_id, reassign_graph_id=False) graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
graph.validate_graph(for_run=False) graph.validate_graph(for_run=False)
@@ -839,23 +842,27 @@ async def update_graph(
new_graph_version = await graph_db.create_graph(graph, user_id=user_id) new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
if new_graph_version.is_active: if new_graph_version.is_active:
await library_db.update_library_agent_version_and_settings( # Keep the library agent up to date with the new active version
user_id, new_graph_version await _update_library_agent_version_and_settings(user_id, new_graph_version)
)
# Handle activation of the new graph first to ensure continuity
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id) new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
# Ensure new version is the only active version
await graph_db.set_graph_active_version( await graph_db.set_graph_active_version(
graph_id=graph_id, version=new_graph_version.version, user_id=user_id graph_id=graph_id, version=new_graph_version.version, user_id=user_id
) )
if current_active_version: if current_active_version:
# Handle deactivation of the previously active version
await on_graph_deactivate(current_active_version, user_id=user_id) await on_graph_deactivate(current_active_version, user_id=user_id)
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
new_graph_version_with_subgraphs = await graph_db.get_graph( new_graph_version_with_subgraphs = await graph_db.get_graph(
graph_id, graph_id,
new_graph_version.version, new_graph_version.version,
user_id=user_id, user_id=user_id,
include_subgraphs=True, include_subgraphs=True,
) )
assert new_graph_version_with_subgraphs assert new_graph_version_with_subgraphs # make type checker happy
return new_graph_version_with_subgraphs return new_graph_version_with_subgraphs
@@ -893,15 +900,33 @@ async def set_graph_active_version(
) )
# Keep the library agent up to date with the new active version # Keep the library agent up to date with the new active version
await library_db.update_library_agent_version_and_settings( await _update_library_agent_version_and_settings(user_id, new_active_graph)
user_id, new_active_graph
)
if current_active_graph and current_active_graph.version != new_active_version: if current_active_graph and current_active_graph.version != new_active_version:
# Handle deactivation of the previously active version # Handle deactivation of the previously active version
await on_graph_deactivate(current_active_graph, user_id=user_id) await on_graph_deactivate(current_active_graph, user_id=user_id)
async def _update_library_agent_version_and_settings(
user_id: str, agent_graph: graph_db.GraphModel
) -> library_model.LibraryAgent:
library = await library_db.update_agent_version_in_library(
user_id, agent_graph.id, agent_graph.version
)
updated_settings = GraphSettings.from_graph(
graph=agent_graph,
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
)
if updated_settings != library.settings:
library = await library_db.update_library_agent(
library_agent_id=library.id,
user_id=user_id,
settings=updated_settings,
)
return library
@v1_router.patch( @v1_router.patch(
path="/graphs/{graph_id}/settings", path="/graphs/{graph_id}/settings",
summary="Update graph settings", summary="Update graph settings",

View File

@@ -1,28 +0,0 @@
"""ElevenLabs integration blocks - test credentials and shared utilities."""
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="elevenlabs",
api_key=SecretStr("mock-elevenlabs-api-key"),
title="Mock ElevenLabs API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
ElevenLabsCredentials = APIKeyCredentials
ElevenLabsCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.ELEVENLABS], Literal["api_key"]
]

View File

@@ -1,77 +0,0 @@
"""Text encoding block for converting special characters to escape sequences."""
import codecs
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
class TextEncoderBlock(Block):
"""
Encodes a string by converting special characters into escape sequences.
This block is the inverse of TextDecoderBlock. It takes text containing
special characters (like newlines, tabs, etc.) and converts them into
their escape sequence representations (e.g., newline becomes \\n).
"""
class Input(BlockSchemaInput):
"""Input schema for TextEncoderBlock."""
text: str = SchemaField(
description="A string containing special characters to be encoded",
placeholder="Your text with newlines and quotes to encode",
)
class Output(BlockSchemaOutput):
"""Output schema for TextEncoderBlock."""
encoded_text: str = SchemaField(
description="The encoded text with special characters converted to escape sequences"
)
error: str = SchemaField(description="Error message if encoding fails")
def __init__(self):
super().__init__(
id="5185f32e-4b65-4ecf-8fbb-873f003f09d6",
description="Encodes a string by converting special characters into escape sequences",
categories={BlockCategory.TEXT},
input_schema=TextEncoderBlock.Input,
output_schema=TextEncoderBlock.Output,
test_input={
"text": """Hello
World!
This is a "quoted" string."""
},
test_output=[
(
"encoded_text",
"""Hello\\nWorld!\\nThis is a "quoted" string.""",
)
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
"""
Encode the input text by converting special characters to escape sequences.
Args:
input_data: The input containing the text to encode.
**kwargs: Additional keyword arguments (unused).
Yields:
The encoded text with escape sequences, or an error message if encoding fails.
"""
try:
encoded_text = codecs.encode(input_data.text, "unicode_escape").decode(
"utf-8"
)
yield "encoded_text", encoded_text
except Exception as e:
yield "error", f"Encoding error: {str(e)}"

View File

@@ -478,7 +478,7 @@ class ExaCreateOrFindWebsetBlock(Block):
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
try: try:
webset = await aexa.websets.get(id=input_data.external_id) webset = aexa.websets.get(id=input_data.external_id)
webset_result = Webset.model_validate(webset.model_dump(by_alias=True)) webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
yield "webset", webset_result yield "webset", webset_result
@@ -494,7 +494,7 @@ class ExaCreateOrFindWebsetBlock(Block):
count=input_data.search_count, count=input_data.search_count,
) )
webset = await aexa.websets.create( webset = aexa.websets.create(
params=CreateWebsetParameters( params=CreateWebsetParameters(
search=search_params, search=search_params,
external_id=input_data.external_id, external_id=input_data.external_id,
@@ -554,7 +554,7 @@ class ExaUpdateWebsetBlock(Block):
if input_data.metadata is not None: if input_data.metadata is not None:
payload["metadata"] = input_data.metadata payload["metadata"] = input_data.metadata
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload) sdk_webset = aexa.websets.update(id=input_data.webset_id, params=payload)
status_str = ( status_str = (
sdk_webset.status.value sdk_webset.status.value
@@ -617,7 +617,7 @@ class ExaListWebsetsBlock(Block):
) -> BlockOutput: ) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
response = await aexa.websets.list( response = aexa.websets.list(
cursor=input_data.cursor, cursor=input_data.cursor,
limit=input_data.limit, limit=input_data.limit,
) )
@@ -678,7 +678,7 @@ class ExaGetWebsetBlock(Block):
) -> BlockOutput: ) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_webset = await aexa.websets.get(id=input_data.webset_id) sdk_webset = aexa.websets.get(id=input_data.webset_id)
status_str = ( status_str = (
sdk_webset.status.value sdk_webset.status.value
@@ -748,7 +748,7 @@ class ExaDeleteWebsetBlock(Block):
) -> BlockOutput: ) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
deleted_webset = await aexa.websets.delete(id=input_data.webset_id) deleted_webset = aexa.websets.delete(id=input_data.webset_id)
status_str = ( status_str = (
deleted_webset.status.value deleted_webset.status.value
@@ -798,7 +798,7 @@ class ExaCancelWebsetBlock(Block):
) -> BlockOutput: ) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id) canceled_webset = aexa.websets.cancel(id=input_data.webset_id)
status_str = ( status_str = (
canceled_webset.status.value canceled_webset.status.value
@@ -968,7 +968,7 @@ class ExaPreviewWebsetBlock(Block):
entity["description"] = input_data.entity_description entity["description"] = input_data.entity_description
payload["entity"] = entity payload["entity"] = entity
sdk_preview = await aexa.websets.preview(params=payload) sdk_preview = aexa.websets.preview(params=payload)
preview = PreviewWebsetModel.from_sdk(sdk_preview) preview = PreviewWebsetModel.from_sdk(sdk_preview)
@@ -1051,7 +1051,7 @@ class ExaWebsetStatusBlock(Block):
) -> BlockOutput: ) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
webset = await aexa.websets.get(id=input_data.webset_id) webset = aexa.websets.get(id=input_data.webset_id)
status = ( status = (
webset.status.value webset.status.value
@@ -1185,7 +1185,7 @@ class ExaWebsetSummaryBlock(Block):
) -> BlockOutput: ) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
webset = await aexa.websets.get(id=input_data.webset_id) webset = aexa.websets.get(id=input_data.webset_id)
# Extract basic info # Extract basic info
webset_id = webset.id webset_id = webset.id
@@ -1211,7 +1211,7 @@ class ExaWebsetSummaryBlock(Block):
total_items = 0 total_items = 0
if input_data.include_sample_items and input_data.sample_size > 0: if input_data.include_sample_items and input_data.sample_size > 0:
items_response = await aexa.websets.items.list( items_response = aexa.websets.items.list(
webset_id=input_data.webset_id, limit=input_data.sample_size webset_id=input_data.webset_id, limit=input_data.sample_size
) )
sample_items_data = [ sample_items_data = [
@@ -1362,7 +1362,7 @@ class ExaWebsetReadyCheckBlock(Block):
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
# Get webset details # Get webset details
webset = await aexa.websets.get(id=input_data.webset_id) webset = aexa.websets.get(id=input_data.webset_id)
status = ( status = (
webset.status.value webset.status.value

View File

@@ -202,7 +202,7 @@ class ExaCreateEnrichmentBlock(Block):
# Use AsyncExa SDK # Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_enrichment = await aexa.websets.enrichments.create( sdk_enrichment = aexa.websets.enrichments.create(
webset_id=input_data.webset_id, params=payload webset_id=input_data.webset_id, params=payload
) )
@@ -223,7 +223,7 @@ class ExaCreateEnrichmentBlock(Block):
items_enriched = 0 items_enriched = 0
while time.time() - poll_start < input_data.polling_timeout: while time.time() - poll_start < input_data.polling_timeout:
current_enrich = await aexa.websets.enrichments.get( current_enrich = aexa.websets.enrichments.get(
webset_id=input_data.webset_id, id=enrichment_id webset_id=input_data.webset_id, id=enrichment_id
) )
current_status = ( current_status = (
@@ -234,7 +234,7 @@ class ExaCreateEnrichmentBlock(Block):
if current_status in ["completed", "failed", "cancelled"]: if current_status in ["completed", "failed", "cancelled"]:
# Estimate items from webset searches # Estimate items from webset searches
webset = await aexa.websets.get(id=input_data.webset_id) webset = aexa.websets.get(id=input_data.webset_id)
if webset.searches: if webset.searches:
for search in webset.searches: for search in webset.searches:
if search.progress: if search.progress:
@@ -329,7 +329,7 @@ class ExaGetEnrichmentBlock(Block):
# Use AsyncExa SDK # Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_enrichment = await aexa.websets.enrichments.get( sdk_enrichment = aexa.websets.enrichments.get(
webset_id=input_data.webset_id, id=input_data.enrichment_id webset_id=input_data.webset_id, id=input_data.enrichment_id
) )
@@ -474,7 +474,7 @@ class ExaDeleteEnrichmentBlock(Block):
# Use AsyncExa SDK # Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
deleted_enrichment = await aexa.websets.enrichments.delete( deleted_enrichment = aexa.websets.enrichments.delete(
webset_id=input_data.webset_id, id=input_data.enrichment_id webset_id=input_data.webset_id, id=input_data.enrichment_id
) )
@@ -525,13 +525,13 @@ class ExaCancelEnrichmentBlock(Block):
# Use AsyncExa SDK # Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
canceled_enrichment = await aexa.websets.enrichments.cancel( canceled_enrichment = aexa.websets.enrichments.cancel(
webset_id=input_data.webset_id, id=input_data.enrichment_id webset_id=input_data.webset_id, id=input_data.enrichment_id
) )
# Try to estimate how many items were enriched before cancellation # Try to estimate how many items were enriched before cancellation
items_enriched = 0 items_enriched = 0
items_response = await aexa.websets.items.list( items_response = aexa.websets.items.list(
webset_id=input_data.webset_id, limit=100 webset_id=input_data.webset_id, limit=100
) )

View File

@@ -222,7 +222,7 @@ class ExaCreateImportBlock(Block):
def _create_test_mock(): def _create_test_mock():
"""Create test mocks for the AsyncExa SDK.""" """Create test mocks for the AsyncExa SDK."""
from datetime import datetime from datetime import datetime
from unittest.mock import AsyncMock, MagicMock from unittest.mock import MagicMock
# Create mock SDK import object # Create mock SDK import object
mock_import = MagicMock() mock_import = MagicMock()
@@ -247,7 +247,7 @@ class ExaCreateImportBlock(Block):
return { return {
"_get_client": lambda *args, **kwargs: MagicMock( "_get_client": lambda *args, **kwargs: MagicMock(
websets=MagicMock( websets=MagicMock(
imports=MagicMock(create=AsyncMock(return_value=mock_import)) imports=MagicMock(create=lambda *args, **kwargs: mock_import)
) )
) )
} }
@@ -294,7 +294,7 @@ class ExaCreateImportBlock(Block):
if input_data.metadata: if input_data.metadata:
payload["metadata"] = input_data.metadata payload["metadata"] = input_data.metadata
sdk_import = await aexa.websets.imports.create( sdk_import = aexa.websets.imports.create(
params=payload, csv_data=input_data.csv_data params=payload, csv_data=input_data.csv_data
) )
@@ -360,7 +360,7 @@ class ExaGetImportBlock(Block):
# Use AsyncExa SDK # Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id) sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
import_obj = ImportModel.from_sdk(sdk_import) import_obj = ImportModel.from_sdk(sdk_import)
@@ -426,7 +426,7 @@ class ExaListImportsBlock(Block):
# Use AsyncExa SDK # Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
response = await aexa.websets.imports.list( response = aexa.websets.imports.list(
cursor=input_data.cursor, cursor=input_data.cursor,
limit=input_data.limit, limit=input_data.limit,
) )
@@ -474,9 +474,7 @@ class ExaDeleteImportBlock(Block):
# Use AsyncExa SDK # Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
deleted_import = await aexa.websets.imports.delete( deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
import_id=input_data.import_id
)
yield "import_id", deleted_import.id yield "import_id", deleted_import.id
yield "success", "true" yield "success", "true"
@@ -575,14 +573,14 @@ class ExaExportWebsetBlock(Block):
} }
) )
# Create async iterator for list_all # Create mock iterator
async def async_item_iterator(*args, **kwargs): mock_items = [mock_item1, mock_item2]
for item in [mock_item1, mock_item2]:
yield item
return { return {
"_get_client": lambda *args, **kwargs: MagicMock( "_get_client": lambda *args, **kwargs: MagicMock(
websets=MagicMock(items=MagicMock(list_all=async_item_iterator)) websets=MagicMock(
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
)
) )
} }
@@ -604,7 +602,7 @@ class ExaExportWebsetBlock(Block):
webset_id=input_data.webset_id, limit=input_data.max_items webset_id=input_data.webset_id, limit=input_data.max_items
) )
async for sdk_item in item_iterator: for sdk_item in item_iterator:
if len(all_items) >= input_data.max_items: if len(all_items) >= input_data.max_items:
break break

View File

@@ -178,7 +178,7 @@ class ExaGetWebsetItemBlock(Block):
) -> BlockOutput: ) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_item = await aexa.websets.items.get( sdk_item = aexa.websets.items.get(
webset_id=input_data.webset_id, id=input_data.item_id webset_id=input_data.webset_id, id=input_data.item_id
) )
@@ -269,7 +269,7 @@ class ExaListWebsetItemsBlock(Block):
response = None response = None
while time.time() - start_time < input_data.wait_timeout: while time.time() - start_time < input_data.wait_timeout:
response = await aexa.websets.items.list( response = aexa.websets.items.list(
webset_id=input_data.webset_id, webset_id=input_data.webset_id,
cursor=input_data.cursor, cursor=input_data.cursor,
limit=input_data.limit, limit=input_data.limit,
@@ -282,13 +282,13 @@ class ExaListWebsetItemsBlock(Block):
interval = min(interval * 1.2, 10) interval = min(interval * 1.2, 10)
if not response: if not response:
response = await aexa.websets.items.list( response = aexa.websets.items.list(
webset_id=input_data.webset_id, webset_id=input_data.webset_id,
cursor=input_data.cursor, cursor=input_data.cursor,
limit=input_data.limit, limit=input_data.limit,
) )
else: else:
response = await aexa.websets.items.list( response = aexa.websets.items.list(
webset_id=input_data.webset_id, webset_id=input_data.webset_id,
cursor=input_data.cursor, cursor=input_data.cursor,
limit=input_data.limit, limit=input_data.limit,
@@ -340,7 +340,7 @@ class ExaDeleteWebsetItemBlock(Block):
) -> BlockOutput: ) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
deleted_item = await aexa.websets.items.delete( deleted_item = aexa.websets.items.delete(
webset_id=input_data.webset_id, id=input_data.item_id webset_id=input_data.webset_id, id=input_data.item_id
) )
@@ -408,7 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
webset_id=input_data.webset_id, limit=input_data.max_items webset_id=input_data.webset_id, limit=input_data.max_items
) )
async for sdk_item in item_iterator: for sdk_item in item_iterator:
if len(all_items) >= input_data.max_items: if len(all_items) >= input_data.max_items:
break break
@@ -475,7 +475,7 @@ class ExaWebsetItemsSummaryBlock(Block):
# Use AsyncExa SDK # Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
webset = await aexa.websets.get(id=input_data.webset_id) webset = aexa.websets.get(id=input_data.webset_id)
entity_type = "unknown" entity_type = "unknown"
if webset.searches: if webset.searches:
@@ -495,7 +495,7 @@ class ExaWebsetItemsSummaryBlock(Block):
# Get sample items if requested # Get sample items if requested
sample_items: List[WebsetItemModel] = [] sample_items: List[WebsetItemModel] = []
if input_data.sample_size > 0: if input_data.sample_size > 0:
items_response = await aexa.websets.items.list( items_response = aexa.websets.items.list(
webset_id=input_data.webset_id, limit=input_data.sample_size webset_id=input_data.webset_id, limit=input_data.sample_size
) )
# Convert to our stable models # Convert to our stable models
@@ -569,7 +569,7 @@ class ExaGetNewItemsBlock(Block):
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
# Get items starting from cursor # Get items starting from cursor
response = await aexa.websets.items.list( response = aexa.websets.items.list(
webset_id=input_data.webset_id, webset_id=input_data.webset_id,
cursor=input_data.since_cursor, cursor=input_data.since_cursor,
limit=input_data.max_items, limit=input_data.max_items,

View File

@@ -233,7 +233,7 @@ class ExaCreateMonitorBlock(Block):
def _create_test_mock(): def _create_test_mock():
"""Create test mocks for the AsyncExa SDK.""" """Create test mocks for the AsyncExa SDK."""
from datetime import datetime from datetime import datetime
from unittest.mock import AsyncMock, MagicMock from unittest.mock import MagicMock
# Create mock SDK monitor object # Create mock SDK monitor object
mock_monitor = MagicMock() mock_monitor = MagicMock()
@@ -263,7 +263,7 @@ class ExaCreateMonitorBlock(Block):
return { return {
"_get_client": lambda *args, **kwargs: MagicMock( "_get_client": lambda *args, **kwargs: MagicMock(
websets=MagicMock( websets=MagicMock(
monitors=MagicMock(create=AsyncMock(return_value=mock_monitor)) monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
) )
) )
} }
@@ -320,7 +320,7 @@ class ExaCreateMonitorBlock(Block):
if input_data.metadata: if input_data.metadata:
payload["metadata"] = input_data.metadata payload["metadata"] = input_data.metadata
sdk_monitor = await aexa.websets.monitors.create(params=payload) sdk_monitor = aexa.websets.monitors.create(params=payload)
monitor = MonitorModel.from_sdk(sdk_monitor) monitor = MonitorModel.from_sdk(sdk_monitor)
@@ -384,7 +384,7 @@ class ExaGetMonitorBlock(Block):
# Use AsyncExa SDK # Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id) sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
monitor = MonitorModel.from_sdk(sdk_monitor) monitor = MonitorModel.from_sdk(sdk_monitor)
@@ -476,7 +476,7 @@ class ExaUpdateMonitorBlock(Block):
if input_data.metadata is not None: if input_data.metadata is not None:
payload["metadata"] = input_data.metadata payload["metadata"] = input_data.metadata
sdk_monitor = await aexa.websets.monitors.update( sdk_monitor = aexa.websets.monitors.update(
monitor_id=input_data.monitor_id, params=payload monitor_id=input_data.monitor_id, params=payload
) )
@@ -522,9 +522,7 @@ class ExaDeleteMonitorBlock(Block):
# Use AsyncExa SDK # Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
deleted_monitor = await aexa.websets.monitors.delete( deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
monitor_id=input_data.monitor_id
)
yield "monitor_id", deleted_monitor.id yield "monitor_id", deleted_monitor.id
yield "success", "true" yield "success", "true"
@@ -581,7 +579,7 @@ class ExaListMonitorsBlock(Block):
# Use AsyncExa SDK # Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
response = await aexa.websets.monitors.list( response = aexa.websets.monitors.list(
cursor=input_data.cursor, cursor=input_data.cursor,
limit=input_data.limit, limit=input_data.limit,
webset_id=input_data.webset_id, webset_id=input_data.webset_id,

View File

@@ -121,7 +121,7 @@ class ExaWaitForWebsetBlock(Block):
WebsetTargetStatus.IDLE, WebsetTargetStatus.IDLE,
WebsetTargetStatus.ANY_COMPLETE, WebsetTargetStatus.ANY_COMPLETE,
]: ]:
final_webset = await aexa.websets.wait_until_idle( final_webset = aexa.websets.wait_until_idle(
id=input_data.webset_id, id=input_data.webset_id,
timeout=input_data.timeout, timeout=input_data.timeout,
poll_interval=input_data.check_interval, poll_interval=input_data.check_interval,
@@ -164,7 +164,7 @@ class ExaWaitForWebsetBlock(Block):
interval = input_data.check_interval interval = input_data.check_interval
while time.time() - start_time < input_data.timeout: while time.time() - start_time < input_data.timeout:
# Get current webset status # Get current webset status
webset = await aexa.websets.get(id=input_data.webset_id) webset = aexa.websets.get(id=input_data.webset_id)
current_status = ( current_status = (
webset.status.value webset.status.value
if hasattr(webset.status, "value") if hasattr(webset.status, "value")
@@ -209,7 +209,7 @@ class ExaWaitForWebsetBlock(Block):
# Timeout reached # Timeout reached
elapsed = time.time() - start_time elapsed = time.time() - start_time
webset = await aexa.websets.get(id=input_data.webset_id) webset = aexa.websets.get(id=input_data.webset_id)
final_status = ( final_status = (
webset.status.value webset.status.value
if hasattr(webset.status, "value") if hasattr(webset.status, "value")
@@ -345,7 +345,7 @@ class ExaWaitForSearchBlock(Block):
try: try:
while time.time() - start_time < input_data.timeout: while time.time() - start_time < input_data.timeout:
# Get current search status using SDK # Get current search status using SDK
search = await aexa.websets.searches.get( search = aexa.websets.searches.get(
webset_id=input_data.webset_id, id=input_data.search_id webset_id=input_data.webset_id, id=input_data.search_id
) )
@@ -401,7 +401,7 @@ class ExaWaitForSearchBlock(Block):
elapsed = time.time() - start_time elapsed = time.time() - start_time
# Get last known status # Get last known status
search = await aexa.websets.searches.get( search = aexa.websets.searches.get(
webset_id=input_data.webset_id, id=input_data.search_id webset_id=input_data.webset_id, id=input_data.search_id
) )
final_status = ( final_status = (
@@ -503,7 +503,7 @@ class ExaWaitForEnrichmentBlock(Block):
try: try:
while time.time() - start_time < input_data.timeout: while time.time() - start_time < input_data.timeout:
# Get current enrichment status using SDK # Get current enrichment status using SDK
enrichment = await aexa.websets.enrichments.get( enrichment = aexa.websets.enrichments.get(
webset_id=input_data.webset_id, id=input_data.enrichment_id webset_id=input_data.webset_id, id=input_data.enrichment_id
) )
@@ -548,7 +548,7 @@ class ExaWaitForEnrichmentBlock(Block):
elapsed = time.time() - start_time elapsed = time.time() - start_time
# Get last known status # Get last known status
enrichment = await aexa.websets.enrichments.get( enrichment = aexa.websets.enrichments.get(
webset_id=input_data.webset_id, id=input_data.enrichment_id webset_id=input_data.webset_id, id=input_data.enrichment_id
) )
final_status = ( final_status = (
@@ -575,7 +575,7 @@ class ExaWaitForEnrichmentBlock(Block):
) -> tuple[list[SampleEnrichmentModel], int]: ) -> tuple[list[SampleEnrichmentModel], int]:
"""Get sample enriched data and count.""" """Get sample enriched data and count."""
# Get a few items to see enrichment results using SDK # Get a few items to see enrichment results using SDK
response = await aexa.websets.items.list(webset_id=webset_id, limit=5) response = aexa.websets.items.list(webset_id=webset_id, limit=5)
sample_data: list[SampleEnrichmentModel] = [] sample_data: list[SampleEnrichmentModel] = []
enriched_count = 0 enriched_count = 0

View File

@@ -317,7 +317,7 @@ class ExaCreateWebsetSearchBlock(Block):
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_search = await aexa.websets.searches.create( sdk_search = aexa.websets.searches.create(
webset_id=input_data.webset_id, params=payload webset_id=input_data.webset_id, params=payload
) )
@@ -350,7 +350,7 @@ class ExaCreateWebsetSearchBlock(Block):
poll_start = time.time() poll_start = time.time()
while time.time() - poll_start < input_data.polling_timeout: while time.time() - poll_start < input_data.polling_timeout:
current_search = await aexa.websets.searches.get( current_search = aexa.websets.searches.get(
webset_id=input_data.webset_id, id=search_id webset_id=input_data.webset_id, id=search_id
) )
current_status = ( current_status = (
@@ -442,7 +442,7 @@ class ExaGetWebsetSearchBlock(Block):
# Use AsyncExa SDK # Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_search = await aexa.websets.searches.get( sdk_search = aexa.websets.searches.get(
webset_id=input_data.webset_id, id=input_data.search_id webset_id=input_data.webset_id, id=input_data.search_id
) )
@@ -523,7 +523,7 @@ class ExaCancelWebsetSearchBlock(Block):
# Use AsyncExa SDK # Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
canceled_search = await aexa.websets.searches.cancel( canceled_search = aexa.websets.searches.cancel(
webset_id=input_data.webset_id, id=input_data.search_id webset_id=input_data.webset_id, id=input_data.search_id
) )
@@ -604,7 +604,7 @@ class ExaFindOrCreateSearchBlock(Block):
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
# Get webset to check existing searches # Get webset to check existing searches
webset = await aexa.websets.get(id=input_data.webset_id) webset = aexa.websets.get(id=input_data.webset_id)
# Look for existing search with same query # Look for existing search with same query
existing_search = None existing_search = None
@@ -636,7 +636,7 @@ class ExaFindOrCreateSearchBlock(Block):
if input_data.entity_type != SearchEntityType.AUTO: if input_data.entity_type != SearchEntityType.AUTO:
payload["entity"] = {"type": input_data.entity_type.value} payload["entity"] = {"type": input_data.entity_type.value}
sdk_search = await aexa.websets.searches.create( sdk_search = aexa.websets.searches.create(
webset_id=input_data.webset_id, params=payload webset_id=input_data.webset_id, params=payload
) )

View File

@@ -162,16 +162,8 @@ class LinearClient:
"searchTerm": team_name, "searchTerm": team_name,
} }
result = await self.query(query, variables) team_id = await self.query(query, variables)
nodes = result["teams"]["nodes"] return team_id["teams"]["nodes"][0]["id"]
if not nodes:
raise LinearAPIException(
f"Team '{team_name}' not found. Check the team name or key and try again.",
status_code=404,
)
return nodes[0]["id"]
except LinearAPIException as e: except LinearAPIException as e:
raise e raise e
@@ -248,44 +240,17 @@ class LinearClient:
except LinearAPIException as e: except LinearAPIException as e:
raise e raise e
async def try_search_issues( async def try_search_issues(self, term: str) -> list[Issue]:
self,
term: str,
max_results: int = 10,
team_id: str | None = None,
) -> list[Issue]:
try: try:
query = """ query = """
query SearchIssues( query SearchIssues($term: String!, $includeComments: Boolean!) {
$term: String!, searchIssues(term: $term, includeComments: $includeComments) {
$first: Int,
$teamId: String
) {
searchIssues(
term: $term,
first: $first,
teamId: $teamId
) {
nodes { nodes {
id id
identifier identifier
title title
description description
priority priority
createdAt
state {
id
name
type
}
project {
id
name
}
assignee {
id
name
}
} }
} }
} }
@@ -293,8 +258,7 @@ class LinearClient:
variables: dict[str, Any] = { variables: dict[str, Any] = {
"term": term, "term": term,
"first": max_results, "includeComments": True,
"teamId": team_id,
} }
issues = await self.query(query, variables) issues = await self.query(query, variables)

View File

@@ -17,7 +17,7 @@ from ._config import (
LinearScope, LinearScope,
linear, linear,
) )
from .models import CreateIssueResponse, Issue, State from .models import CreateIssueResponse, Issue
class LinearCreateIssueBlock(Block): class LinearCreateIssueBlock(Block):
@@ -135,20 +135,9 @@ class LinearSearchIssuesBlock(Block):
description="Linear credentials with read permissions", description="Linear credentials with read permissions",
required_scopes={LinearScope.READ}, required_scopes={LinearScope.READ},
) )
max_results: int = SchemaField(
description="Maximum number of results to return",
default=10,
ge=1,
le=100,
)
team_name: str | None = SchemaField(
description="Optional team name to filter results (e.g., 'Internal', 'Open Source')",
default=None,
)
class Output(BlockSchemaOutput): class Output(BlockSchemaOutput):
issues: list[Issue] = SchemaField(description="List of issues") issues: list[Issue] = SchemaField(description="List of issues")
error: str = SchemaField(description="Error message if the search failed")
def __init__(self): def __init__(self):
super().__init__( super().__init__(
@@ -156,11 +145,8 @@ class LinearSearchIssuesBlock(Block):
description="Searches for issues on Linear", description="Searches for issues on Linear",
input_schema=self.Input, input_schema=self.Input,
output_schema=self.Output, output_schema=self.Output,
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
test_input={ test_input={
"term": "Test issue", "term": "Test issue",
"max_results": 10,
"team_name": None,
"credentials": TEST_CREDENTIALS_INPUT_OAUTH, "credentials": TEST_CREDENTIALS_INPUT_OAUTH,
}, },
test_credentials=TEST_CREDENTIALS_OAUTH, test_credentials=TEST_CREDENTIALS_OAUTH,
@@ -170,14 +156,10 @@ class LinearSearchIssuesBlock(Block):
[ [
Issue( Issue(
id="abc123", id="abc123",
identifier="TST-123", identifier="abc123",
title="Test issue", title="Test issue",
description="Test description", description="Test description",
priority=1, priority=1,
state=State(
id="state1", name="In Progress", type="started"
),
createdAt="2026-01-15T10:00:00.000Z",
) )
], ],
) )
@@ -186,12 +168,10 @@ class LinearSearchIssuesBlock(Block):
"search_issues": lambda *args, **kwargs: [ "search_issues": lambda *args, **kwargs: [
Issue( Issue(
id="abc123", id="abc123",
identifier="TST-123", identifier="abc123",
title="Test issue", title="Test issue",
description="Test description", description="Test description",
priority=1, priority=1,
state=State(id="state1", name="In Progress", type="started"),
createdAt="2026-01-15T10:00:00.000Z",
) )
] ]
}, },
@@ -201,22 +181,10 @@ class LinearSearchIssuesBlock(Block):
async def search_issues( async def search_issues(
credentials: OAuth2Credentials | APIKeyCredentials, credentials: OAuth2Credentials | APIKeyCredentials,
term: str, term: str,
max_results: int = 10,
team_name: str | None = None,
) -> list[Issue]: ) -> list[Issue]:
client = LinearClient(credentials=credentials) client = LinearClient(credentials=credentials)
response: list[Issue] = await client.try_search_issues(term=term)
# Resolve team name to ID if provided return response
# Raises LinearAPIException with descriptive message if team not found
team_id: str | None = None
if team_name:
team_id = await client.try_get_team_by_name(team_name=team_name)
return await client.try_search_issues(
term=term,
max_results=max_results,
team_id=team_id,
)
async def run( async def run(
self, self,
@@ -228,10 +196,7 @@ class LinearSearchIssuesBlock(Block):
"""Execute the issue search""" """Execute the issue search"""
try: try:
issues = await self.search_issues( issues = await self.search_issues(
credentials=credentials, credentials=credentials, term=input_data.term
term=input_data.term,
max_results=input_data.max_results,
team_name=input_data.team_name,
) )
yield "issues", issues yield "issues", issues
except LinearAPIException as e: except LinearAPIException as e:

View File

@@ -36,21 +36,12 @@ class Project(BaseModel):
content: str | None = None content: str | None = None
class State(BaseModel):
id: str
name: str
type: str | None = (
None # Workflow state type (e.g., "triage", "backlog", "started", "completed", "canceled")
)
class Issue(BaseModel): class Issue(BaseModel):
id: str id: str
identifier: str identifier: str
title: str title: str
description: str | None description: str | None
priority: int priority: int
state: State | None = None
project: Project | None = None project: Project | None = None
createdAt: str | None = None createdAt: str | None = None
comments: list[Comment] | None = None comments: list[Comment] | None = None

View File

@@ -115,7 +115,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101" CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929" CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001" CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_4_6_OPUS = "claude-opus-4-6"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307" CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# AI/ML API models # AI/ML API models
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo" AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
@@ -271,9 +270,6 @@ MODEL_METADATA = {
LlmModel.CLAUDE_4_SONNET: ModelMetadata( LlmModel.CLAUDE_4_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2 "anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
), # claude-4-sonnet-20250514 ), # claude-4-sonnet-20250514
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
), # claude-opus-4-6
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata( LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3 "anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
), # claude-opus-4-5-20251101 ), # claude-opus-4-5-20251101
@@ -531,12 +527,12 @@ class LLMResponse(BaseModel):
def convert_openai_tool_fmt_to_anthropic( def convert_openai_tool_fmt_to_anthropic(
openai_tools: list[dict] | None = None, openai_tools: list[dict] | None = None,
) -> Iterable[ToolParam] | anthropic.Omit: ) -> Iterable[ToolParam] | anthropic.NotGiven:
""" """
Convert OpenAI tool format to Anthropic tool format. Convert OpenAI tool format to Anthropic tool format.
""" """
if not openai_tools or len(openai_tools) == 0: if not openai_tools or len(openai_tools) == 0:
return anthropic.omit return anthropic.NOT_GIVEN
anthropic_tools = [] anthropic_tools = []
for tool in openai_tools: for tool in openai_tools:
@@ -596,10 +592,10 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
def get_parallel_tool_calls_param( def get_parallel_tool_calls_param(
llm_model: LlmModel, parallel_tool_calls: bool | None llm_model: LlmModel, parallel_tool_calls: bool | None
) -> bool | openai.Omit: ):
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs.""" """Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
if llm_model.startswith("o") or parallel_tool_calls is None: if llm_model.startswith("o") or parallel_tool_calls is None:
return openai.omit return openai.NOT_GIVEN
return parallel_tool_calls return parallel_tool_calls

View File

@@ -0,0 +1,246 @@
import os
import tempfile
from typing import Optional
from moviepy.audio.io.AudioFileClip import AudioFileClip
from moviepy.video.fx.Loop import Loop
from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class MediaDurationBlock(Block):
class Input(BlockSchemaInput):
media_in: MediaFileType = SchemaField(
description="Media input (URL, data URI, or local path)."
)
is_video: bool = SchemaField(
description="Whether the media is a video (True) or audio (False).",
default=True,
)
class Output(BlockSchemaOutput):
duration: float = SchemaField(
description="Duration of the media file (in seconds)."
)
def __init__(self):
super().__init__(
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
description="Block to get the duration of a media file.",
categories={BlockCategory.MULTIMEDIA},
input_schema=MediaDurationBlock.Input,
output_schema=MediaDurationBlock.Output,
)
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# 1) Store the input media locally
local_media_path = await store_media_file(
file=input_data.media_in,
execution_context=execution_context,
return_format="for_local_processing",
)
assert execution_context.graph_exec_id is not None
media_abspath = get_exec_file_path(
execution_context.graph_exec_id, local_media_path
)
# 2) Load the clip
if input_data.is_video:
clip = VideoFileClip(media_abspath)
else:
clip = AudioFileClip(media_abspath)
yield "duration", clip.duration
class LoopVideoBlock(Block):
"""
Block for looping (repeating) a video clip until a given duration or number of loops.
"""
class Input(BlockSchemaInput):
video_in: MediaFileType = SchemaField(
description="The input video (can be a URL, data URI, or local path)."
)
# Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`.
duration: Optional[float] = SchemaField(
description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.",
default=None,
ge=0.0,
)
n_loops: Optional[int] = SchemaField(
description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).",
default=None,
ge=1,
)
class Output(BlockSchemaOutput):
video_out: str = SchemaField(
description="Looped video returned either as a relative path or a data URI."
)
def __init__(self):
super().__init__(
id="8bf9eef6-5451-4213-b265-25306446e94b",
description="Block to loop a video to a given duration or number of repeats.",
categories={BlockCategory.MULTIMEDIA},
input_schema=LoopVideoBlock.Input,
output_schema=LoopVideoBlock.Output,
)
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
assert execution_context.graph_exec_id is not None
assert execution_context.node_exec_id is not None
graph_exec_id = execution_context.graph_exec_id
node_exec_id = execution_context.node_exec_id
# 1) Store the input video locally
local_video_path = await store_media_file(
file=input_data.video_in,
execution_context=execution_context,
return_format="for_local_processing",
)
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
# 2) Load the clip
clip = VideoFileClip(input_abspath)
# 3) Apply the loop effect
looped_clip = clip
if input_data.duration:
# Loop until we reach the specified duration
looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)])
elif input_data.n_loops:
looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)])
else:
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
assert isinstance(looped_clip, VideoFileClip)
# 4) Save the looped output
output_filename = MediaFileType(
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
)
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
looped_clip = looped_clip.with_audio(clip.audio)
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
# Return output - for_block_output returns workspace:// if available, else data URI
video_out = await store_media_file(
file=output_filename,
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_out", video_out
class AddAudioToVideoBlock(Block):
"""
Block that adds (attaches) an audio track to an existing video.
Optionally scale the volume of the new track.
"""
class Input(BlockSchemaInput):
video_in: MediaFileType = SchemaField(
description="Video input (URL, data URI, or local path)."
)
audio_in: MediaFileType = SchemaField(
description="Audio input (URL, data URI, or local path)."
)
volume: float = SchemaField(
description="Volume scale for the newly attached audio track (1.0 = original).",
default=1.0,
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
description="Final video (with attached audio), as a path or data URI."
)
def __init__(self):
super().__init__(
id="3503748d-62b6-4425-91d6-725b064af509",
description="Block to attach an audio file to a video file using moviepy.",
categories={BlockCategory.MULTIMEDIA},
input_schema=AddAudioToVideoBlock.Input,
output_schema=AddAudioToVideoBlock.Output,
)
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
assert execution_context.graph_exec_id is not None
assert execution_context.node_exec_id is not None
graph_exec_id = execution_context.graph_exec_id
node_exec_id = execution_context.node_exec_id
# 1) Store the inputs locally
local_video_path = await store_media_file(
file=input_data.video_in,
execution_context=execution_context,
return_format="for_local_processing",
)
local_audio_path = await store_media_file(
file=input_data.audio_in,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
video_abspath = os.path.join(abs_temp_dir, local_video_path)
audio_abspath = os.path.join(abs_temp_dir, local_audio_path)
# 2) Load video + audio with moviepy
video_clip = VideoFileClip(video_abspath)
audio_clip = AudioFileClip(audio_abspath)
# Optionally scale volume
if input_data.volume != 1.0:
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
# 3) Attach the new audio track
final_clip = video_clip.with_audio(audio_clip)
# 4) Write to output file
output_filename = MediaFileType(
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
)
output_abspath = os.path.join(abs_temp_dir, output_filename)
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
# 5) Return output - for_block_output returns workspace:// if available, else data URI
video_out = await store_media_file(
file=output_filename,
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_out", video_out

View File

@@ -1,77 +0,0 @@
import pytest
from backend.blocks.encoder_block import TextEncoderBlock
@pytest.mark.asyncio
async def test_text_encoder_basic():
"""Test basic encoding of newlines and special characters."""
block = TextEncoderBlock()
result = []
async for output in block.run(TextEncoderBlock.Input(text="Hello\nWorld")):
result.append(output)
assert len(result) == 1
assert result[0][0] == "encoded_text"
assert result[0][1] == "Hello\\nWorld"
@pytest.mark.asyncio
async def test_text_encoder_multiple_escapes():
"""Test encoding of multiple escape sequences."""
block = TextEncoderBlock()
result = []
async for output in block.run(
TextEncoderBlock.Input(text="Line1\nLine2\tTabbed\rCarriage")
):
result.append(output)
assert len(result) == 1
assert result[0][0] == "encoded_text"
assert "\\n" in result[0][1]
assert "\\t" in result[0][1]
assert "\\r" in result[0][1]
@pytest.mark.asyncio
async def test_text_encoder_unicode():
"""Test that unicode characters are handled correctly."""
block = TextEncoderBlock()
result = []
async for output in block.run(TextEncoderBlock.Input(text="Hello 世界\n")):
result.append(output)
assert len(result) == 1
assert result[0][0] == "encoded_text"
# Unicode characters should be escaped as \uXXXX sequences
assert "\\n" in result[0][1]
@pytest.mark.asyncio
async def test_text_encoder_empty_string():
"""Test encoding of an empty string."""
block = TextEncoderBlock()
result = []
async for output in block.run(TextEncoderBlock.Input(text="")):
result.append(output)
assert len(result) == 1
assert result[0][0] == "encoded_text"
assert result[0][1] == ""
@pytest.mark.asyncio
async def test_text_encoder_error_handling():
"""Test that encoding errors are handled gracefully."""
from unittest.mock import patch
block = TextEncoderBlock()
result = []
with patch("codecs.encode", side_effect=Exception("Mocked encoding error")):
async for output in block.run(TextEncoderBlock.Input(text="test")):
result.append(output)
assert len(result) == 1
assert result[0][0] == "error"
assert "Mocked encoding error" in result[0][1]

View File

@@ -1,37 +0,0 @@
"""Video editing blocks for AutoGPT Platform.
This module provides blocks for:
- Downloading videos from URLs (YouTube, Vimeo, news sites, direct links)
- Clipping/trimming video segments
- Concatenating multiple videos
- Adding text overlays
- Adding AI-generated narration
- Getting media duration
- Looping videos
- Adding audio to videos
Dependencies:
- yt-dlp: For video downloading
- moviepy: For video editing operations
- elevenlabs: For AI narration (optional)
"""
from backend.blocks.video.add_audio import AddAudioToVideoBlock
from backend.blocks.video.clip import VideoClipBlock
from backend.blocks.video.concat import VideoConcatBlock
from backend.blocks.video.download import VideoDownloadBlock
from backend.blocks.video.duration import MediaDurationBlock
from backend.blocks.video.loop import LoopVideoBlock
from backend.blocks.video.narration import VideoNarrationBlock
from backend.blocks.video.text_overlay import VideoTextOverlayBlock
__all__ = [
"AddAudioToVideoBlock",
"LoopVideoBlock",
"MediaDurationBlock",
"VideoClipBlock",
"VideoConcatBlock",
"VideoDownloadBlock",
"VideoNarrationBlock",
"VideoTextOverlayBlock",
]

View File

@@ -1,131 +0,0 @@
"""Shared utilities for video blocks."""
from __future__ import annotations
import logging
import os
import re
import subprocess
from pathlib import Path
logger = logging.getLogger(__name__)
# Known operation tags added by video blocks
_VIDEO_OPS = (
r"(?:clip|overlay|narrated|looped|concat|audio_attached|with_audio|narration)"
)
# Matches: {node_exec_id}_{operation}_ where node_exec_id contains a UUID
_BLOCK_PREFIX_RE = re.compile(
r"^[a-zA-Z0-9_-]*"
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
r"[a-zA-Z0-9_-]*"
r"_" + _VIDEO_OPS + r"_"
)
# Matches: a lone {node_exec_id}_ prefix (no operation keyword, e.g. download output)
_UUID_PREFIX_RE = re.compile(
r"^[a-zA-Z0-9_-]*"
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
r"[a-zA-Z0-9_-]*_"
)
def extract_source_name(input_path: str, max_length: int = 50) -> str:
"""Extract the original source filename by stripping block-generated prefixes.
Iteratively removes {node_exec_id}_{operation}_ prefixes that accumulate
when chaining video blocks, recovering the original human-readable name.
Safe for plain filenames (no UUID -> no stripping).
Falls back to "video" if everything is stripped.
"""
stem = Path(input_path).stem
# Pass 1: strip {node_exec_id}_{operation}_ prefixes iteratively
while _BLOCK_PREFIX_RE.match(stem):
stem = _BLOCK_PREFIX_RE.sub("", stem, count=1)
# Pass 2: strip a lone {node_exec_id}_ prefix (e.g. from download block)
if _UUID_PREFIX_RE.match(stem):
stem = _UUID_PREFIX_RE.sub("", stem, count=1)
if not stem:
return "video"
return stem[:max_length]
def get_video_codecs(output_path: str) -> tuple[str, str]:
"""Get appropriate video and audio codecs based on output file extension.
Args:
output_path: Path to the output file (used to determine extension)
Returns:
Tuple of (video_codec, audio_codec)
Codec mappings:
- .mp4: H.264 + AAC (universal compatibility)
- .webm: VP8 + Vorbis (web streaming)
- .mkv: H.264 + AAC (container supports many codecs)
- .mov: H.264 + AAC (Apple QuickTime, widely compatible)
- .m4v: H.264 + AAC (Apple iTunes/devices)
- .avi: MPEG-4 + MP3 (legacy Windows)
"""
ext = os.path.splitext(output_path)[1].lower()
codec_map: dict[str, tuple[str, str]] = {
".mp4": ("libx264", "aac"),
".webm": ("libvpx", "libvorbis"),
".mkv": ("libx264", "aac"),
".mov": ("libx264", "aac"),
".m4v": ("libx264", "aac"),
".avi": ("mpeg4", "libmp3lame"),
}
return codec_map.get(ext, ("libx264", "aac"))
def strip_chapters_inplace(video_path: str) -> None:
"""Strip chapter metadata from a media file in-place using ffmpeg.
MoviePy 2.x crashes with IndexError when parsing files with embedded
chapter metadata (https://github.com/Zulko/moviepy/issues/2419).
This strips chapters without re-encoding.
Args:
video_path: Absolute path to the media file to strip chapters from.
"""
base, ext = os.path.splitext(video_path)
tmp_path = base + ".tmp" + ext
try:
result = subprocess.run(
[
"ffmpeg",
"-y",
"-i",
video_path,
"-map_chapters",
"-1",
"-codec",
"copy",
tmp_path,
],
capture_output=True,
text=True,
timeout=300,
)
if result.returncode != 0:
logger.warning(
"ffmpeg chapter strip failed (rc=%d): %s",
result.returncode,
result.stderr,
)
return
os.replace(tmp_path, video_path)
except FileNotFoundError:
logger.warning("ffmpeg not found; skipping chapter strip")
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)

View File

@@ -1,113 +0,0 @@
"""AddAudioToVideoBlock - Attach an audio track to a video file."""
from moviepy.audio.io.AudioFileClip import AudioFileClip
from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class AddAudioToVideoBlock(Block):
"""Add (attach) an audio track to an existing video."""
class Input(BlockSchemaInput):
video_in: MediaFileType = SchemaField(
description="Video input (URL, data URI, or local path)."
)
audio_in: MediaFileType = SchemaField(
description="Audio input (URL, data URI, or local path)."
)
volume: float = SchemaField(
description="Volume scale for the newly attached audio track (1.0 = original).",
default=1.0,
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
description="Final video (with attached audio), as a path or data URI."
)
def __init__(self):
super().__init__(
id="3503748d-62b6-4425-91d6-725b064af509",
description="Block to attach an audio file to a video file using moviepy.",
categories={BlockCategory.MULTIMEDIA},
input_schema=AddAudioToVideoBlock.Input,
output_schema=AddAudioToVideoBlock.Output,
)
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
assert execution_context.graph_exec_id is not None
assert execution_context.node_exec_id is not None
graph_exec_id = execution_context.graph_exec_id
node_exec_id = execution_context.node_exec_id
# 1) Store the inputs locally
local_video_path = await store_media_file(
file=input_data.video_in,
execution_context=execution_context,
return_format="for_local_processing",
)
local_audio_path = await store_media_file(
file=input_data.audio_in,
execution_context=execution_context,
return_format="for_local_processing",
)
video_abspath = get_exec_file_path(graph_exec_id, local_video_path)
audio_abspath = get_exec_file_path(graph_exec_id, local_audio_path)
# 2) Load video + audio with moviepy
strip_chapters_inplace(video_abspath)
strip_chapters_inplace(audio_abspath)
video_clip = None
audio_clip = None
final_clip = None
try:
video_clip = VideoFileClip(video_abspath)
audio_clip = AudioFileClip(audio_abspath)
# Optionally scale volume
if input_data.volume != 1.0:
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
# 3) Attach the new audio track
final_clip = video_clip.with_audio(audio_clip)
# 4) Write to output file
source = extract_source_name(local_video_path)
output_filename = MediaFileType(f"{node_exec_id}_with_audio_{source}.mp4")
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
final_clip.write_videofile(
output_abspath, codec="libx264", audio_codec="aac"
)
finally:
if final_clip:
final_clip.close()
if audio_clip:
audio_clip.close()
if video_clip:
video_clip.close()
# 5) Return output - for_block_output returns workspace:// if available, else data URI
video_out = await store_media_file(
file=output_filename,
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_out", video_out

View File

@@ -1,167 +0,0 @@
"""VideoClipBlock - Extract a segment from a video file."""
from typing import Literal
from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.blocks.video._utils import (
extract_source_name,
get_video_codecs,
strip_chapters_inplace,
)
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class VideoClipBlock(Block):
"""Extract a time segment from a video."""
class Input(BlockSchemaInput):
video_in: MediaFileType = SchemaField(
description="Input video (URL, data URI, or local path)"
)
start_time: float = SchemaField(description="Start time in seconds", ge=0.0)
end_time: float = SchemaField(description="End time in seconds", ge=0.0)
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
description="Output format", default="mp4", advanced=True
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
description="Clipped video file (path or data URI)"
)
duration: float = SchemaField(description="Clip duration in seconds")
def __init__(self):
super().__init__(
id="8f539119-e580-4d86-ad41-86fbcb22abb1",
description="Extract a time segment from a video",
categories={BlockCategory.MULTIMEDIA},
input_schema=self.Input,
output_schema=self.Output,
test_input={
"video_in": "/tmp/test.mp4",
"start_time": 0.0,
"end_time": 10.0,
},
test_output=[("video_out", str), ("duration", float)],
test_mock={
"_clip_video": lambda *args: 10.0,
"_store_input_video": lambda *args, **kwargs: "test.mp4",
"_store_output_video": lambda *args, **kwargs: "clip_test.mp4",
},
)
async def _store_input_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store input video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_local_processing",
)
async def _store_output_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store output video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_block_output",
)
def _clip_video(
self,
video_abspath: str,
output_abspath: str,
start_time: float,
end_time: float,
) -> float:
"""Extract a clip from a video. Extracted for testability."""
clip = None
subclip = None
try:
strip_chapters_inplace(video_abspath)
clip = VideoFileClip(video_abspath)
subclip = clip.subclipped(start_time, end_time)
video_codec, audio_codec = get_video_codecs(output_abspath)
subclip.write_videofile(
output_abspath, codec=video_codec, audio_codec=audio_codec
)
return subclip.duration
finally:
if subclip:
subclip.close()
if clip:
clip.close()
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
node_exec_id: str,
**kwargs,
) -> BlockOutput:
# Validate time range
if input_data.end_time <= input_data.start_time:
raise BlockExecutionError(
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
block_name=self.name,
block_id=str(self.id),
)
try:
assert execution_context.graph_exec_id is not None
# Store the input video locally
local_video_path = await self._store_input_video(
execution_context, input_data.video_in
)
video_abspath = get_exec_file_path(
execution_context.graph_exec_id, local_video_path
)
# Build output path
source = extract_source_name(local_video_path)
output_filename = MediaFileType(
f"{node_exec_id}_clip_{source}.{input_data.output_format}"
)
output_abspath = get_exec_file_path(
execution_context.graph_exec_id, output_filename
)
duration = self._clip_video(
video_abspath,
output_abspath,
input_data.start_time,
input_data.end_time,
)
# Return as workspace path or data URI based on context
video_out = await self._store_output_video(
execution_context, output_filename
)
yield "video_out", video_out
yield "duration", duration
except BlockExecutionError:
raise
except Exception as e:
raise BlockExecutionError(
message=f"Failed to clip video: {e}",
block_name=self.name,
block_id=str(self.id),
) from e

View File

@@ -1,227 +0,0 @@
"""VideoConcatBlock - Concatenate multiple video clips into one."""
from typing import Literal
from moviepy import concatenate_videoclips
from moviepy.video.fx import CrossFadeIn, CrossFadeOut, FadeIn, FadeOut
from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.blocks.video._utils import (
extract_source_name,
get_video_codecs,
strip_chapters_inplace,
)
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class VideoConcatBlock(Block):
"""Merge multiple video clips into one continuous video."""
class Input(BlockSchemaInput):
videos: list[MediaFileType] = SchemaField(
description="List of video files to concatenate (in order)"
)
transition: Literal["none", "crossfade", "fade_black"] = SchemaField(
description="Transition between clips", default="none"
)
transition_duration: int = SchemaField(
description="Transition duration in seconds",
default=1,
ge=0,
advanced=True,
)
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
description="Output format", default="mp4", advanced=True
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
description="Concatenated video file (path or data URI)"
)
total_duration: float = SchemaField(description="Total duration in seconds")
def __init__(self):
super().__init__(
id="9b0f531a-1118-487f-aeec-3fa63ea8900a",
description="Merge multiple video clips into one continuous video",
categories={BlockCategory.MULTIMEDIA},
input_schema=self.Input,
output_schema=self.Output,
test_input={
"videos": ["/tmp/a.mp4", "/tmp/b.mp4"],
},
test_output=[
("video_out", str),
("total_duration", float),
],
test_mock={
"_concat_videos": lambda *args: 20.0,
"_store_input_video": lambda *args, **kwargs: "test.mp4",
"_store_output_video": lambda *args, **kwargs: "concat_test.mp4",
},
)
async def _store_input_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store input video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_local_processing",
)
async def _store_output_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store output video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_block_output",
)
def _concat_videos(
self,
video_abspaths: list[str],
output_abspath: str,
transition: str,
transition_duration: int,
) -> float:
"""Concatenate videos. Extracted for testability.
Returns:
Total duration of the concatenated video.
"""
clips = []
faded_clips = []
final = None
try:
# Load clips
for v in video_abspaths:
strip_chapters_inplace(v)
clips.append(VideoFileClip(v))
# Validate transition_duration against shortest clip
if transition in {"crossfade", "fade_black"} and transition_duration > 0:
min_duration = min(c.duration for c in clips)
if transition_duration >= min_duration:
raise BlockExecutionError(
message=(
f"transition_duration ({transition_duration}s) must be "
f"shorter than the shortest clip ({min_duration:.2f}s)"
),
block_name=self.name,
block_id=str(self.id),
)
if transition == "crossfade":
for i, clip in enumerate(clips):
effects = []
if i > 0:
effects.append(CrossFadeIn(transition_duration))
if i < len(clips) - 1:
effects.append(CrossFadeOut(transition_duration))
if effects:
clip = clip.with_effects(effects)
faded_clips.append(clip)
final = concatenate_videoclips(
faded_clips,
method="compose",
padding=-transition_duration,
)
elif transition == "fade_black":
for clip in clips:
faded = clip.with_effects(
[FadeIn(transition_duration), FadeOut(transition_duration)]
)
faded_clips.append(faded)
final = concatenate_videoclips(faded_clips)
else:
final = concatenate_videoclips(clips)
video_codec, audio_codec = get_video_codecs(output_abspath)
final.write_videofile(
output_abspath, codec=video_codec, audio_codec=audio_codec
)
return final.duration
finally:
if final:
final.close()
for clip in faded_clips:
clip.close()
for clip in clips:
clip.close()
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
node_exec_id: str,
**kwargs,
) -> BlockOutput:
# Validate minimum clips
if len(input_data.videos) < 2:
raise BlockExecutionError(
message="At least 2 videos are required for concatenation",
block_name=self.name,
block_id=str(self.id),
)
try:
assert execution_context.graph_exec_id is not None
# Store all input videos locally
video_abspaths = []
for video in input_data.videos:
local_path = await self._store_input_video(execution_context, video)
video_abspaths.append(
get_exec_file_path(execution_context.graph_exec_id, local_path)
)
# Build output path
source = (
extract_source_name(video_abspaths[0]) if video_abspaths else "video"
)
output_filename = MediaFileType(
f"{node_exec_id}_concat_{source}.{input_data.output_format}"
)
output_abspath = get_exec_file_path(
execution_context.graph_exec_id, output_filename
)
total_duration = self._concat_videos(
video_abspaths,
output_abspath,
input_data.transition,
input_data.transition_duration,
)
# Return as workspace path or data URI based on context
video_out = await self._store_output_video(
execution_context, output_filename
)
yield "video_out", video_out
yield "total_duration", total_duration
except BlockExecutionError:
raise
except Exception as e:
raise BlockExecutionError(
message=f"Failed to concatenate videos: {e}",
block_name=self.name,
block_id=str(self.id),
) from e

View File

@@ -1,172 +0,0 @@
"""VideoDownloadBlock - Download video from URL (YouTube, Vimeo, news sites, direct links)."""
import os
import typing
from typing import Literal
import yt_dlp
if typing.TYPE_CHECKING:
from yt_dlp import _Params
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class VideoDownloadBlock(Block):
"""Download video from URL using yt-dlp."""
class Input(BlockSchemaInput):
url: str = SchemaField(
description="URL of the video to download (YouTube, Vimeo, direct link, etc.)",
placeholder="https://www.youtube.com/watch?v=...",
)
quality: Literal["best", "1080p", "720p", "480p", "audio_only"] = SchemaField(
description="Video quality preference", default="720p"
)
output_format: Literal["mp4", "webm", "mkv"] = SchemaField(
description="Output video format", default="mp4", advanced=True
)
class Output(BlockSchemaOutput):
video_file: MediaFileType = SchemaField(
description="Downloaded video (path or data URI)"
)
duration: float = SchemaField(description="Video duration in seconds")
title: str = SchemaField(description="Video title from source")
source_url: str = SchemaField(description="Original source URL")
def __init__(self):
super().__init__(
id="c35daabb-cd60-493b-b9ad-51f1fe4b50c4",
description="Download video from URL (YouTube, Vimeo, news sites, direct links)",
categories={BlockCategory.MULTIMEDIA},
input_schema=self.Input,
output_schema=self.Output,
disabled=True, # Disable until we can sandbox yt-dlp and handle security implications
test_input={
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
"quality": "480p",
},
test_output=[
("video_file", str),
("duration", float),
("title", str),
("source_url", str),
],
test_mock={
"_download_video": lambda *args: (
"video.mp4",
212.0,
"Test Video",
),
"_store_output_video": lambda *args, **kwargs: "video.mp4",
},
)
async def _store_output_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store output video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_block_output",
)
def _get_format_string(self, quality: str) -> str:
formats = {
"best": "bestvideo+bestaudio/best",
"1080p": "bestvideo[height<=1080]+bestaudio/best[height<=1080]",
"720p": "bestvideo[height<=720]+bestaudio/best[height<=720]",
"480p": "bestvideo[height<=480]+bestaudio/best[height<=480]",
"audio_only": "bestaudio/best",
}
return formats.get(quality, formats["720p"])
def _download_video(
self,
url: str,
quality: str,
output_format: str,
output_dir: str,
node_exec_id: str,
) -> tuple[str, float, str]:
"""Download video. Extracted for testability."""
output_template = os.path.join(
output_dir, f"{node_exec_id}_%(title).50s.%(ext)s"
)
ydl_opts: "_Params" = {
"format": f"{self._get_format_string(quality)}/best",
"outtmpl": output_template,
"merge_output_format": output_format,
"quiet": True,
"no_warnings": True,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info = ydl.extract_info(url, download=True)
video_path = ydl.prepare_filename(info)
# Handle format conversion in filename
if not video_path.endswith(f".{output_format}"):
video_path = video_path.rsplit(".", 1)[0] + f".{output_format}"
# Return just the filename, not the full path
filename = os.path.basename(video_path)
return (
filename,
info.get("duration") or 0.0,
info.get("title") or "Unknown",
)
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
node_exec_id: str,
**kwargs,
) -> BlockOutput:
try:
assert execution_context.graph_exec_id is not None
# Get the exec file directory
output_dir = get_exec_file_path(execution_context.graph_exec_id, "")
os.makedirs(output_dir, exist_ok=True)
filename, duration, title = self._download_video(
input_data.url,
input_data.quality,
input_data.output_format,
output_dir,
node_exec_id,
)
# Return as workspace path or data URI based on context
video_out = await self._store_output_video(
execution_context, MediaFileType(filename)
)
yield "video_file", video_out
yield "duration", duration
yield "title", title
yield "source_url", input_data.url
except Exception as e:
raise BlockExecutionError(
message=f"Failed to download video: {e}",
block_name=self.name,
block_id=str(self.id),
) from e

View File

@@ -1,77 +0,0 @@
"""MediaDurationBlock - Get the duration of a media file."""
from moviepy.audio.io.AudioFileClip import AudioFileClip
from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.blocks.video._utils import strip_chapters_inplace
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class MediaDurationBlock(Block):
"""Get the duration of a media file (video or audio)."""
class Input(BlockSchemaInput):
media_in: MediaFileType = SchemaField(
description="Media input (URL, data URI, or local path)."
)
is_video: bool = SchemaField(
description="Whether the media is a video (True) or audio (False).",
default=True,
)
class Output(BlockSchemaOutput):
duration: float = SchemaField(
description="Duration of the media file (in seconds)."
)
def __init__(self):
super().__init__(
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
description="Block to get the duration of a media file.",
categories={BlockCategory.MULTIMEDIA},
input_schema=MediaDurationBlock.Input,
output_schema=MediaDurationBlock.Output,
)
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# 1) Store the input media locally
local_media_path = await store_media_file(
file=input_data.media_in,
execution_context=execution_context,
return_format="for_local_processing",
)
assert execution_context.graph_exec_id is not None
media_abspath = get_exec_file_path(
execution_context.graph_exec_id, local_media_path
)
# 2) Strip chapters to avoid MoviePy crash, then load the clip
strip_chapters_inplace(media_abspath)
clip = None
try:
if input_data.is_video:
clip = VideoFileClip(media_abspath)
else:
clip = AudioFileClip(media_abspath)
duration = clip.duration
finally:
if clip:
clip.close()
yield "duration", duration

View File

@@ -1,115 +0,0 @@
"""LoopVideoBlock - Loop a video to a given duration or number of repeats."""
from typing import Optional
from moviepy.video.fx.Loop import Loop
from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class LoopVideoBlock(Block):
"""Loop (repeat) a video clip until a given duration or number of loops."""
class Input(BlockSchemaInput):
video_in: MediaFileType = SchemaField(
description="The input video (can be a URL, data URI, or local path)."
)
duration: Optional[float] = SchemaField(
description="Target duration (in seconds) to loop the video to. Either duration or n_loops must be provided.",
default=None,
ge=0.0,
le=3600.0, # Max 1 hour to prevent disk exhaustion
)
n_loops: Optional[int] = SchemaField(
description="Number of times to repeat the video. Either n_loops or duration must be provided.",
default=None,
ge=1,
le=10, # Max 10 loops to prevent disk exhaustion
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
description="Looped video returned either as a relative path or a data URI."
)
def __init__(self):
super().__init__(
id="8bf9eef6-5451-4213-b265-25306446e94b",
description="Block to loop a video to a given duration or number of repeats.",
categories={BlockCategory.MULTIMEDIA},
input_schema=LoopVideoBlock.Input,
output_schema=LoopVideoBlock.Output,
)
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
assert execution_context.graph_exec_id is not None
assert execution_context.node_exec_id is not None
graph_exec_id = execution_context.graph_exec_id
node_exec_id = execution_context.node_exec_id
# 1) Store the input video locally
local_video_path = await store_media_file(
file=input_data.video_in,
execution_context=execution_context,
return_format="for_local_processing",
)
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
# 2) Load the clip
strip_chapters_inplace(input_abspath)
clip = None
looped_clip = None
try:
clip = VideoFileClip(input_abspath)
# 3) Apply the loop effect
if input_data.duration:
# Loop until we reach the specified duration
looped_clip = clip.with_effects([Loop(duration=input_data.duration)])
elif input_data.n_loops:
looped_clip = clip.with_effects([Loop(n=input_data.n_loops)])
else:
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
assert isinstance(looped_clip, VideoFileClip)
# 4) Save the looped output
source = extract_source_name(local_video_path)
output_filename = MediaFileType(f"{node_exec_id}_looped_{source}.mp4")
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
looped_clip = looped_clip.with_audio(clip.audio)
looped_clip.write_videofile(
output_abspath, codec="libx264", audio_codec="aac"
)
finally:
if looped_clip:
looped_clip.close()
if clip:
clip.close()
# Return output - for_block_output returns workspace:// if available, else data URI
video_out = await store_media_file(
file=output_filename,
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_out", video_out

View File

@@ -1,267 +0,0 @@
"""VideoNarrationBlock - Generate AI voice narration and add to video."""
import os
from typing import Literal
from elevenlabs import ElevenLabs
from moviepy import CompositeAudioClip
from moviepy.audio.io.AudioFileClip import AudioFileClip
from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.blocks.elevenlabs._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
ElevenLabsCredentials,
ElevenLabsCredentialsInput,
)
from backend.blocks.video._utils import (
extract_source_name,
get_video_codecs,
strip_chapters_inplace,
)
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsField, SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class VideoNarrationBlock(Block):
"""Generate AI narration and add to video."""
class Input(BlockSchemaInput):
credentials: ElevenLabsCredentialsInput = CredentialsField(
description="ElevenLabs API key for voice synthesis"
)
video_in: MediaFileType = SchemaField(
description="Input video (URL, data URI, or local path)"
)
script: str = SchemaField(description="Narration script text")
voice_id: str = SchemaField(
description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel
)
model_id: Literal[
"eleven_multilingual_v2",
"eleven_flash_v2_5",
"eleven_turbo_v2_5",
"eleven_turbo_v2",
] = SchemaField(
description="ElevenLabs TTS model",
default="eleven_multilingual_v2",
)
mix_mode: Literal["replace", "mix", "ducking"] = SchemaField(
description="How to combine with original audio. 'ducking' applies stronger attenuation than 'mix'.",
default="ducking",
)
narration_volume: float = SchemaField(
description="Narration volume (0.0 to 2.0)",
default=1.0,
ge=0.0,
le=2.0,
advanced=True,
)
original_volume: float = SchemaField(
description="Original audio volume when mixing (0.0 to 1.0)",
default=0.3,
ge=0.0,
le=1.0,
advanced=True,
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
description="Video with narration (path or data URI)"
)
audio_file: MediaFileType = SchemaField(
description="Generated audio file (path or data URI)"
)
def __init__(self):
super().__init__(
id="3d036b53-859c-4b17-9826-ca340f736e0e",
description="Generate AI narration and add to video",
categories={BlockCategory.MULTIMEDIA, BlockCategory.AI},
input_schema=self.Input,
output_schema=self.Output,
test_input={
"video_in": "/tmp/test.mp4",
"script": "Hello world",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("video_out", str), ("audio_file", str)],
test_mock={
"_generate_narration_audio": lambda *args: b"mock audio content",
"_add_narration_to_video": lambda *args: None,
"_store_input_video": lambda *args, **kwargs: "test.mp4",
"_store_output_video": lambda *args, **kwargs: "narrated_test.mp4",
},
)
async def _store_input_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store input video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_local_processing",
)
async def _store_output_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store output video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_block_output",
)
def _generate_narration_audio(
self, api_key: str, script: str, voice_id: str, model_id: str
) -> bytes:
"""Generate narration audio via ElevenLabs API."""
client = ElevenLabs(api_key=api_key)
audio_generator = client.text_to_speech.convert(
voice_id=voice_id,
text=script,
model_id=model_id,
)
# The SDK returns a generator, collect all chunks
return b"".join(audio_generator)
def _add_narration_to_video(
self,
video_abspath: str,
audio_abspath: str,
output_abspath: str,
mix_mode: str,
narration_volume: float,
original_volume: float,
) -> None:
"""Add narration audio to video. Extracted for testability."""
video = None
final = None
narration_original = None
narration_scaled = None
original = None
try:
strip_chapters_inplace(video_abspath)
video = VideoFileClip(video_abspath)
narration_original = AudioFileClip(audio_abspath)
narration_scaled = narration_original.with_volume_scaled(narration_volume)
narration = narration_scaled
if mix_mode == "replace":
final_audio = narration
elif mix_mode == "mix":
if video.audio:
original = video.audio.with_volume_scaled(original_volume)
final_audio = CompositeAudioClip([original, narration])
else:
final_audio = narration
else: # ducking - apply stronger attenuation
if video.audio:
# Ducking uses a much lower volume for original audio
ducking_volume = original_volume * 0.3
original = video.audio.with_volume_scaled(ducking_volume)
final_audio = CompositeAudioClip([original, narration])
else:
final_audio = narration
final = video.with_audio(final_audio)
video_codec, audio_codec = get_video_codecs(output_abspath)
final.write_videofile(
output_abspath, codec=video_codec, audio_codec=audio_codec
)
finally:
if original:
original.close()
if narration_scaled:
narration_scaled.close()
if narration_original:
narration_original.close()
if final:
final.close()
if video:
video.close()
async def run(
self,
input_data: Input,
*,
credentials: ElevenLabsCredentials,
execution_context: ExecutionContext,
node_exec_id: str,
**kwargs,
) -> BlockOutput:
try:
assert execution_context.graph_exec_id is not None
# Store the input video locally
local_video_path = await self._store_input_video(
execution_context, input_data.video_in
)
video_abspath = get_exec_file_path(
execution_context.graph_exec_id, local_video_path
)
# Generate narration audio via ElevenLabs
audio_content = self._generate_narration_audio(
credentials.api_key.get_secret_value(),
input_data.script,
input_data.voice_id,
input_data.model_id,
)
# Save audio to exec file path
audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3")
audio_abspath = get_exec_file_path(
execution_context.graph_exec_id, audio_filename
)
os.makedirs(os.path.dirname(audio_abspath), exist_ok=True)
with open(audio_abspath, "wb") as f:
f.write(audio_content)
# Add narration to video
source = extract_source_name(local_video_path)
output_filename = MediaFileType(f"{node_exec_id}_narrated_{source}.mp4")
output_abspath = get_exec_file_path(
execution_context.graph_exec_id, output_filename
)
self._add_narration_to_video(
video_abspath,
audio_abspath,
output_abspath,
input_data.mix_mode,
input_data.narration_volume,
input_data.original_volume,
)
# Return as workspace path or data URI based on context
video_out = await self._store_output_video(
execution_context, output_filename
)
audio_out = await self._store_output_video(
execution_context, audio_filename
)
yield "video_out", video_out
yield "audio_file", audio_out
except Exception as e:
raise BlockExecutionError(
message=f"Failed to add narration: {e}",
block_name=self.name,
block_id=str(self.id),
) from e

View File

@@ -1,231 +0,0 @@
"""VideoTextOverlayBlock - Add text overlay to video."""
from typing import Literal
from moviepy import CompositeVideoClip, TextClip
from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.blocks.video._utils import (
extract_source_name,
get_video_codecs,
strip_chapters_inplace,
)
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class VideoTextOverlayBlock(Block):
"""Add text overlay/caption to video."""
class Input(BlockSchemaInput):
video_in: MediaFileType = SchemaField(
description="Input video (URL, data URI, or local path)"
)
text: str = SchemaField(description="Text to overlay on video")
position: Literal[
"top",
"center",
"bottom",
"top-left",
"top-right",
"bottom-left",
"bottom-right",
] = SchemaField(description="Position of text on screen", default="bottom")
start_time: float | None = SchemaField(
description="When to show text (seconds). None = entire video",
default=None,
advanced=True,
)
end_time: float | None = SchemaField(
description="When to hide text (seconds). None = until end",
default=None,
advanced=True,
)
font_size: int = SchemaField(
description="Font size", default=48, ge=12, le=200, advanced=True
)
font_color: str = SchemaField(
description="Font color (hex or name)", default="white", advanced=True
)
bg_color: str | None = SchemaField(
description="Background color behind text (None for transparent)",
default=None,
advanced=True,
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
description="Video with text overlay (path or data URI)"
)
def __init__(self):
super().__init__(
id="8ef14de6-cc90-430a-8cfa-3a003be92454",
description="Add text overlay/caption to video",
categories={BlockCategory.MULTIMEDIA},
input_schema=self.Input,
output_schema=self.Output,
disabled=True, # Disable until we can lockdown imagemagick security policy
test_input={"video_in": "/tmp/test.mp4", "text": "Hello World"},
test_output=[("video_out", str)],
test_mock={
"_add_text_overlay": lambda *args: None,
"_store_input_video": lambda *args, **kwargs: "test.mp4",
"_store_output_video": lambda *args, **kwargs: "overlay_test.mp4",
},
)
async def _store_input_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store input video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_local_processing",
)
async def _store_output_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store output video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_block_output",
)
def _add_text_overlay(
self,
video_abspath: str,
output_abspath: str,
text: str,
position: str,
start_time: float | None,
end_time: float | None,
font_size: int,
font_color: str,
bg_color: str | None,
) -> None:
"""Add text overlay to video. Extracted for testability."""
video = None
final = None
txt_clip = None
try:
strip_chapters_inplace(video_abspath)
video = VideoFileClip(video_abspath)
txt_clip = TextClip(
text=text,
font_size=font_size,
color=font_color,
bg_color=bg_color,
)
# Position mapping
pos_map = {
"top": ("center", "top"),
"center": ("center", "center"),
"bottom": ("center", "bottom"),
"top-left": ("left", "top"),
"top-right": ("right", "top"),
"bottom-left": ("left", "bottom"),
"bottom-right": ("right", "bottom"),
}
txt_clip = txt_clip.with_position(pos_map[position])
# Set timing
start = start_time or 0
end = end_time or video.duration
duration = max(0, end - start)
txt_clip = txt_clip.with_start(start).with_end(end).with_duration(duration)
final = CompositeVideoClip([video, txt_clip])
video_codec, audio_codec = get_video_codecs(output_abspath)
final.write_videofile(
output_abspath, codec=video_codec, audio_codec=audio_codec
)
finally:
if txt_clip:
txt_clip.close()
if final:
final.close()
if video:
video.close()
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
node_exec_id: str,
**kwargs,
) -> BlockOutput:
# Validate time range if both are provided
if (
input_data.start_time is not None
and input_data.end_time is not None
and input_data.end_time <= input_data.start_time
):
raise BlockExecutionError(
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
block_name=self.name,
block_id=str(self.id),
)
try:
assert execution_context.graph_exec_id is not None
# Store the input video locally
local_video_path = await self._store_input_video(
execution_context, input_data.video_in
)
video_abspath = get_exec_file_path(
execution_context.graph_exec_id, local_video_path
)
# Build output path
source = extract_source_name(local_video_path)
output_filename = MediaFileType(f"{node_exec_id}_overlay_{source}.mp4")
output_abspath = get_exec_file_path(
execution_context.graph_exec_id, output_filename
)
self._add_text_overlay(
video_abspath,
output_abspath,
input_data.text,
input_data.position,
input_data.start_time,
input_data.end_time,
input_data.font_size,
input_data.font_color,
input_data.bg_color,
)
# Return as workspace path or data URI based on context
video_out = await self._store_output_video(
execution_context, output_filename
)
yield "video_out", video_out
except BlockExecutionError:
raise
except Exception as e:
raise BlockExecutionError(
message=f"Failed to add text overlay: {e}",
block_name=self.name,
block_id=str(self.id),
) from e

View File

@@ -165,13 +165,10 @@ class TranscribeYoutubeVideoBlock(Block):
credentials: WebshareProxyCredentials, credentials: WebshareProxyCredentials,
**kwargs, **kwargs,
) -> BlockOutput: ) -> BlockOutput:
try: video_id = self.extract_video_id(input_data.youtube_url)
video_id = self.extract_video_id(input_data.youtube_url) yield "video_id", video_id
transcript = self.get_transcript(video_id, credentials)
transcript_text = self.format_transcript(transcript=transcript)
# Only yield after all operations succeed transcript = self.get_transcript(video_id, credentials)
yield "video_id", video_id transcript_text = self.format_transcript(transcript=transcript)
yield "transcript", transcript_text
except Exception as e: yield "transcript", transcript_text
yield "error", str(e)

View File

@@ -246,9 +246,7 @@ class BlockSchema(BaseModel):
f"is not of type {CredentialsMetaInput.__name__}" f"is not of type {CredentialsMetaInput.__name__}"
) )
CredentialsMetaInput.validate_credentials_field_schema( credentials_fields[field_name].validate_credentials_field_schema(cls)
cls.get_field_schema(field_name), field_name
)
elif field_name in credentials_fields: elif field_name in credentials_fields:
raise KeyError( raise KeyError(

View File

@@ -36,14 +36,12 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
from backend.blocks.video.narration import VideoNarrationBlock
from backend.data.block import Block, BlockCost, BlockCostType from backend.data.block import Block, BlockCost, BlockCostType
from backend.integrations.credentials_store import ( from backend.integrations.credentials_store import (
aiml_api_credentials, aiml_api_credentials,
anthropic_credentials, anthropic_credentials,
apollo_credentials, apollo_credentials,
did_credentials, did_credentials,
elevenlabs_credentials,
enrichlayer_credentials, enrichlayer_credentials,
groq_credentials, groq_credentials,
ideogram_credentials, ideogram_credentials,
@@ -80,7 +78,6 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.CLAUDE_4_1_OPUS: 21, LlmModel.CLAUDE_4_1_OPUS: 21,
LlmModel.CLAUDE_4_OPUS: 21, LlmModel.CLAUDE_4_OPUS: 21,
LlmModel.CLAUDE_4_SONNET: 5, LlmModel.CLAUDE_4_SONNET: 5,
LlmModel.CLAUDE_4_6_OPUS: 14,
LlmModel.CLAUDE_4_5_HAIKU: 4, LlmModel.CLAUDE_4_5_HAIKU: 4,
LlmModel.CLAUDE_4_5_OPUS: 14, LlmModel.CLAUDE_4_5_OPUS: 14,
LlmModel.CLAUDE_4_5_SONNET: 9, LlmModel.CLAUDE_4_5_SONNET: 9,
@@ -642,16 +639,4 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
}, },
), ),
], ],
VideoNarrationBlock: [
BlockCost(
cost_amount=5, # ElevenLabs TTS cost
cost_filter={
"credentials": {
"id": elevenlabs_credentials.id,
"provider": elevenlabs_credentials.provider,
"type": elevenlabs_credentials.type,
}
},
)
],
} }

View File

@@ -134,16 +134,6 @@ async def test_block_credit_reset(server: SpinTestServer):
month1 = datetime.now(timezone.utc).replace(month=1, day=1) month1 = datetime.now(timezone.utc).replace(month=1, day=1)
user_credit.time_now = lambda: month1 user_credit.time_now = lambda: month1
# IMPORTANT: Set updatedAt to December of previous year to ensure it's
# in a different month than month1 (January). This fixes a timing bug
# where if the test runs in early February, 35 days ago would be January,
# matching the mocked month1 and preventing the refill from triggering.
dec_previous_year = month1.replace(year=month1.year - 1, month=12, day=15)
await UserBalance.prisma().update(
where={"userId": DEFAULT_USER_ID},
data={"updatedAt": dec_previous_year},
)
# First call in month 1 should trigger refill # First call in month 1 should trigger refill
balance = await user_credit.get_credits(DEFAULT_USER_ID) balance = await user_credit.get_credits(DEFAULT_USER_ID)
assert balance == REFILL_VALUE # Should get 1000 credits assert balance == REFILL_VALUE # Should get 1000 credits

View File

@@ -1,8 +1,9 @@
import logging import logging
import queue
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from enum import Enum from enum import Enum
from multiprocessing import Manager
from queue import Empty
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Annotated, Annotated,
@@ -1199,16 +1200,12 @@ class NodeExecutionEntry(BaseModel):
class ExecutionQueue(Generic[T]): class ExecutionQueue(Generic[T]):
""" """
Thread-safe queue for managing node execution within a single graph execution. Queue for managing the execution of agents.
This will be shared between different processes
Note: Uses queue.Queue (not multiprocessing.Queue) since all access is from
threads within the same process. If migrating back to ProcessPoolExecutor,
replace with multiprocessing.Manager().Queue() for cross-process safety.
""" """
def __init__(self): def __init__(self):
# Thread-safe queue (not multiprocessing) — see class docstring self.queue = Manager().Queue()
self.queue: queue.Queue[T] = queue.Queue()
def add(self, execution: T) -> T: def add(self, execution: T) -> T:
self.queue.put(execution) self.queue.put(execution)
@@ -1223,7 +1220,7 @@ class ExecutionQueue(Generic[T]):
def get_or_none(self) -> T | None: def get_or_none(self) -> T | None:
try: try:
return self.queue.get_nowait() return self.queue.get_nowait()
except queue.Empty: except Empty:
return None return None

View File

@@ -1,58 +0,0 @@
"""Tests for ExecutionQueue thread-safety."""
import queue
import threading
from backend.data.execution import ExecutionQueue
def test_execution_queue_uses_stdlib_queue():
"""Verify ExecutionQueue uses queue.Queue (not multiprocessing)."""
q = ExecutionQueue()
assert isinstance(q.queue, queue.Queue)
def test_basic_operations():
"""Test add, get, empty, and get_or_none."""
q = ExecutionQueue()
assert q.empty() is True
assert q.get_or_none() is None
result = q.add("item1")
assert result == "item1"
assert q.empty() is False
item = q.get()
assert item == "item1"
assert q.empty() is True
def test_thread_safety():
"""Test concurrent access from multiple threads."""
q = ExecutionQueue()
results = []
num_items = 100
def producer():
for i in range(num_items):
q.add(f"item_{i}")
def consumer():
count = 0
while count < num_items:
item = q.get_or_none()
if item is not None:
results.append(item)
count += 1
producer_thread = threading.Thread(target=producer)
consumer_thread = threading.Thread(target=consumer)
producer_thread.start()
consumer_thread.start()
producer_thread.join(timeout=5)
consumer_thread.join(timeout=5)
assert len(results) == num_items

View File

@@ -3,7 +3,7 @@ import logging
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Self, cast from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, cast
from prisma.enums import SubmissionStatus from prisma.enums import SubmissionStatus
from prisma.models import ( from prisma.models import (
@@ -20,7 +20,7 @@ from prisma.types import (
AgentNodeLinkCreateInput, AgentNodeLinkCreateInput,
StoreListingVersionWhereInput, StoreListingVersionWhereInput,
) )
from pydantic import BaseModel, BeforeValidator, Field from pydantic import BaseModel, BeforeValidator, Field, create_model
from pydantic.fields import computed_field from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock from backend.blocks.agent import AgentExecutorBlock
@@ -30,6 +30,7 @@ from backend.data.db import prisma as db
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
from backend.data.model import ( from backend.data.model import (
CredentialsField,
CredentialsFieldInfo, CredentialsFieldInfo,
CredentialsMetaInput, CredentialsMetaInput,
is_credentials_field_name, is_credentials_field_name,
@@ -44,6 +45,7 @@ from .block import (
AnyBlockSchema, AnyBlockSchema,
Block, Block,
BlockInput, BlockInput,
BlockSchema,
BlockType, BlockType,
EmptySchema, EmptySchema,
get_block, get_block,
@@ -111,12 +113,10 @@ class Link(BaseDbModel):
class Node(BaseDbModel): class Node(BaseDbModel):
block_id: str block_id: str
input_default: BlockInput = Field( # dict[input_name, default_value] input_default: BlockInput = {} # dict[input_name, default_value]
default_factory=dict metadata: dict[str, Any] = {}
) input_links: list[Link] = []
metadata: dict[str, Any] = Field(default_factory=dict) output_links: list[Link] = []
input_links: list[Link] = Field(default_factory=list)
output_links: list[Link] = Field(default_factory=list)
@property @property
def credentials_optional(self) -> bool: def credentials_optional(self) -> bool:
@@ -221,33 +221,18 @@ class NodeModel(Node):
return result return result
class GraphBaseMeta(BaseDbModel): class BaseGraph(BaseDbModel):
"""
Shared base for `GraphMeta` and `BaseGraph`, with core graph metadata fields.
"""
version: int = 1 version: int = 1
is_active: bool = True is_active: bool = True
name: str name: str
description: str description: str
instructions: str | None = None instructions: str | None = None
recommended_schedule_cron: str | None = None recommended_schedule_cron: str | None = None
nodes: list[Node] = []
links: list[Link] = []
forked_from_id: str | None = None forked_from_id: str | None = None
forked_from_version: int | None = None forked_from_version: int | None = None
class BaseGraph(GraphBaseMeta):
"""
Graph with nodes, links, and computed I/O schema fields.
Used to represent sub-graphs within a `Graph`. Contains the full graph
structure including nodes and links, plus computed fields for schemas
and trigger info. Does NOT include user_id or created_at (see GraphModel).
"""
nodes: list[Node] = Field(default_factory=list)
links: list[Link] = Field(default_factory=list)
@computed_field @computed_field
@property @property
def input_schema(self) -> dict[str, Any]: def input_schema(self) -> dict[str, Any]:
@@ -376,79 +361,44 @@ class GraphTriggerInfo(BaseModel):
class Graph(BaseGraph): class Graph(BaseGraph):
"""Creatable graph model used in API create/update endpoints.""" sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
sub_graphs: list[BaseGraph] = Field(default_factory=list) # Flattened sub-graphs
class GraphMeta(GraphBaseMeta):
"""
Lightweight graph metadata model representing an existing graph from the database,
for use in listings and summaries.
Lacks `GraphModel`'s nodes, links, and expensive computed fields.
Use for list endpoints where full graph data is not needed and performance matters.
"""
id: str # type: ignore
version: int # type: ignore
user_id: str
created_at: datetime
@classmethod
def from_db(cls, graph: "AgentGraph") -> Self:
return cls(
id=graph.id,
version=graph.version,
is_active=graph.isActive,
name=graph.name or "",
description=graph.description or "",
instructions=graph.instructions,
recommended_schedule_cron=graph.recommendedScheduleCron,
forked_from_id=graph.forkedFromId,
forked_from_version=graph.forkedFromVersion,
user_id=graph.userId,
created_at=graph.createdAt,
)
class GraphModel(Graph, GraphMeta):
"""
Full graph model representing an existing graph from the database.
This is the primary model for working with persisted graphs. Includes all
graph data (nodes, links, sub_graphs) plus user ownership and timestamps.
Provides computed fields (input_schema, output_schema, etc.) used during
set-up (frontend) and execution (backend).
Inherits from:
- `Graph`: provides structure (nodes, links, sub_graphs) and computed schemas
- `GraphMeta`: provides user_id, created_at for database records
"""
nodes: list[NodeModel] = Field(default_factory=list) # type: ignore
@property
def starting_nodes(self) -> list[NodeModel]:
outbound_nodes = {link.sink_id for link in self.links}
input_nodes = {
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
}
return [
node
for node in self.nodes
if node.id not in outbound_nodes or node.id in input_nodes
]
@property
def webhook_input_node(self) -> NodeModel | None: # type: ignore
return cast(NodeModel, super().webhook_input_node)
@computed_field @computed_field
@property @property
def credentials_input_schema(self) -> dict[str, Any]: def credentials_input_schema(self) -> dict[str, Any]:
graph_credentials_inputs = self.aggregate_credentials_inputs() schema = self._credentials_input_schema.jsonschema()
# Determine which credential fields are required based on credentials_optional metadata
graph_credentials_inputs = self.aggregate_credentials_inputs()
required_fields = []
# Build a map of node_id -> node for quick lookup
all_nodes = {node.id: node for node in self.nodes}
for sub_graph in self.sub_graphs:
for node in sub_graph.nodes:
all_nodes[node.id] = node
for field_key, (
_field_info,
node_field_pairs,
) in graph_credentials_inputs.items():
# A field is required if ANY node using it has credentials_optional=False
is_required = False
for node_id, _field_name in node_field_pairs:
node = all_nodes.get(node_id)
if node and not node.credentials_optional:
is_required = True
break
if is_required:
required_fields.append(field_key)
schema["required"] = required_fields
return schema
@property
def _credentials_input_schema(self) -> type[BlockSchema]:
graph_credentials_inputs = self.aggregate_credentials_inputs()
logger.debug( logger.debug(
f"Combined credentials input fields for graph #{self.id} ({self.name}): " f"Combined credentials input fields for graph #{self.id} ({self.name}): "
f"{graph_credentials_inputs}" f"{graph_credentials_inputs}"
@@ -456,8 +406,8 @@ class GraphModel(Graph, GraphMeta):
# Warn if same-provider credentials inputs can't be combined (= bad UX) # Warn if same-provider credentials inputs can't be combined (= bad UX)
graph_cred_fields = list(graph_credentials_inputs.values()) graph_cred_fields = list(graph_credentials_inputs.values())
for i, (field, keys, _) in enumerate(graph_cred_fields): for i, (field, keys) in enumerate(graph_cred_fields):
for other_field, other_keys, _ in list(graph_cred_fields)[i + 1 :]: for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
if field.provider != other_field.provider: if field.provider != other_field.provider:
continue continue
if ProviderName.HTTP in field.provider: if ProviderName.HTTP in field.provider:
@@ -473,78 +423,31 @@ class GraphModel(Graph, GraphMeta):
f"keys: {keys} <> {other_keys}." f"keys: {keys} <> {other_keys}."
) )
# Build JSON schema directly to avoid expensive create_model + validation overhead fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
properties = {} agg_field_key: (
required_fields = [] CredentialsMetaInput[
Literal[tuple(field_info.provider)], # type: ignore
for agg_field_key, ( Literal[tuple(field_info.supported_types)], # type: ignore
field_info, ],
_, CredentialsField(
is_required, required_scopes=set(field_info.required_scopes or []),
) in graph_credentials_inputs.items(): discriminator=field_info.discriminator,
providers = list(field_info.provider) discriminator_mapping=field_info.discriminator_mapping,
cred_types = list(field_info.supported_types) discriminator_values=field_info.discriminator_values,
),
field_schema: dict[str, Any] = {
"credentials_provider": providers,
"credentials_types": cred_types,
"type": "object",
"properties": {
"id": {"title": "Id", "type": "string"},
"title": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"default": None,
"title": "Title",
},
"provider": {
"title": "Provider",
"type": "string",
**(
{"enum": providers}
if len(providers) > 1
else {"const": providers[0]}
),
},
"type": {
"title": "Type",
"type": "string",
**(
{"enum": cred_types}
if len(cred_types) > 1
else {"const": cred_types[0]}
),
},
},
"required": ["id", "provider", "type"],
}
# Add other (optional) field info items
field_schema.update(
field_info.model_dump(
by_alias=True,
exclude_defaults=True,
exclude={"provider", "supported_types"}, # already included above
)
) )
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
# Ensure field schema is well-formed
CredentialsMetaInput.validate_credentials_field_schema(
field_schema, agg_field_key
)
properties[agg_field_key] = field_schema
if is_required:
required_fields.append(agg_field_key)
return {
"type": "object",
"properties": properties,
"required": required_fields,
} }
return create_model(
self.name.replace(" ", "") + "CredentialsInputSchema",
__base__=BlockSchema,
**fields, # type: ignore
)
def aggregate_credentials_inputs( def aggregate_credentials_inputs(
self, self,
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]: ) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
""" """
Returns: Returns:
dict[aggregated_field_key, tuple( dict[aggregated_field_key, tuple(
@@ -552,19 +455,13 @@ class GraphModel(Graph, GraphMeta):
(now includes discriminator_values from matching nodes) (now includes discriminator_values from matching nodes)
set[(node_id, field_name)]: Node credentials fields that are set[(node_id, field_name)]: Node credentials fields that are
compatible with this aggregated field spec compatible with this aggregated field spec
bool: True if the field is required (any node has credentials_optional=False)
)] )]
""" """
# First collect all credential field data with input defaults # First collect all credential field data with input defaults
# Track (field_info, (node_id, field_name), is_required) for each credential field node_credential_data = []
node_credential_data: list[tuple[CredentialsFieldInfo, tuple[str, str]]] = []
node_required_map: dict[str, bool] = {} # node_id -> is_required
for graph in [self] + self.sub_graphs: for graph in [self] + self.sub_graphs:
for node in graph.nodes: for node in graph.nodes:
# Track if this node requires credentials (credentials_optional=False means required)
node_required_map[node.id] = not node.credentials_optional
for ( for (
field_name, field_name,
field_info, field_info,
@@ -588,21 +485,37 @@ class GraphModel(Graph, GraphMeta):
) )
# Combine credential field info (this will merge discriminator_values automatically) # Combine credential field info (this will merge discriminator_values automatically)
combined = CredentialsFieldInfo.combine(*node_credential_data) return CredentialsFieldInfo.combine(*node_credential_data)
# Add is_required flag to each aggregated field
# A field is required if ANY node using it has credentials_optional=False class GraphModel(Graph):
return { user_id: str
key: ( nodes: list[NodeModel] = [] # type: ignore
field_info,
node_field_pairs, created_at: datetime
any(
node_required_map.get(node_id, True) @property
for node_id, _ in node_field_pairs def starting_nodes(self) -> list[NodeModel]:
), outbound_nodes = {link.sink_id for link in self.links}
) input_nodes = {
for key, (field_info, node_field_pairs) in combined.items() node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
} }
return [
node
for node in self.nodes
if node.id not in outbound_nodes or node.id in input_nodes
]
@property
def webhook_input_node(self) -> NodeModel | None: # type: ignore
return cast(NodeModel, super().webhook_input_node)
def meta(self) -> "GraphMeta":
"""
Returns a GraphMeta object with metadata about the graph.
This is used to return metadata about the graph without exposing nodes and links.
"""
return GraphMeta.from_graph(self)
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False): def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
""" """
@@ -886,14 +799,13 @@ class GraphModel(Graph, GraphMeta):
if is_static_output_block(link.source_id): if is_static_output_block(link.source_id):
link.is_static = True # Each value block output should be static. link.is_static = True # Each value block output should be static.
@classmethod @staticmethod
def from_db( # type: ignore[reportIncompatibleMethodOverride] def from_db(
cls,
graph: AgentGraph, graph: AgentGraph,
for_export: bool = False, for_export: bool = False,
sub_graphs: list[AgentGraph] | None = None, sub_graphs: list[AgentGraph] | None = None,
) -> Self: ) -> "GraphModel":
return cls( return GraphModel(
id=graph.id, id=graph.id,
user_id=graph.userId if not for_export else "", user_id=graph.userId if not for_export else "",
version=graph.version, version=graph.version,
@@ -919,28 +831,17 @@ class GraphModel(Graph, GraphMeta):
], ],
) )
def hide_nodes(self) -> "GraphModelWithoutNodes":
"""
Returns a copy of the `GraphModel` with nodes, links, and sub-graphs hidden
(excluded from serialization). They are still present in the model instance
so all computed fields (e.g. `credentials_input_schema`) still work.
"""
return GraphModelWithoutNodes.model_validate(self, from_attributes=True)
class GraphMeta(Graph):
user_id: str
class GraphModelWithoutNodes(GraphModel): # Easy work-around to prevent exposing nodes and links in the API response
""" nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
GraphModel variant that excludes nodes, links, and sub-graphs from serialization. links: list[Link] = Field(default=[], exclude=True)
Used in contexts like the store where exposing internal graph structure @staticmethod
is not desired. Inherits all computed fields from GraphModel but marks def from_graph(graph: GraphModel) -> "GraphMeta":
nodes and links as excluded from JSON output. return GraphMeta(**graph.model_dump())
"""
nodes: list[NodeModel] = Field(default_factory=list, exclude=True)
links: list[Link] = Field(default_factory=list, exclude=True)
sub_graphs: list[BaseGraph] = Field(default_factory=list, exclude=True)
class GraphsPaginated(BaseModel): class GraphsPaginated(BaseModel):
@@ -1011,11 +912,21 @@ async def list_graphs_paginated(
where=where_clause, where=where_clause,
distinct=["id"], distinct=["id"],
order={"version": "desc"}, order={"version": "desc"},
include=AGENT_GRAPH_INCLUDE,
skip=offset, skip=offset,
take=page_size, take=page_size,
) )
graph_models = [GraphMeta.from_db(graph) for graph in graphs] graph_models: list[GraphMeta] = []
for graph in graphs:
try:
graph_meta = GraphModel.from_db(graph).meta()
# Trigger serialization to validate that the graph is well formed
graph_meta.model_dump()
graph_models.append(graph_meta)
except Exception as e:
logger.error(f"Error processing graph {graph.id}: {e}")
continue
return GraphsPaginated( return GraphsPaginated(
graphs=graph_models, graphs=graph_models,

View File

@@ -19,6 +19,7 @@ from typing import (
cast, cast,
get_args, get_args,
) )
from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
from prisma.enums import CreditTransactionType, OnboardingStep from prisma.enums import CreditTransactionType, OnboardingStep
@@ -41,7 +42,6 @@ from typing_extensions import TypedDict
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.util.json import loads as json_loads from backend.util.json import loads as json_loads
from backend.util.request import parse_url
from backend.util.settings import Secrets from backend.util.settings import Secrets
# Type alias for any provider name (including custom ones) # Type alias for any provider name (including custom ones)
@@ -163,6 +163,7 @@ class User(BaseModel):
if TYPE_CHECKING: if TYPE_CHECKING:
from prisma.models import User as PrismaUser from prisma.models import User as PrismaUser
from backend.data.block import BlockSchema
T = TypeVar("T") T = TypeVar("T")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -396,25 +397,19 @@ class HostScopedCredentials(_BaseCredentials):
def matches_url(self, url: str) -> bool: def matches_url(self, url: str) -> bool:
"""Check if this credential should be applied to the given URL.""" """Check if this credential should be applied to the given URL."""
request_host, request_port = _extract_host_from_url(url) parsed_url = urlparse(url)
cred_scope_host, cred_scope_port = _extract_host_from_url(self.host) # Extract hostname without port
request_host = parsed_url.hostname
if not request_host: if not request_host:
return False return False
# If a port is specified in credential host, the request host port must match # Simple host matching - exact match or wildcard subdomain match
if cred_scope_port is not None and request_port != cred_scope_port: if self.host == request_host:
return False
# Non-standard ports are only allowed if explicitly specified in credential host
elif cred_scope_port is None and request_port not in (80, 443, None):
return False
# Simple host matching
if cred_scope_host == request_host:
return True return True
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com") # Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
if cred_scope_host.startswith("*."): if self.host.startswith("*."):
domain = cred_scope_host[2:] # Remove "*." domain = self.host[2:] # Remove "*."
return request_host.endswith(f".{domain}") or request_host == domain return request_host.endswith(f".{domain}") or request_host == domain
return False return False
@@ -507,13 +502,15 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]: def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
return get_args(cls.model_fields["type"].annotation) return get_args(cls.model_fields["type"].annotation)
@staticmethod @classmethod
def validate_credentials_field_schema( def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
field_schema: dict[str, Any], field_name: str
):
"""Validates the schema of a credentials input field""" """Validates the schema of a credentials input field"""
field_name = next(
name for name, type in model.get_credentials_fields().items() if type is cls
)
field_schema = model.jsonschema()["properties"][field_name]
try: try:
field_info = CredentialsFieldInfo[CP, CT].model_validate(field_schema) schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
except ValidationError as e: except ValidationError as e:
if "Field required [type=missing" not in str(e): if "Field required [type=missing" not in str(e):
raise raise
@@ -523,11 +520,11 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
f"{field_schema}" f"{field_schema}"
) from e ) from e
providers = field_info.provider providers = cls.allowed_providers()
if ( if (
providers is not None providers is not None
and len(providers) > 1 and len(providers) > 1
and not field_info.discriminator and not schema_extra.discriminator
): ):
raise TypeError( raise TypeError(
f"Multi-provider CredentialsField '{field_name}' " f"Multi-provider CredentialsField '{field_name}' "
@@ -554,13 +551,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
) )
def _extract_host_from_url(url: str) -> tuple[str, int | None]: def _extract_host_from_url(url: str) -> str:
"""Extract host and port from URL for grouping host-scoped credentials.""" """Extract host from URL for grouping host-scoped credentials."""
try: try:
parsed = parse_url(url) parsed = urlparse(url)
return parsed.hostname or url, parsed.port return parsed.hostname or url
except Exception: except Exception:
return "", None return ""
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]): class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
@@ -609,7 +606,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
providers = frozenset( providers = frozenset(
[cast(CP, "http")] [cast(CP, "http")]
+ [ + [
cast(CP, parse_url(str(value)).netloc) cast(CP, _extract_host_from_url(str(value)))
for value in field.discriminator_values for value in field.discriminator_values
] ]
) )

View File

@@ -79,23 +79,10 @@ class TestHostScopedCredentials:
headers={"Authorization": SecretStr("Bearer token")}, headers={"Authorization": SecretStr("Bearer token")},
) )
# Non-standard ports require explicit port in credential host assert creds.matches_url("http://localhost:8080/api/v1")
assert not creds.matches_url("http://localhost:8080/api/v1")
assert creds.matches_url("https://localhost:443/secure/endpoint") assert creds.matches_url("https://localhost:443/secure/endpoint")
assert creds.matches_url("http://localhost/simple") assert creds.matches_url("http://localhost/simple")
def test_matches_url_with_explicit_port(self):
"""Test URL matching with explicit port in credential host."""
creds = HostScopedCredentials(
provider="custom",
host="localhost:8080",
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url("http://localhost:8080/api/v1")
assert not creds.matches_url("http://localhost:3000/api/v1")
assert not creds.matches_url("http://localhost/simple")
def test_empty_headers_dict(self): def test_empty_headers_dict(self):
"""Test HostScopedCredentials with empty headers.""" """Test HostScopedCredentials with empty headers."""
creds = HostScopedCredentials( creds = HostScopedCredentials(
@@ -141,20 +128,8 @@ class TestHostScopedCredentials:
("*.example.com", "https://sub.api.example.com/test", True), ("*.example.com", "https://sub.api.example.com/test", True),
("*.example.com", "https://example.com/test", True), ("*.example.com", "https://example.com/test", True),
("*.example.com", "https://example.org/test", False), ("*.example.com", "https://example.org/test", False),
# Non-standard ports require explicit port in credential host ("localhost", "http://localhost:3000/test", True),
("localhost", "http://localhost:3000/test", False),
("localhost:3000", "http://localhost:3000/test", True),
("localhost", "http://127.0.0.1:3000/test", False), ("localhost", "http://127.0.0.1:3000/test", False),
# IPv6 addresses (frontend stores with brackets via URL.hostname)
("[::1]", "http://[::1]/test", True),
("[::1]", "http://[::1]:80/test", True),
("[::1]", "https://[::1]:443/test", True),
("[::1]", "http://[::1]:8080/test", False), # Non-standard port
("[::1]:8080", "http://[::1]:8080/test", True),
("[::1]:8080", "http://[::1]:9090/test", False),
("[2001:db8::1]", "http://[2001:db8::1]/path", True),
("[2001:db8::1]", "https://[2001:db8::1]:443/path", True),
("[2001:db8::1]", "http://[2001:db8::ff]/path", False),
], ],
) )
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool): def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):

View File

@@ -1,4 +1,3 @@
import asyncio
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
@@ -226,10 +225,6 @@ class SyncRabbitMQ(RabbitMQBase):
class AsyncRabbitMQ(RabbitMQBase): class AsyncRabbitMQ(RabbitMQBase):
"""Asynchronous RabbitMQ client""" """Asynchronous RabbitMQ client"""
def __init__(self, config: RabbitMQConfig):
super().__init__(config)
self._reconnect_lock: asyncio.Lock | None = None
@property @property
def is_connected(self) -> bool: def is_connected(self) -> bool:
return bool(self._connection and not self._connection.is_closed) return bool(self._connection and not self._connection.is_closed)
@@ -240,17 +235,7 @@ class AsyncRabbitMQ(RabbitMQBase):
@conn_retry("AsyncRabbitMQ", "Acquiring async connection") @conn_retry("AsyncRabbitMQ", "Acquiring async connection")
async def connect(self): async def connect(self):
if self.is_connected and self._channel and not self._channel.is_closed: if self.is_connected:
return
if (
self.is_connected
and self._connection
and (self._channel is None or self._channel.is_closed)
):
self._channel = await self._connection.channel()
await self._channel.set_qos(prefetch_count=1)
await self.declare_infrastructure()
return return
self._connection = await aio_pika.connect_robust( self._connection = await aio_pika.connect_robust(
@@ -306,46 +291,24 @@ class AsyncRabbitMQ(RabbitMQBase):
exchange, routing_key=queue.routing_key or queue.name exchange, routing_key=queue.routing_key or queue.name
) )
@property @func_retry
def _lock(self) -> asyncio.Lock: async def publish_message(
if self._reconnect_lock is None:
self._reconnect_lock = asyncio.Lock()
return self._reconnect_lock
async def _ensure_channel(self) -> aio_pika.abc.AbstractChannel:
"""Get a valid channel, reconnecting if the current one is stale.
Uses a lock to prevent concurrent reconnection attempts from racing.
"""
if self.is_ready:
return self._channel # type: ignore # is_ready guarantees non-None
async with self._lock:
# Double-check after acquiring lock
if self.is_ready:
return self._channel # type: ignore
self._channel = None
await self.connect()
if self._channel is None:
raise RuntimeError("Channel should be established after connect")
return self._channel
async def _publish_once(
self, self,
routing_key: str, routing_key: str,
message: str, message: str,
exchange: Optional[Exchange] = None, exchange: Optional[Exchange] = None,
persistent: bool = True, persistent: bool = True,
) -> None: ) -> None:
channel = await self._ensure_channel() if not self.is_ready:
await self.connect()
if self._channel is None:
raise RuntimeError("Channel should be established after connect")
if exchange: if exchange:
exchange_obj = await channel.get_exchange(exchange.name) exchange_obj = await self._channel.get_exchange(exchange.name)
else: else:
exchange_obj = channel.default_exchange exchange_obj = self._channel.default_exchange
await exchange_obj.publish( await exchange_obj.publish(
aio_pika.Message( aio_pika.Message(
@@ -359,23 +322,9 @@ class AsyncRabbitMQ(RabbitMQBase):
routing_key=routing_key, routing_key=routing_key,
) )
@func_retry
async def publish_message(
self,
routing_key: str,
message: str,
exchange: Optional[Exchange] = None,
persistent: bool = True,
) -> None:
try:
await self._publish_once(routing_key, message, exchange, persistent)
except aio_pika.exceptions.ChannelInvalidStateError:
logger.warning(
"RabbitMQ channel invalid, forcing reconnect and retrying publish"
)
async with self._lock:
self._channel = None
await self._publish_once(routing_key, message, exchange, persistent)
async def get_channel(self) -> aio_pika.abc.AbstractChannel: async def get_channel(self) -> aio_pika.abc.AbstractChannel:
return await self._ensure_channel() if not self.is_ready:
await self.connect()
if self._channel is None:
raise RuntimeError("Channel should be established after connect")
return self._channel

View File

@@ -373,7 +373,7 @@ def make_node_credentials_input_map(
# Get aggregated credentials fields for the graph # Get aggregated credentials fields for the graph
graph_cred_inputs = graph.aggregate_credentials_inputs() graph_cred_inputs = graph.aggregate_credentials_inputs()
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items(): for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
# Best-effort map: skip missing items # Best-effort map: skip missing items
if graph_input_name not in graph_credentials_input: if graph_input_name not in graph_credentials_input:
continue continue

View File

@@ -224,14 +224,6 @@ openweathermap_credentials = APIKeyCredentials(
expires_at=None, expires_at=None,
) )
elevenlabs_credentials = APIKeyCredentials(
id="f4a8b6c2-3d1e-4f5a-9b8c-7d6e5f4a3b2c",
provider="elevenlabs",
api_key=SecretStr(settings.secrets.elevenlabs_api_key),
title="Use Credits for ElevenLabs",
expires_at=None,
)
DEFAULT_CREDENTIALS = [ DEFAULT_CREDENTIALS = [
ollama_credentials, ollama_credentials,
revid_credentials, revid_credentials,
@@ -260,7 +252,6 @@ DEFAULT_CREDENTIALS = [
v0_credentials, v0_credentials,
webshare_proxy_credentials, webshare_proxy_credentials,
openweathermap_credentials, openweathermap_credentials,
elevenlabs_credentials,
] ]
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS} SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
@@ -375,8 +366,6 @@ class IntegrationCredentialsStore:
all_credentials.append(webshare_proxy_credentials) all_credentials.append(webshare_proxy_credentials)
if settings.secrets.openweathermap_api_key: if settings.secrets.openweathermap_api_key:
all_credentials.append(openweathermap_credentials) all_credentials.append(openweathermap_credentials)
if settings.secrets.elevenlabs_api_key:
all_credentials.append(elevenlabs_credentials)
return all_credentials return all_credentials
async def get_creds_by_id( async def get_creds_by_id(

View File

@@ -18,7 +18,6 @@ class ProviderName(str, Enum):
DISCORD = "discord" DISCORD = "discord"
D_ID = "d_id" D_ID = "d_id"
E2B = "e2b" E2B = "e2b"
ELEVENLABS = "elevenlabs"
FAL = "fal" FAL = "fal"
GITHUB = "github" GITHUB = "github"
GOOGLE = "google" GOOGLE = "google"

View File

@@ -8,8 +8,6 @@ from pathlib import Path
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Literal
from urllib.parse import urlparse from urllib.parse import urlparse
from pydantic import BaseModel
from backend.util.cloud_storage import get_cloud_storage_handler from backend.util.cloud_storage import get_cloud_storage_handler
from backend.util.request import Requests from backend.util.request import Requests
from backend.util.settings import Config from backend.util.settings import Config
@@ -19,35 +17,6 @@ from backend.util.virus_scanner import scan_content_safe
if TYPE_CHECKING: if TYPE_CHECKING:
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
class WorkspaceUri(BaseModel):
"""Parsed workspace:// URI."""
file_ref: str # File ID or path (e.g. "abc123" or "/path/to/file.txt")
mime_type: str | None = None # MIME type from fragment (e.g. "video/mp4")
is_path: bool = False # True if file_ref is a path (starts with "/")
def parse_workspace_uri(uri: str) -> WorkspaceUri:
"""Parse a workspace:// URI into its components.
Examples:
"workspace://abc123" → WorkspaceUri(file_ref="abc123", mime_type=None, is_path=False)
"workspace://abc123#video/mp4" → WorkspaceUri(file_ref="abc123", mime_type="video/mp4", is_path=False)
"workspace:///path/to/file.txt" → WorkspaceUri(file_ref="/path/to/file.txt", mime_type=None, is_path=True)
"""
raw = uri.removeprefix("workspace://")
mime_type: str | None = None
if "#" in raw:
raw, fragment = raw.split("#", 1)
mime_type = fragment or None
return WorkspaceUri(
file_ref=raw,
mime_type=mime_type,
is_path=raw.startswith("/"),
)
# Return format options for store_media_file # Return format options for store_media_file
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc. # - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs # - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
@@ -214,20 +183,22 @@ async def store_media_file(
"This file type is only available in CoPilot sessions." "This file type is only available in CoPilot sessions."
) )
# Parse workspace reference (strips #mimeType fragment from file ID) # Parse workspace reference
ws = parse_workspace_uri(file) # workspace://abc123 - by file ID
# workspace:///path/to/file.txt - by virtual path
file_ref = file[12:] # Remove "workspace://"
if ws.is_path: if file_ref.startswith("/"):
# Path reference: workspace:///path/to/file.txt # Path reference
workspace_content = await workspace_manager.read_file(ws.file_ref) workspace_content = await workspace_manager.read_file(file_ref)
file_info = await workspace_manager.get_file_info_by_path(ws.file_ref) file_info = await workspace_manager.get_file_info_by_path(file_ref)
filename = sanitize_filename( filename = sanitize_filename(
file_info.name if file_info else f"{uuid.uuid4()}.bin" file_info.name if file_info else f"{uuid.uuid4()}.bin"
) )
else: else:
# ID reference: workspace://abc123 or workspace://abc123#video/mp4 # ID reference
workspace_content = await workspace_manager.read_file_by_id(ws.file_ref) workspace_content = await workspace_manager.read_file_by_id(file_ref)
file_info = await workspace_manager.get_file_info(ws.file_ref) file_info = await workspace_manager.get_file_info(file_ref)
filename = sanitize_filename( filename = sanitize_filename(
file_info.name if file_info else f"{uuid.uuid4()}.bin" file_info.name if file_info else f"{uuid.uuid4()}.bin"
) )
@@ -342,14 +313,6 @@ async def store_media_file(
if not target_path.is_file(): if not target_path.is_file():
raise ValueError(f"Local file does not exist: {target_path}") raise ValueError(f"Local file does not exist: {target_path}")
# Virus scan the local file before any further processing
local_content = target_path.read_bytes()
if len(local_content) > MAX_FILE_SIZE_BYTES:
raise ValueError(
f"File too large: {len(local_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
)
await scan_content_safe(local_content, filename=sanitized_file)
# Return based on requested format # Return based on requested format
if return_format == "for_local_processing": if return_format == "for_local_processing":
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL # Use when processing files locally with tools like ffmpeg, MoviePy, PIL
@@ -371,21 +334,7 @@ async def store_media_file(
# Don't re-save if input was already from workspace # Don't re-save if input was already from workspace
if is_from_workspace: if is_from_workspace:
# Return original workspace reference, ensuring MIME type fragment # Return original workspace reference
ws = parse_workspace_uri(file)
if not ws.mime_type:
# Add MIME type fragment if missing (older refs without it)
try:
if ws.is_path:
info = await workspace_manager.get_file_info_by_path(
ws.file_ref
)
else:
info = await workspace_manager.get_file_info(ws.file_ref)
if info:
return MediaFileType(f"{file}#{info.mimeType}")
except Exception:
pass
return MediaFileType(file) return MediaFileType(file)
# Save new content to workspace # Save new content to workspace
@@ -397,7 +346,7 @@ async def store_media_file(
filename=filename, filename=filename,
overwrite=True, overwrite=True,
) )
return MediaFileType(f"workspace://{file_record.id}#{file_record.mimeType}") return MediaFileType(f"workspace://{file_record.id}")
else: else:
raise ValueError(f"Invalid return_format: {return_format}") raise ValueError(f"Invalid return_format: {return_format}")

View File

@@ -247,100 +247,3 @@ class TestFileCloudIntegration:
execution_context=make_test_context(graph_exec_id=graph_exec_id), execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing", return_format="for_local_processing",
) )
@pytest.mark.asyncio
async def test_store_media_file_local_path_scanned(self):
"""Test that local file paths are scanned for viruses."""
graph_exec_id = "test-exec-123"
local_file = "test_video.mp4"
file_content = b"fake video content"
with patch(
"backend.util.file.get_cloud_storage_handler"
) as mock_handler_getter, patch(
"backend.util.file.scan_content_safe"
) as mock_scan, patch(
"backend.util.file.Path"
) as mock_path_class:
# Mock cloud storage handler - not a cloud path
mock_handler = MagicMock()
mock_handler.is_cloud_path.return_value = False
mock_handler_getter.return_value = mock_handler
# Mock virus scanner
mock_scan.return_value = None
# Mock file system operations
mock_base_path = MagicMock()
mock_target_path = MagicMock()
mock_resolved_path = MagicMock()
mock_path_class.return_value = mock_base_path
mock_base_path.mkdir = MagicMock()
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
mock_target_path.resolve.return_value = mock_resolved_path
mock_resolved_path.is_relative_to.return_value = True
mock_resolved_path.is_file.return_value = True
mock_resolved_path.read_bytes.return_value = file_content
mock_resolved_path.relative_to.return_value = Path(local_file)
mock_resolved_path.name = local_file
result = await store_media_file(
file=MediaFileType(local_file),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)
# Verify virus scan was called for local file
mock_scan.assert_called_once_with(file_content, filename=local_file)
# Result should be the relative path
assert str(result) == local_file
@pytest.mark.asyncio
async def test_store_media_file_local_path_virus_detected(self):
"""Test that infected local files raise VirusDetectedError."""
from backend.api.features.store.exceptions import VirusDetectedError
graph_exec_id = "test-exec-123"
local_file = "infected.exe"
file_content = b"malicious content"
with patch(
"backend.util.file.get_cloud_storage_handler"
) as mock_handler_getter, patch(
"backend.util.file.scan_content_safe"
) as mock_scan, patch(
"backend.util.file.Path"
) as mock_path_class:
# Mock cloud storage handler - not a cloud path
mock_handler = MagicMock()
mock_handler.is_cloud_path.return_value = False
mock_handler_getter.return_value = mock_handler
# Mock virus scanner to detect virus
mock_scan.side_effect = VirusDetectedError(
"EICAR-Test-File", "File rejected due to virus detection"
)
# Mock file system operations
mock_base_path = MagicMock()
mock_target_path = MagicMock()
mock_resolved_path = MagicMock()
mock_path_class.return_value = mock_base_path
mock_base_path.mkdir = MagicMock()
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
mock_target_path.resolve.return_value = mock_resolved_path
mock_resolved_path.is_relative_to.return_value = True
mock_resolved_path.is_file.return_value = True
mock_resolved_path.read_bytes.return_value = file_content
with pytest.raises(VirusDetectedError):
await store_media_file(
file=MediaFileType(local_file),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)

View File

@@ -157,7 +157,12 @@ async def validate_url(
is_trusted: Boolean indicating if the hostname is in trusted_origins is_trusted: Boolean indicating if the hostname is in trusted_origins
ip_addresses: List of IP addresses for the host; empty if the host is trusted ip_addresses: List of IP addresses for the host; empty if the host is trusted
""" """
parsed = parse_url(url) # Canonicalize URL
url = url.strip("/ ").replace("\\", "/")
parsed = urlparse(url)
if not parsed.scheme:
url = f"http://{url}"
parsed = urlparse(url)
# Check scheme # Check scheme
if parsed.scheme not in ALLOWED_SCHEMES: if parsed.scheme not in ALLOWED_SCHEMES:
@@ -215,17 +220,6 @@ async def validate_url(
) )
def parse_url(url: str) -> URL:
"""Canonicalizes and parses a URL string."""
url = url.strip("/ ").replace("\\", "/")
# Ensure scheme is present for proper parsing
if not re.match(r"[a-z0-9+.\-]+://", url):
url = f"http://{url}"
return urlparse(url)
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL: def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
""" """
Pins a URL to a specific IP address to prevent DNS rebinding attacks. Pins a URL to a specific IP address to prevent DNS rebinding attacks.

View File

@@ -656,7 +656,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
e2b_api_key: str = Field(default="", description="E2B API key") e2b_api_key: str = Field(default="", description="E2B API key")
nvidia_api_key: str = Field(default="", description="Nvidia API key") nvidia_api_key: str = Field(default="", description="Nvidia API key")
mem0_api_key: str = Field(default="", description="Mem0 API key") mem0_api_key: str = Field(default="", description="Mem0 API key")
elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key")
linear_client_id: str = Field(default="", description="Linear client ID") linear_client_id: str = Field(default="", description="Linear client ID")
linear_client_secret: str = Field(default="", description="Linear client secret") linear_client_secret: str = Field(default="", description="Linear client secret")

View File

@@ -22,7 +22,6 @@ from backend.data.workspace import (
soft_delete_workspace_file, soft_delete_workspace_file,
) )
from backend.util.settings import Config from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -188,9 +187,6 @@ class WorkspaceManager:
f"{Config().max_file_size_mb}MB limit" f"{Config().max_file_size_mb}MB limit"
) )
# Virus scan content before persisting (defense in depth)
await scan_content_safe(content, filename=filename)
# Determine path with session scoping # Determine path with session scoping
if path is None: if path is None:
path = f"/{filename}" path = f"/{filename}"

File diff suppressed because it is too large Load Diff

View File

@@ -12,16 +12,15 @@ python = ">=3.10,<3.14"
aio-pika = "^9.5.5" aio-pika = "^9.5.5"
aiohttp = "^3.10.0" aiohttp = "^3.10.0"
aiodns = "^3.5.0" aiodns = "^3.5.0"
anthropic = "^0.79.0" anthropic = "^0.59.0"
apscheduler = "^3.11.1" apscheduler = "^3.11.1"
autogpt-libs = { path = "../autogpt_libs", develop = true } autogpt-libs = { path = "../autogpt_libs", develop = true }
bleach = { extras = ["css"], version = "^6.2.0" } bleach = { extras = ["css"], version = "^6.2.0" }
click = "^8.2.0" click = "^8.2.0"
cryptography = "^46.0" cryptography = "^45.0"
discord-py = "^2.5.2" discord-py = "^2.5.2"
e2b-code-interpreter = "^1.5.2" e2b-code-interpreter = "^1.5.2"
elevenlabs = "^1.50.0" fastapi = "^0.116.1"
fastapi = "^0.128.6"
feedparser = "^6.0.11" feedparser = "^6.0.11"
flake8 = "^7.3.0" flake8 = "^7.3.0"
google-api-python-client = "^2.177.0" google-api-python-client = "^2.177.0"
@@ -34,11 +33,11 @@ html2text = "^2024.2.26"
jinja2 = "^3.1.6" jinja2 = "^3.1.6"
jsonref = "^1.1.0" jsonref = "^1.1.0"
jsonschema = "^4.25.0" jsonschema = "^4.25.0"
langfuse = "^3.14.1" langfuse = "^3.11.0"
launchdarkly-server-sdk = "^9.14.1" launchdarkly-server-sdk = "^9.12.0"
mem0ai = "^0.1.115" mem0ai = "^0.1.115"
moviepy = "^2.1.2" moviepy = "^2.1.2"
ollama = "^0.6.1" ollama = "^0.5.1"
openai = "^1.97.1" openai = "^1.97.1"
orjson = "^3.10.0" orjson = "^3.10.0"
pika = "^1.3.2" pika = "^1.3.2"
@@ -48,16 +47,16 @@ postmarker = "^1.0"
praw = "~7.8.1" praw = "~7.8.1"
prisma = "^0.15.0" prisma = "^0.15.0"
rank-bm25 = "^0.2.2" rank-bm25 = "^0.2.2"
prometheus-client = "^0.24.1" prometheus-client = "^0.22.1"
prometheus-fastapi-instrumentator = "^7.0.0" prometheus-fastapi-instrumentator = "^7.0.0"
psutil = "^7.0.0" psutil = "^7.0.0"
psycopg2-binary = "^2.9.10" psycopg2-binary = "^2.9.10"
pydantic = { extras = ["email"], version = "^2.12.5" } pydantic = { extras = ["email"], version = "^2.11.7" }
pydantic-settings = "^2.12.0" pydantic-settings = "^2.10.1"
pytest = "^8.4.1" pytest = "^8.4.1"
pytest-asyncio = "^1.1.0" pytest-asyncio = "^1.1.0"
python-dotenv = "^1.1.1" python-dotenv = "^1.1.1"
python-multipart = "^0.0.22" python-multipart = "^0.0.20"
redis = "^6.2.0" redis = "^6.2.0"
regex = "^2025.9.18" regex = "^2025.9.18"
replicate = "^1.0.6" replicate = "^1.0.6"
@@ -65,19 +64,18 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal
sqlalchemy = "^2.0.40" sqlalchemy = "^2.0.40"
strenum = "^0.4.9" strenum = "^0.4.9"
stripe = "^11.5.0" stripe = "^11.5.0"
supabase = "2.27.3" supabase = "2.17.0"
tenacity = "^9.1.4" tenacity = "^9.1.2"
todoist-api-python = "^2.1.7" todoist-api-python = "^2.1.7"
tweepy = "^4.16.0" tweepy = "^4.16.0"
uvicorn = { extras = ["standard"], version = "^0.40.0" } uvicorn = { extras = ["standard"], version = "^0.35.0" }
websockets = "^15.0" websockets = "^15.0"
youtube-transcript-api = "^1.2.1" youtube-transcript-api = "^1.2.1"
yt-dlp = "2025.12.08"
zerobouncesdk = "^1.1.2" zerobouncesdk = "^1.1.2"
# NOTE: please insert new dependencies in their alphabetical location # NOTE: please insert new dependencies in their alphabetical location
pytest-snapshot = "^0.9.0" pytest-snapshot = "^0.9.0"
aiofiles = "^24.1.0" aiofiles = "^24.1.0"
tiktoken = "^0.12.0" tiktoken = "^0.9.0"
aioclamd = "^1.0.0" aioclamd = "^1.0.0"
setuptools = "^80.9.0" setuptools = "^80.9.0"
gcloud-aio-storage = "^9.5.0" gcloud-aio-storage = "^9.5.0"
@@ -95,13 +93,13 @@ black = "^24.10.0"
faker = "^38.2.0" faker = "^38.2.0"
httpx = "^0.28.1" httpx = "^0.28.1"
isort = "^5.13.2" isort = "^5.13.2"
poethepoet = "^0.41.0" poethepoet = "^0.37.0"
pre-commit = "^4.4.0" pre-commit = "^4.4.0"
pyright = "^1.1.407" pyright = "^1.1.407"
pytest-mock = "^3.15.1" pytest-mock = "^3.15.1"
pytest-watcher = "^0.6.3" pytest-watcher = "^0.4.2"
requests = "^2.32.5" requests = "^2.32.5"
ruff = "^0.15.0" ruff = "^0.14.5"
# NOTE: please insert new dependencies in their alphabetical location # NOTE: please insert new dependencies in their alphabetical location
[build-system] [build-system]

View File

@@ -3,6 +3,7 @@
"credentials_input_schema": { "credentials_input_schema": {
"properties": {}, "properties": {},
"required": [], "required": [],
"title": "TestGraphCredentialsInputSchema",
"type": "object" "type": "object"
}, },
"description": "A test graph", "description": "A test graph",

View File

@@ -1,14 +1,34 @@
[ [
{ {
"created_at": "2025-09-04T13:37:00", "credentials_input_schema": {
"properties": {},
"required": [],
"title": "TestGraphCredentialsInputSchema",
"type": "object"
},
"description": "A test graph", "description": "A test graph",
"forked_from_id": null, "forked_from_id": null,
"forked_from_version": null, "forked_from_version": null,
"has_external_trigger": false,
"has_human_in_the_loop": false,
"has_sensitive_action": false,
"id": "graph-123", "id": "graph-123",
"input_schema": {
"properties": {},
"required": [],
"type": "object"
},
"instructions": null, "instructions": null,
"is_active": true, "is_active": true,
"name": "Test Graph", "name": "Test Graph",
"output_schema": {
"properties": {},
"required": [],
"type": "object"
},
"recommended_schedule_cron": null, "recommended_schedule_cron": null,
"sub_graphs": [],
"trigger_setup_info": null,
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
"version": 1 "version": 1
} }

View File

@@ -25,12 +25,8 @@ RUN if [ -f .env.production ]; then \
cp .env.default .env; \ cp .env.default .env; \
fi fi
RUN pnpm run generate:api RUN pnpm run generate:api
# Disable source-map generation in Docker builds to halve webpack memory usage.
# Source maps are only useful when SENTRY_AUTH_TOKEN is set (Vercel deploys);
# the Docker image never uploads them, so generating them just wastes RAM.
ENV NEXT_PUBLIC_SOURCEMAPS="false"
# In CI, we want NEXT_PUBLIC_PW_TEST=true during build so Next.js inlines it # In CI, we want NEXT_PUBLIC_PW_TEST=true during build so Next.js inlines it
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=8192" pnpm build; else NODE_OPTIONS="--max-old-space-size=8192" pnpm build; fi RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=4096" pnpm build; else NODE_OPTIONS="--max-old-space-size=4096" pnpm build; fi
# Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile # Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile
FROM node:21-alpine AS prod FROM node:21-alpine AS prod

View File

@@ -1,12 +1,8 @@
import { withSentryConfig } from "@sentry/nextjs"; import { withSentryConfig } from "@sentry/nextjs";
// Allow Docker builds to skip source-map generation (halves memory usage).
// Defaults to true so Vercel/local builds are unaffected.
const enableSourceMaps = process.env.NEXT_PUBLIC_SOURCEMAPS !== "false";
/** @type {import('next').NextConfig} */ /** @type {import('next').NextConfig} */
const nextConfig = { const nextConfig = {
productionBrowserSourceMaps: enableSourceMaps, productionBrowserSourceMaps: true,
// Externalize OpenTelemetry packages to fix Turbopack HMR issues // Externalize OpenTelemetry packages to fix Turbopack HMR issues
serverExternalPackages: [ serverExternalPackages: [
"@opentelemetry/instrumentation", "@opentelemetry/instrumentation",
@@ -18,37 +14,9 @@ const nextConfig = {
serverActions: { serverActions: {
bodySizeLimit: "256mb", bodySizeLimit: "256mb",
}, },
// Increase body size limit for API routes (file uploads) - 256MB to match backend limit
proxyClientMaxBodySize: "256mb",
middlewareClientMaxBodySize: "256mb", middlewareClientMaxBodySize: "256mb",
// Limit parallel webpack workers to reduce peak memory during builds.
cpus: 2,
},
// Work around cssnano "Invalid array length" bug in Next.js's bundled
// cssnano-simple comment parser when processing very large CSS chunks.
// CSS is still bundled correctly; gzip handles most of the size savings anyway.
webpack: (config, { dev }) => {
if (!dev) {
// Next.js adds CssMinimizerPlugin internally (after user config), so we
// can't filter it from config.plugins. Instead, intercept the webpack
// compilation hooks and replace the buggy plugin's tap with a no-op.
config.plugins.push({
apply(compiler) {
compiler.hooks.compilation.tap(
"DisableCssMinimizer",
(compilation) => {
compilation.hooks.processAssets.intercept({
register: (tap) => {
if (tap.name === "CssMinimizerPlugin") {
return { ...tap, fn: async () => {} };
}
return tap;
},
});
},
);
},
});
}
return config;
}, },
images: { images: {
domains: [ domains: [
@@ -86,16 +54,9 @@ const nextConfig = {
transpilePackages: ["geist"], transpilePackages: ["geist"],
}; };
// Only run the Sentry webpack plugin when we can actually upload source maps const isDevelopmentBuild = process.env.NODE_ENV !== "production";
// (i.e. on Vercel with SENTRY_AUTH_TOKEN set). The Sentry *runtime* SDK
// (imported in app code) still captures errors without the plugin.
// Skipping the plugin saves ~1 GB of peak memory during `next build`.
const skipSentryPlugin =
process.env.NODE_ENV !== "production" ||
!enableSourceMaps ||
!process.env.SENTRY_AUTH_TOKEN;
export default skipSentryPlugin export default isDevelopmentBuild
? nextConfig ? nextConfig
: withSentryConfig(nextConfig, { : withSentryConfig(nextConfig, {
// For all available options, see: // For all available options, see:
@@ -135,7 +96,7 @@ export default skipSentryPlugin
// This helps Sentry with sourcemaps... https://docs.sentry.io/platforms/javascript/guides/nextjs/sourcemaps/ // This helps Sentry with sourcemaps... https://docs.sentry.io/platforms/javascript/guides/nextjs/sourcemaps/
sourcemaps: { sourcemaps: {
disable: !enableSourceMaps, disable: false,
assets: [".next/**/*.js", ".next/**/*.js.map"], assets: [".next/**/*.js", ".next/**/*.js.map"],
ignore: ["**/node_modules/**"], ignore: ["**/node_modules/**"],
deleteSourcemapsAfterUpload: false, // Source is public anyway :) deleteSourcemapsAfterUpload: false, // Source is public anyway :)

View File

@@ -7,7 +7,7 @@
}, },
"scripts": { "scripts": {
"dev": "pnpm run generate:api:force && next dev --turbo", "dev": "pnpm run generate:api:force && next dev --turbo",
"build": "cross-env NODE_OPTIONS=--max-old-space-size=16384 next build", "build": "next build",
"start": "next start", "start": "next start",
"start:standalone": "cd .next/standalone && node server.js", "start:standalone": "cd .next/standalone && node server.js",
"lint": "next lint && prettier --check .", "lint": "next lint && prettier --check .",
@@ -30,7 +30,6 @@
"defaults" "defaults"
], ],
"dependencies": { "dependencies": {
"@ai-sdk/react": "3.0.61",
"@faker-js/faker": "10.0.0", "@faker-js/faker": "10.0.0",
"@hookform/resolvers": "5.2.2", "@hookform/resolvers": "5.2.2",
"@next/third-parties": "15.4.6", "@next/third-parties": "15.4.6",
@@ -61,10 +60,6 @@
"@rjsf/utils": "6.1.2", "@rjsf/utils": "6.1.2",
"@rjsf/validator-ajv8": "6.1.2", "@rjsf/validator-ajv8": "6.1.2",
"@sentry/nextjs": "10.27.0", "@sentry/nextjs": "10.27.0",
"@streamdown/cjk": "1.0.1",
"@streamdown/code": "1.0.1",
"@streamdown/math": "1.0.1",
"@streamdown/mermaid": "1.0.1",
"@supabase/ssr": "0.7.0", "@supabase/ssr": "0.7.0",
"@supabase/supabase-js": "2.78.0", "@supabase/supabase-js": "2.78.0",
"@tanstack/react-query": "5.90.6", "@tanstack/react-query": "5.90.6",
@@ -73,7 +68,6 @@
"@vercel/analytics": "1.5.0", "@vercel/analytics": "1.5.0",
"@vercel/speed-insights": "1.2.0", "@vercel/speed-insights": "1.2.0",
"@xyflow/react": "12.9.2", "@xyflow/react": "12.9.2",
"ai": "6.0.59",
"boring-avatars": "1.11.2", "boring-avatars": "1.11.2",
"class-variance-authority": "0.7.1", "class-variance-authority": "0.7.1",
"clsx": "2.1.1", "clsx": "2.1.1",
@@ -93,6 +87,7 @@
"launchdarkly-react-client-sdk": "3.9.0", "launchdarkly-react-client-sdk": "3.9.0",
"lodash": "4.17.21", "lodash": "4.17.21",
"lucide-react": "0.552.0", "lucide-react": "0.552.0",
"moment": "2.30.1",
"next": "15.4.10", "next": "15.4.10",
"next-themes": "0.4.6", "next-themes": "0.4.6",
"nuqs": "2.7.2", "nuqs": "2.7.2",
@@ -107,7 +102,7 @@
"react-markdown": "9.0.3", "react-markdown": "9.0.3",
"react-modal": "3.16.3", "react-modal": "3.16.3",
"react-shepherd": "6.1.9", "react-shepherd": "6.1.9",
"react-window": "2.2.0", "react-window": "1.8.11",
"recharts": "3.3.0", "recharts": "3.3.0",
"rehype-autolink-headings": "7.1.0", "rehype-autolink-headings": "7.1.0",
"rehype-highlight": "7.0.2", "rehype-highlight": "7.0.2",
@@ -117,11 +112,9 @@
"remark-math": "6.0.0", "remark-math": "6.0.0",
"shepherd.js": "14.5.1", "shepherd.js": "14.5.1",
"sonner": "2.0.7", "sonner": "2.0.7",
"streamdown": "2.1.0",
"tailwind-merge": "2.6.0", "tailwind-merge": "2.6.0",
"tailwind-scrollbar": "3.1.0", "tailwind-scrollbar": "3.1.0",
"tailwindcss-animate": "1.0.7", "tailwindcss-animate": "1.0.7",
"use-stick-to-bottom": "1.1.2",
"uuid": "11.1.0", "uuid": "11.1.0",
"vaul": "1.1.2", "vaul": "1.1.2",
"zod": "3.25.76", "zod": "3.25.76",
@@ -147,7 +140,7 @@
"@types/react": "18.3.17", "@types/react": "18.3.17",
"@types/react-dom": "18.3.5", "@types/react-dom": "18.3.5",
"@types/react-modal": "3.16.3", "@types/react-modal": "3.16.3",
"@types/react-window": "2.0.0", "@types/react-window": "1.8.8",
"@vitejs/plugin-react": "5.1.2", "@vitejs/plugin-react": "5.1.2",
"axe-playwright": "2.2.2", "axe-playwright": "2.2.2",
"chromatic": "13.3.3", "chromatic": "13.3.3",
@@ -179,8 +172,7 @@
}, },
"pnpm": { "pnpm": {
"overrides": { "overrides": {
"@opentelemetry/instrumentation": "0.209.0", "@opentelemetry/instrumentation": "0.209.0"
"lodash-es": "4.17.23"
} }
}, },
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd" "packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput"; import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
import { GraphModel } from "@/app/api/__generated__/models/graphModel"; import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput"; import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput";
import { useState } from "react"; import { useState } from "react";
import { getSchemaDefaultCredentials } from "../../helpers"; import { getSchemaDefaultCredentials } from "../../helpers";
@@ -9,7 +9,7 @@ type Credential = CredentialsMetaInput | undefined;
type Credentials = Record<string, Credential>; type Credentials = Record<string, Credential>;
type Props = { type Props = {
agent: GraphModel | null; agent: GraphMeta | null;
siblingInputs?: Record<string, any>; siblingInputs?: Record<string, any>;
onCredentialsChange: ( onCredentialsChange: (
credentials: Record<string, CredentialsMetaInput>, credentials: Record<string, CredentialsMetaInput>,

View File

@@ -1,9 +1,9 @@
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput"; import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
import { GraphModel } from "@/app/api/__generated__/models/graphModel"; import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api/types"; import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api/types";
export function getCredentialFields( export function getCredentialFields(
agent: GraphModel | null, agent: GraphMeta | null,
): AgentCredentialsFields { ): AgentCredentialsFields {
if (!agent) return {}; if (!agent) return {};

View File

@@ -3,10 +3,10 @@ import type {
CredentialsMetaInput, CredentialsMetaInput,
} from "@/lib/autogpt-server-api/types"; } from "@/lib/autogpt-server-api/types";
import type { InputValues } from "./types"; import type { InputValues } from "./types";
import { GraphModel } from "@/app/api/__generated__/models/graphModel"; import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
export function computeInitialAgentInputs( export function computeInitialAgentInputs(
agent: GraphModel | null, agent: GraphMeta | null,
existingInputs?: InputValues | null, existingInputs?: InputValues | null,
): InputValues { ): InputValues {
const properties = agent?.input_schema?.properties || {}; const properties = agent?.input_schema?.properties || {};
@@ -29,7 +29,7 @@ export function computeInitialAgentInputs(
} }
type IsRunDisabledParams = { type IsRunDisabledParams = {
agent: GraphModel | null; agent: GraphMeta | null;
isRunning: boolean; isRunning: boolean;
agentInputs: InputValues | null | undefined; agentInputs: InputValues | null | undefined;
}; };

View File

@@ -1,17 +1,6 @@
import { OAuthPopupResultMessage } from "./types"; import { OAuthPopupResultMessage } from "./types";
import { NextResponse } from "next/server"; import { NextResponse } from "next/server";
/**
* Safely encode a value as JSON for embedding in a script tag.
* Escapes characters that could break out of the script context to prevent XSS.
*/
function safeJsonStringify(value: unknown): string {
return JSON.stringify(value)
.replace(/</g, "\\u003c")
.replace(/>/g, "\\u003e")
.replace(/&/g, "\\u0026");
}
// This route is intended to be used as the callback for integration OAuth flows, // This route is intended to be used as the callback for integration OAuth flows,
// controlled by the CredentialsInput component. The CredentialsInput opens the login // controlled by the CredentialsInput component. The CredentialsInput opens the login
// page in a pop-up window, which then redirects to this route to close the loop. // page in a pop-up window, which then redirects to this route to close the loop.
@@ -34,13 +23,12 @@ export async function GET(request: Request) {
console.debug("Sending message to opener:", message); console.debug("Sending message to opener:", message);
// Return a response with the message as JSON and a script to close the window // Return a response with the message as JSON and a script to close the window
// Use safeJsonStringify to prevent XSS by escaping <, >, and & characters
return new NextResponse( return new NextResponse(
` `
<html> <html>
<body> <body>
<script> <script>
window.opener.postMessage(${safeJsonStringify(message)}); window.opener.postMessage(${JSON.stringify(message)});
window.close(); window.close();
</script> </script>
</body> </body>

View File

@@ -1,4 +1,4 @@
import debounce from "lodash/debounce"; import { debounce } from "lodash";
import { useCallback, useEffect, useRef, useState } from "react"; import { useCallback, useEffect, useRef, useState } from "react";
import { useBlockMenuStore } from "../../../../stores/blockMenuStore"; import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
import { getQueryClient } from "@/lib/react-query/queryClient"; import { getQueryClient } from "@/lib/react-query/queryClient";

View File

@@ -70,10 +70,10 @@ export const HorizontalScroll: React.FC<HorizontalScrollAreaProps> = ({
{children} {children}
</div> </div>
{canScrollLeft && ( {canScrollLeft && (
<div className="pointer-events-none absolute inset-y-0 left-0 w-8 bg-gradient-to-r from-background via-background/80 to-background/0" /> <div className="pointer-events-none absolute inset-y-0 left-0 w-8 bg-gradient-to-r from-white via-white/80 to-white/0" />
)} )}
{canScrollRight && ( {canScrollRight && (
<div className="pointer-events-none absolute inset-y-0 right-0 w-8 bg-gradient-to-l from-background via-background/80 to-background/0" /> <div className="pointer-events-none absolute inset-y-0 right-0 w-8 bg-gradient-to-l from-white via-white/80 to-white/0" />
)} )}
{canScrollLeft && ( {canScrollLeft && (
<button <button

View File

@@ -30,8 +30,6 @@ import {
} from "@/components/atoms/Tooltip/BaseTooltip"; } from "@/components/atoms/Tooltip/BaseTooltip";
import { GraphMeta } from "@/lib/autogpt-server-api"; import { GraphMeta } from "@/lib/autogpt-server-api";
import jaro from "jaro-winkler"; import jaro from "jaro-winkler";
import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { okData } from "@/app/api/helpers";
type _Block = Omit<Block, "inputSchema" | "outputSchema"> & { type _Block = Omit<Block, "inputSchema" | "outputSchema"> & {
uiKey?: string; uiKey?: string;
@@ -109,8 +107,6 @@ export function BlocksControl({
.filter((b) => b.uiType !== BlockUIType.AGENT) .filter((b) => b.uiType !== BlockUIType.AGENT)
.sort((a, b) => a.name.localeCompare(b.name)); .sort((a, b) => a.name.localeCompare(b.name));
// Agent blocks are created from GraphMeta which doesn't include schemas.
// Schemas will be fetched on-demand when the block is actually added.
const agentBlockList = flows const agentBlockList = flows
.map((flow): _Block => { .map((flow): _Block => {
return { return {
@@ -120,9 +116,8 @@ export function BlocksControl({
`Ver.${flow.version}` + `Ver.${flow.version}` +
(flow.description ? ` | ${flow.description}` : ""), (flow.description ? ` | ${flow.description}` : ""),
categories: [{ category: "AGENT", description: "" }], categories: [{ category: "AGENT", description: "" }],
// Empty schemas - will be populated when block is added inputSchema: flow.input_schema,
inputSchema: { type: "object", properties: {} }, outputSchema: flow.output_schema,
outputSchema: { type: "object", properties: {} },
staticOutput: false, staticOutput: false,
uiType: BlockUIType.AGENT, uiType: BlockUIType.AGENT,
costs: [], costs: [],
@@ -130,7 +125,8 @@ export function BlocksControl({
hardcodedValues: { hardcodedValues: {
graph_id: flow.id, graph_id: flow.id,
graph_version: flow.version, graph_version: flow.version,
// Schemas will be fetched on-demand when block is added input_schema: flow.input_schema,
output_schema: flow.output_schema,
}, },
}; };
}) })
@@ -186,37 +182,6 @@ export function BlocksControl({
setSelectedCategory(null); setSelectedCategory(null);
}, []); }, []);
// Handler to add a block, fetching graph data on-demand for agent blocks
const handleAddBlock = useCallback(
async (block: _Block & { notAvailable: string | null }) => {
if (block.notAvailable) return;
// For agent blocks, fetch the full graph to get schemas
if (block.uiType === BlockUIType.AGENT && block.hardcodedValues) {
const graphID = block.hardcodedValues.graph_id as string;
const graphVersion = block.hardcodedValues.graph_version as number;
const graphData = okData(
await getV1GetSpecificGraph(graphID, { version: graphVersion }),
);
if (graphData) {
addBlock(block.id, block.name, {
...block.hardcodedValues,
input_schema: graphData.input_schema,
output_schema: graphData.output_schema,
});
} else {
// Fallback: add without schemas (will be incomplete)
console.error("Failed to fetch graph data for agent block");
addBlock(block.id, block.name, block.hardcodedValues || {});
}
} else {
addBlock(block.id, block.name, block.hardcodedValues || {});
}
},
[addBlock],
);
// Extract unique categories from blocks // Extract unique categories from blocks
const categories = useMemo(() => { const categories = useMemo(() => {
return Array.from( return Array.from(
@@ -338,7 +303,10 @@ export function BlocksControl({
}), }),
); );
}} }}
onClick={() => handleAddBlock(block)} onClick={() =>
!block.notAvailable &&
addBlock(block.id, block.name, block?.hardcodedValues || {})
}
title={block.notAvailable ?? undefined} title={block.notAvailable ?? undefined}
> >
<div <div

View File

@@ -1,6 +1,6 @@
import { beautifyString } from "@/lib/utils"; import { beautifyString } from "@/lib/utils";
import { Clipboard, Maximize2 } from "lucide-react"; import { Clipboard, Maximize2 } from "lucide-react";
import React, { useMemo, useState } from "react"; import React, { useState } from "react";
import { Button } from "../../../../../components/__legacy__/ui/button"; import { Button } from "../../../../../components/__legacy__/ui/button";
import { ContentRenderer } from "../../../../../components/__legacy__/ui/render"; import { ContentRenderer } from "../../../../../components/__legacy__/ui/render";
import { import {
@@ -11,12 +11,6 @@ import {
TableHeader, TableHeader,
TableRow, TableRow,
} from "../../../../../components/__legacy__/ui/table"; } from "../../../../../components/__legacy__/ui/table";
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
import {
globalRegistry,
OutputItem,
} from "@/components/contextual/OutputRenderers";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { useToast } from "../../../../../components/molecules/Toast/use-toast"; import { useToast } from "../../../../../components/molecules/Toast/use-toast";
import ExpandableOutputDialog from "./ExpandableOutputDialog"; import ExpandableOutputDialog from "./ExpandableOutputDialog";
@@ -32,9 +26,6 @@ export default function DataTable({
data, data,
}: DataTableProps) { }: DataTableProps) {
const { toast } = useToast(); const { toast } = useToast();
const enableEnhancedOutputHandling = useGetFlag(
Flag.ENABLE_ENHANCED_OUTPUT_HANDLING,
);
const [expandedDialog, setExpandedDialog] = useState<{ const [expandedDialog, setExpandedDialog] = useState<{
isOpen: boolean; isOpen: boolean;
execId: string; execId: string;
@@ -42,15 +33,6 @@ export default function DataTable({
data: any[]; data: any[];
} | null>(null); } | null>(null);
// Prepare renderers for each item when enhanced mode is enabled
const getItemRenderer = useMemo(() => {
if (!enableEnhancedOutputHandling) return null;
return (item: unknown) => {
const metadata: OutputMetadata = {};
return globalRegistry.getRenderer(item, metadata);
};
}, [enableEnhancedOutputHandling]);
const copyData = (pin: string, data: string) => { const copyData = (pin: string, data: string) => {
navigator.clipboard.writeText(data).then(() => { navigator.clipboard.writeText(data).then(() => {
toast({ toast({
@@ -120,31 +102,15 @@ export default function DataTable({
<Clipboard size={18} /> <Clipboard size={18} />
</Button> </Button>
</div> </div>
{value.map((item, index) => { {value.map((item, index) => (
const renderer = getItemRenderer?.(item); <React.Fragment key={index}>
if (enableEnhancedOutputHandling && renderer) { <ContentRenderer
const metadata: OutputMetadata = {}; value={item}
return ( truncateLongData={truncateLongData}
<React.Fragment key={index}> />
<OutputItem {index < value.length - 1 && ", "}
value={item} </React.Fragment>
metadata={metadata} ))}
renderer={renderer}
/>
{index < value.length - 1 && ", "}
</React.Fragment>
);
}
return (
<React.Fragment key={index}>
<ContentRenderer
value={item}
truncateLongData={truncateLongData}
/>
{index < value.length - 1 && ", "}
</React.Fragment>
);
})}
</div> </div>
</TableCell> </TableCell>
</TableRow> </TableRow>

View File

@@ -29,17 +29,13 @@ import "@xyflow/react/dist/style.css";
import { ConnectedEdge, CustomNode } from "../CustomNode/CustomNode"; import { ConnectedEdge, CustomNode } from "../CustomNode/CustomNode";
import "./flow.css"; import "./flow.css";
import { import {
BlockIORootSchema,
BlockUIType, BlockUIType,
formatEdgeID, formatEdgeID,
GraphExecutionID, GraphExecutionID,
GraphID, GraphID,
GraphMeta, GraphMeta,
LibraryAgent, LibraryAgent,
SpecialBlockID,
} from "@/lib/autogpt-server-api"; } from "@/lib/autogpt-server-api";
import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { okData } from "@/app/api/helpers";
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types"; import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
import { Key, storage } from "@/services/storage/local-storage"; import { Key, storage } from "@/services/storage/local-storage";
import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils"; import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils";
@@ -691,94 +687,8 @@ const FlowEditor: React.FC<{
[getNode, updateNode, nodes], [getNode, updateNode, nodes],
); );
/* Shared helper to create and add a node */
const createAndAddNode = useCallback(
async (
blockID: string,
blockName: string,
hardcodedValues: Record<string, any>,
position: { x: number; y: number },
): Promise<CustomNode | null> => {
const nodeSchema = availableBlocks.find((node) => node.id === blockID);
if (!nodeSchema) {
console.error(`Schema not found for block ID: ${blockID}`);
return null;
}
// For agent blocks, fetch the full graph to get schemas
let inputSchema: BlockIORootSchema = nodeSchema.inputSchema;
let outputSchema: BlockIORootSchema = nodeSchema.outputSchema;
let finalHardcodedValues = hardcodedValues;
if (blockID === SpecialBlockID.AGENT) {
const graphID = hardcodedValues.graph_id as string;
const graphVersion = hardcodedValues.graph_version as number;
const graphData = okData(
await getV1GetSpecificGraph(graphID, { version: graphVersion }),
);
if (graphData) {
inputSchema = graphData.input_schema as BlockIORootSchema;
outputSchema = graphData.output_schema as BlockIORootSchema;
finalHardcodedValues = {
...hardcodedValues,
input_schema: graphData.input_schema,
output_schema: graphData.output_schema,
};
} else {
console.error("Failed to fetch graph data for agent block");
}
}
const newNode: CustomNode = {
id: nodeId.toString(),
type: "custom",
position,
data: {
blockType: blockName,
blockCosts: nodeSchema.costs || [],
title: `${blockName} ${nodeId}`,
description: nodeSchema.description,
categories: nodeSchema.categories,
inputSchema: inputSchema,
outputSchema: outputSchema,
hardcodedValues: finalHardcodedValues,
connections: [],
isOutputOpen: false,
block_id: blockID,
isOutputStatic: nodeSchema.staticOutput,
uiType: nodeSchema.uiType,
},
};
addNodes(newNode);
setNodeId((prevId) => prevId + 1);
clearNodesStatusAndOutput();
history.push({
type: "ADD_NODE",
payload: { node: { ...newNode, ...newNode.data } },
undo: () => deleteElements({ nodes: [{ id: newNode.id }] }),
redo: () => addNodes(newNode),
});
return newNode;
},
[
availableBlocks,
nodeId,
addNodes,
deleteElements,
clearNodesStatusAndOutput,
],
);
const addNode = useCallback( const addNode = useCallback(
async ( (blockId: string, nodeType: string, hardcodedValues: any = {}) => {
blockId: string,
nodeType: string,
hardcodedValues: Record<string, any> = {},
) => {
const nodeSchema = availableBlocks.find((node) => node.id === blockId); const nodeSchema = availableBlocks.find((node) => node.id === blockId);
if (!nodeSchema) { if (!nodeSchema) {
console.error(`Schema not found for block ID: ${blockId}`); console.error(`Schema not found for block ID: ${blockId}`);
@@ -797,42 +707,73 @@ const FlowEditor: React.FC<{
// Alternative: We could also use D3 force, Intersection for this (React flow Pro examples) // Alternative: We could also use D3 force, Intersection for this (React flow Pro examples)
const { x, y } = getViewport(); const { x, y } = getViewport();
const position = const viewportCoordinates =
nodeDimensions && Object.keys(nodeDimensions).length > 0 nodeDimensions && Object.keys(nodeDimensions).length > 0
? findNewlyAddedBlockCoordinates( ? // we will get all the dimension of nodes, then store
findNewlyAddedBlockCoordinates(
nodeDimensions, nodeDimensions,
nodeSchema.uiType == BlockUIType.NOTE ? 300 : 500, nodeSchema.uiType == BlockUIType.NOTE ? 300 : 500,
60, 60,
1.0, 1.0,
) )
: { : // we will get all the dimension of nodes, then store
{
x: window.innerWidth / 2 - x, x: window.innerWidth / 2 - x,
y: window.innerHeight / 2 - y, y: window.innerHeight / 2 - y,
}; };
const newNode = await createAndAddNode( const newNode: CustomNode = {
blockId, id: nodeId.toString(),
nodeType, type: "custom",
hardcodedValues, position: viewportCoordinates, // Set the position to the calculated viewport center
position, data: {
); blockType: nodeType,
if (!newNode) return; blockCosts: nodeSchema.costs,
title: `${nodeType} ${nodeId}`,
description: nodeSchema.description,
categories: nodeSchema.categories,
inputSchema: nodeSchema.inputSchema,
outputSchema: nodeSchema.outputSchema,
hardcodedValues: hardcodedValues,
connections: [],
isOutputOpen: false,
block_id: blockId,
isOutputStatic: nodeSchema.staticOutput,
uiType: nodeSchema.uiType,
},
};
addNodes(newNode);
setNodeId((prevId) => prevId + 1);
clearNodesStatusAndOutput(); // Clear status and output when a new node is added
setViewport( setViewport(
{ {
x: -position.x * 0.8 + (window.innerWidth - 0.0) / 2, // Rough estimate of the dimension of the node is: 500x400px.
y: -position.y * 0.8 + (window.innerHeight - 400) / 2, // Though we skip shifting the X, considering the block menu side-bar.
x: -viewportCoordinates.x * 0.8 + (window.innerWidth - 0.0) / 2,
y: -viewportCoordinates.y * 0.8 + (window.innerHeight - 400) / 2,
zoom: 0.8, zoom: 0.8,
}, },
{ duration: 500 }, { duration: 500 },
); );
history.push({
type: "ADD_NODE",
payload: { node: { ...newNode, ...newNode.data } },
undo: () => deleteElements({ nodes: [{ id: newNode.id }] }),
redo: () => addNodes(newNode),
});
}, },
[ [
nodeId,
getViewport, getViewport,
setViewport, setViewport,
availableBlocks, availableBlocks,
addNodes,
nodeDimensions, nodeDimensions,
createAndAddNode, deleteElements,
clearNodesStatusAndOutput,
], ],
); );
@@ -979,7 +920,7 @@ const FlowEditor: React.FC<{
}, []); }, []);
const onDrop = useCallback( const onDrop = useCallback(
async (event: React.DragEvent) => { (event: React.DragEvent) => {
event.preventDefault(); event.preventDefault();
const blockData = event.dataTransfer.getData("application/reactflow"); const blockData = event.dataTransfer.getData("application/reactflow");
@@ -994,17 +935,62 @@ const FlowEditor: React.FC<{
y: event.clientY, y: event.clientY,
}); });
await createAndAddNode( // Find the block schema
blockId, const nodeSchema = availableBlocks.find((node) => node.id === blockId);
blockName, if (!nodeSchema) {
hardcodedValues || {}, console.error(`Schema not found for block ID: ${blockId}`);
return;
}
// Create the new node at the drop position
const newNode: CustomNode = {
id: nodeId.toString(),
type: "custom",
position, position,
); data: {
blockType: blockName,
blockCosts: nodeSchema.costs || [],
title: `${blockName} ${nodeId}`,
description: nodeSchema.description,
categories: nodeSchema.categories,
inputSchema: nodeSchema.inputSchema,
outputSchema: nodeSchema.outputSchema,
hardcodedValues: hardcodedValues,
connections: [],
isOutputOpen: false,
block_id: blockId,
uiType: nodeSchema.uiType,
},
};
history.push({
type: "ADD_NODE",
payload: { node: { ...newNode, ...newNode.data } },
undo: () => {
deleteElements({ nodes: [{ id: newNode.id } as any], edges: [] });
},
redo: () => {
addNodes([newNode]);
},
});
addNodes([newNode]);
clearNodesStatusAndOutput();
setNodeId((prevId) => prevId + 1);
} catch (error) { } catch (error) {
console.error("Failed to drop block:", error); console.error("Failed to drop block:", error);
} }
}, },
[screenToFlowPosition, createAndAddNode], [
nodeId,
availableBlocks,
nodes,
edges,
addNodes,
screenToFlowPosition,
deleteElements,
clearNodesStatusAndOutput,
],
); );
const buildContextValue: BuilderContextType = useMemo( const buildContextValue: BuilderContextType = useMemo(

View File

@@ -1,14 +1,8 @@
import React, { useContext, useMemo, useState } from "react"; import React, { useContext, useState } from "react";
import { Button } from "@/components/__legacy__/ui/button"; import { Button } from "@/components/__legacy__/ui/button";
import { Maximize2 } from "lucide-react"; import { Maximize2 } from "lucide-react";
import * as Separator from "@radix-ui/react-separator"; import * as Separator from "@radix-ui/react-separator";
import { ContentRenderer } from "@/components/__legacy__/ui/render"; import { ContentRenderer } from "@/components/__legacy__/ui/render";
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
import {
globalRegistry,
OutputItem,
} from "@/components/contextual/OutputRenderers";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { beautifyString } from "@/lib/utils"; import { beautifyString } from "@/lib/utils";
@@ -27,9 +21,6 @@ export default function NodeOutputs({
data, data,
}: NodeOutputsProps) { }: NodeOutputsProps) {
const builderContext = useContext(BuilderContext); const builderContext = useContext(BuilderContext);
const enableEnhancedOutputHandling = useGetFlag(
Flag.ENABLE_ENHANCED_OUTPUT_HANDLING,
);
const [expandedDialog, setExpandedDialog] = useState<{ const [expandedDialog, setExpandedDialog] = useState<{
isOpen: boolean; isOpen: boolean;
@@ -46,15 +37,6 @@ export default function NodeOutputs({
const { getNodeTitle } = builderContext; const { getNodeTitle } = builderContext;
// Prepare renderers for each item when enhanced mode is enabled
const getItemRenderer = useMemo(() => {
if (!enableEnhancedOutputHandling) return null;
return (item: unknown) => {
const metadata: OutputMetadata = {};
return globalRegistry.getRenderer(item, metadata);
};
}, [enableEnhancedOutputHandling]);
const getBeautifiedPinName = (pin: string) => { const getBeautifiedPinName = (pin: string) => {
if (!pin.startsWith("tools_^_")) { if (!pin.startsWith("tools_^_")) {
return beautifyString(pin); return beautifyString(pin);
@@ -105,31 +87,15 @@ export default function NodeOutputs({
<div className="mt-2"> <div className="mt-2">
<strong className="mr-2">Data:</strong> <strong className="mr-2">Data:</strong>
<div className="mt-1"> <div className="mt-1">
{dataArray.slice(0, 10).map((item, index) => { {dataArray.slice(0, 10).map((item, index) => (
const renderer = getItemRenderer?.(item); <React.Fragment key={index}>
if (enableEnhancedOutputHandling && renderer) { <ContentRenderer
const metadata: OutputMetadata = {}; value={item}
return ( truncateLongData={truncateLongData}
<React.Fragment key={index}> />
<OutputItem {index < Math.min(dataArray.length, 10) - 1 && ", "}
value={item} </React.Fragment>
metadata={metadata} ))}
renderer={renderer}
/>
{index < Math.min(dataArray.length, 10) - 1 && ", "}
</React.Fragment>
);
}
return (
<React.Fragment key={index}>
<ContentRenderer
value={item}
truncateLongData={truncateLongData}
/>
{index < Math.min(dataArray.length, 10) - 1 && ", "}
</React.Fragment>
);
})}
{dataArray.length > 10 && ( {dataArray.length > 10 && (
<span style={{ color: "#888" }}> <span style={{ color: "#888" }}>
<br /> <br />

Some files were not shown because too many files have changed in this diff Show More