mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-10 14:55:16 -05:00
Compare commits
65 Commits
kpczerwins
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6573d987ea | ||
|
|
ae8ce8b4ca | ||
|
|
81c1524658 | ||
|
|
f2ead70f3d | ||
|
|
7d4c020a9b | ||
|
|
e596ea87cb | ||
|
|
81f8290f01 | ||
|
|
6467f6734f | ||
|
|
5a30d11416 | ||
|
|
1f4105e8f9 | ||
|
|
caf9ff34e6 | ||
|
|
e8fc8ee623 | ||
|
|
1a16e203b8 | ||
|
|
5dae303ce0 | ||
|
|
6cbfbdd013 | ||
|
|
0c6fa60436 | ||
|
|
b04e916c23 | ||
|
|
1a32ba7d9a | ||
|
|
deccc26f1f | ||
|
|
9e38bd5b78 | ||
|
|
a329831b0b | ||
|
|
98dd1a9480 | ||
|
|
9c7c598c7d | ||
|
|
728c40def5 | ||
|
|
cd64562e1b | ||
|
|
8fddc9d71f | ||
|
|
3d1cd03fc8 | ||
|
|
e7ebe42306 | ||
|
|
e0fab7e34e | ||
|
|
29ee85c86f | ||
|
|
85b6520710 | ||
|
|
bfa942e032 | ||
|
|
11256076d8 | ||
|
|
3ca2387631 | ||
|
|
ed07f02738 | ||
|
|
b121030c94 | ||
|
|
c22c18374d | ||
|
|
e40233a3ac | ||
|
|
3ae5eabf9d | ||
|
|
a077ba9f03 | ||
|
|
5401d54eaa | ||
|
|
5ac89d7c0b | ||
|
|
4f908d5cb3 | ||
|
|
c1aa684743 | ||
|
|
7e5b84cc5c | ||
|
|
09cb313211 | ||
|
|
c026485023 | ||
|
|
1eabc60484 | ||
|
|
f4bf492f24 | ||
|
|
81e48c00a4 | ||
|
|
7dc53071e8 | ||
|
|
4878665c66 | ||
|
|
678ddde751 | ||
|
|
aef6f57cfd | ||
|
|
f7350c797a | ||
|
|
2abbb7fbc8 | ||
|
|
05b60db554 | ||
|
|
cc4839bedb | ||
|
|
dbbff04616 | ||
|
|
e6438b9a76 | ||
|
|
e10ff8d37f | ||
|
|
9538992eaf | ||
|
|
27b72062f2 | ||
|
|
9a79a8d257 | ||
|
|
a9bf08748b |
2
.github/workflows/classic-frontend-ci.yml
vendored
2
.github/workflows/classic-frontend-ci.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
|
||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||
if: github.event_name == 'push'
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
uses: peter-evans/create-pull-request@v8
|
||||
with:
|
||||
add-paths: classic/frontend/build/web
|
||||
base: ${{ github.ref_name }}
|
||||
|
||||
@@ -42,7 +42,7 @@ jobs:
|
||||
|
||||
- name: Get CI failure details
|
||||
id: failure_details
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const run = await github.rest.actions.getWorkflowRun({
|
||||
|
||||
9
.github/workflows/claude-dependabot.yml
vendored
9
.github/workflows/claude-dependabot.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -78,7 +78,7 @@ jobs:
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
@@ -91,7 +91,7 @@ jobs:
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
@@ -124,7 +124,7 @@ jobs:
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
@@ -309,6 +309,7 @@ jobs:
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
allowed_bots: "dependabot[bot]"
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||
prompt: |
|
||||
|
||||
8
.github/workflows/claude.yml
vendored
8
.github/workflows/claude.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -94,7 +94,7 @@ jobs:
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
@@ -107,7 +107,7 @@ jobs:
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
@@ -140,7 +140,7 @@ jobs:
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
|
||||
8
.github/workflows/copilot-setup-steps.yml
vendored
8
.github/workflows/copilot-setup-steps.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -76,7 +76,7 @@ jobs:
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
@@ -89,7 +89,7 @@ jobs:
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
@@ -132,7 +132,7 @@ jobs:
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
|
||||
2
.github/workflows/docs-block-sync.yml
vendored
2
.github/workflows/docs-block-sync.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
2
.github/workflows/docs-claude-review.yml
vendored
2
.github/workflows/docs-claude-review.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
2
.github/workflows/docs-enhance.yml
vendored
2
.github/workflows/docs-enhance.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -88,7 +88,7 @@ jobs:
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: Check comment permissions and deployment status
|
||||
id: check_status
|
||||
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const commentBody = context.payload.comment.body.trim();
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
|
||||
- name: Post permission denied comment
|
||||
if: steps.check_status.outputs.permission_denied == 'true'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
- name: Get PR details for deployment
|
||||
id: pr_details
|
||||
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const pr = await github.rest.pulls.get({
|
||||
@@ -98,7 +98,7 @@ jobs:
|
||||
|
||||
- name: Post deploy success comment
|
||||
if: steps.check_status.outputs.should_deploy == 'true'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
@@ -126,7 +126,7 @@ jobs:
|
||||
|
||||
- name: Post undeploy success comment
|
||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
@@ -139,7 +139,7 @@ jobs:
|
||||
- name: Check deployment status on PR close
|
||||
id: check_pr_close
|
||||
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const comments = await github.rest.issues.listComments({
|
||||
@@ -187,7 +187,7 @@ jobs:
|
||||
github.event_name == 'pull_request' &&
|
||||
github.event.action == 'closed' &&
|
||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
|
||||
38
.github/workflows/platform-frontend-ci.yml
vendored
38
.github/workflows/platform-frontend-ci.yml
vendored
@@ -27,13 +27,22 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
components-changed: ${{ steps.filter.outputs.components }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check for component changes
|
||||
uses: dorny/paths-filter@v3
|
||||
id: filter
|
||||
with:
|
||||
filters: |
|
||||
components:
|
||||
- 'autogpt_platform/frontend/src/components/**'
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -45,7 +54,7 @@ jobs:
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
@@ -65,7 +74,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -73,7 +82,7 @@ jobs:
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
@@ -90,8 +99,11 @@ jobs:
|
||||
chromatic:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
# Only run on dev branch pushes or PRs targeting dev
|
||||
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
|
||||
# Disabled: to re-enable, remove 'false &&' from the condition below
|
||||
if: >-
|
||||
false
|
||||
&& (github.ref == 'refs/heads/dev' || github.base_ref == 'dev')
|
||||
&& needs.setup.outputs.components-changed == 'true'
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -100,7 +112,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -108,7 +120,7 @@ jobs:
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
@@ -141,7 +153,7 @@ jobs:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -164,7 +176,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Cache Docker layers
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: /tmp/.buildx-cache
|
||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -219,7 +231,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
@@ -270,7 +282,7 @@ jobs:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -278,7 +290,7 @@ jobs:
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
|
||||
12
.github/workflows/platform-fullstack-ci.yml
vendored
12
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
@@ -56,7 +56,7 @@ jobs:
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
types:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: big-boi
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -85,10 +85,10 @@ jobs:
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -180,3 +180,4 @@ autogpt_platform/backend/settings.py
|
||||
.claude/settings.local.json
|
||||
CLAUDE.local.md
|
||||
/autogpt_platform/backend/logs
|
||||
.next
|
||||
1862
autogpt_platform/autogpt_libs/poetry.lock
generated
1862
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^45.0"
|
||||
cryptography = "^46.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
pydantic = "^2.11.7"
|
||||
pydantic-settings = "^2.10.1"
|
||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||
fastapi = "^0.128.0"
|
||||
google-cloud-logging = "^3.13.0"
|
||||
launchdarkly-server-sdk = "^9.14.1"
|
||||
pydantic = "^2.12.5"
|
||||
pydantic-settings = "^2.12.0"
|
||||
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
supabase = "^2.27.2"
|
||||
uvicorn = "^0.40.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pyright = "^1.1.404"
|
||||
pyright = "^1.1.408"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
pytest-cov = "^6.2.1"
|
||||
ruff = "^0.12.11"
|
||||
pytest-asyncio = "^1.3.0"
|
||||
pytest-mock = "^3.15.1"
|
||||
pytest-cov = "^7.0.0"
|
||||
ruff = "^0.15.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -152,6 +152,7 @@ REPLICATE_API_KEY=
|
||||
REVID_API_KEY=
|
||||
SCREENSHOTONE_API_KEY=
|
||||
UNREAL_SPEECH_API_KEY=
|
||||
ELEVENLABS_API_KEY=
|
||||
|
||||
# Data & Search Services
|
||||
E2B_API_KEY=
|
||||
|
||||
3
autogpt_platform/backend/.gitignore
vendored
3
autogpt_platform/backend/.gitignore
vendored
@@ -19,3 +19,6 @@ load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
migrations/*/rollback*.sql
|
||||
|
||||
# Workspace files
|
||||
workspaces/
|
||||
|
||||
@@ -62,10 +62,12 @@ ENV POETRY_HOME=/opt/poetry \
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Install Python without upgrading system-managed packages
|
||||
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
ffmpeg \
|
||||
imagemagick \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy only necessary files from builder
|
||||
|
||||
@@ -0,0 +1,368 @@
|
||||
"""Redis Streams consumer for operation completion messages.
|
||||
|
||||
This module provides a consumer (ChatCompletionConsumer) that listens for
|
||||
completion notifications (OperationCompleteMessage) from external services
|
||||
(like Agent Generator) and triggers the appropriate stream registry and
|
||||
chat service updates via process_operation_success/process_operation_failure.
|
||||
|
||||
Why Redis Streams instead of RabbitMQ?
|
||||
--------------------------------------
|
||||
While the project typically uses RabbitMQ for async task queues (e.g., execution
|
||||
queue), Redis Streams was chosen for chat completion notifications because:
|
||||
|
||||
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
|
||||
Streams (via stream_registry) for message persistence and replay. Using Redis
|
||||
Streams for completion notifications keeps all chat streaming infrastructure
|
||||
in one system, simplifying operations and reducing cross-system coordination.
|
||||
|
||||
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
|
||||
allowing consumers to replay missed messages after reconnection. This aligns
|
||||
with the SSE reconnection pattern where clients can resume from last_message_id.
|
||||
|
||||
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
|
||||
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
|
||||
recovering from dead consumers - ideal for the completion callback pattern.
|
||||
|
||||
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
|
||||
stream_registry) provides lower latency than an additional RabbitMQ hop.
|
||||
|
||||
5. **Atomicity with Task State**: Completion processing often needs to update
|
||||
task metadata stored in Redis. Keeping both in Redis enables simpler
|
||||
transactional semantics without distributed coordination.
|
||||
|
||||
The consumer uses Redis Streams with consumer groups for reliable message
|
||||
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
|
||||
stale pending messages from dead consumers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from prisma import Prisma
|
||||
from pydantic import BaseModel
|
||||
from redis.exceptions import ResponseError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class OperationCompleteMessage(BaseModel):
|
||||
"""Message format for operation completion notifications."""
|
||||
|
||||
operation_id: str
|
||||
task_id: str
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionConsumer:
|
||||
"""Consumer for chat operation completion messages from Redis Streams.
|
||||
|
||||
This consumer initializes its own Prisma client in start() to ensure
|
||||
database operations work correctly within this async context.
|
||||
|
||||
Uses Redis consumer groups to allow multiple platform pods to consume
|
||||
messages reliably with automatic redelivery on failure.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._consumer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self._prisma: Prisma | None = None
|
||||
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the completion consumer."""
|
||||
if self._running:
|
||||
logger.warning("Completion consumer already running")
|
||||
return
|
||||
|
||||
# Create consumer group if it doesn't exist
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.xgroup_create(
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
id="0",
|
||||
mkstream=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Created consumer group '{config.stream_consumer_group}' "
|
||||
f"on stream '{config.stream_completion_name}'"
|
||||
)
|
||||
except ResponseError as e:
|
||||
if "BUSYGROUP" in str(e):
|
||||
logger.debug(
|
||||
f"Consumer group '{config.stream_consumer_group}' already exists"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
self._running = True
|
||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
||||
logger.info(
|
||||
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
||||
)
|
||||
|
||||
async def _ensure_prisma(self) -> Prisma:
|
||||
"""Lazily initialize Prisma client on first use."""
|
||||
if self._prisma is None:
|
||||
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||
self._prisma = Prisma(datasource={"url": database_url})
|
||||
await self._prisma.connect()
|
||||
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
||||
return self._prisma
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the completion consumer."""
|
||||
self._running = False
|
||||
|
||||
if self._consumer_task:
|
||||
self._consumer_task.cancel()
|
||||
try:
|
||||
await self._consumer_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._consumer_task = None
|
||||
|
||||
if self._prisma:
|
||||
await self._prisma.disconnect()
|
||||
self._prisma = None
|
||||
logger.info("[COMPLETION] Consumer Prisma client disconnected")
|
||||
|
||||
logger.info("Chat completion consumer stopped")
|
||||
|
||||
async def _consume_messages(self) -> None:
|
||||
"""Main message consumption loop with retry logic."""
|
||||
max_retries = 10
|
||||
retry_delay = 5 # seconds
|
||||
retry_count = 0
|
||||
block_timeout = 5000 # milliseconds
|
||||
|
||||
while self._running and retry_count < max_retries:
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Reset retry count on successful connection
|
||||
retry_count = 0
|
||||
|
||||
while self._running:
|
||||
# First, claim any stale pending messages from dead consumers
|
||||
# Redis does NOT auto-redeliver pending messages; we must explicitly
|
||||
# claim them using XAUTOCLAIM
|
||||
try:
|
||||
claimed_result = await redis.xautoclaim(
|
||||
name=config.stream_completion_name,
|
||||
groupname=config.stream_consumer_group,
|
||||
consumername=self._consumer_name,
|
||||
min_idle_time=config.stream_claim_min_idle_ms,
|
||||
start_id="0-0",
|
||||
count=10,
|
||||
)
|
||||
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
|
||||
if claimed_result and len(claimed_result) >= 2:
|
||||
claimed_entries = claimed_result[1]
|
||||
if claimed_entries:
|
||||
logger.info(
|
||||
f"Claimed {len(claimed_entries)} stale pending messages"
|
||||
)
|
||||
for entry_id, data in claimed_entries:
|
||||
if not self._running:
|
||||
return
|
||||
await self._process_entry(redis, entry_id, data)
|
||||
except Exception as e:
|
||||
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
|
||||
|
||||
# Read new messages from the stream
|
||||
messages = await redis.xreadgroup(
|
||||
groupname=config.stream_consumer_group,
|
||||
consumername=self._consumer_name,
|
||||
streams={config.stream_completion_name: ">"},
|
||||
block=block_timeout,
|
||||
count=10,
|
||||
)
|
||||
|
||||
if not messages:
|
||||
continue
|
||||
|
||||
for stream_name, entries in messages:
|
||||
for entry_id, data in entries:
|
||||
if not self._running:
|
||||
return
|
||||
await self._process_entry(redis, entry_id, data)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Consumer cancelled")
|
||||
return
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
logger.error(
|
||||
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
if self._running and retry_count < max_retries:
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.error("Max retries reached, stopping consumer")
|
||||
return
|
||||
|
||||
async def _process_entry(
|
||||
self, redis: Any, entry_id: str, data: dict[str, Any]
|
||||
) -> None:
|
||||
"""Process a single stream entry and acknowledge it on success.
|
||||
|
||||
Args:
|
||||
redis: Redis client connection
|
||||
entry_id: The stream entry ID
|
||||
data: The entry data dict
|
||||
"""
|
||||
try:
|
||||
# Handle the message
|
||||
message_data = data.get("data")
|
||||
if message_data:
|
||||
await self._handle_message(
|
||||
message_data.encode()
|
||||
if isinstance(message_data, str)
|
||||
else message_data
|
||||
)
|
||||
|
||||
# Acknowledge the message after successful processing
|
||||
await redis.xack(
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
entry_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing completion message {entry_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Message remains in pending state and will be claimed by
|
||||
# XAUTOCLAIM after min_idle_time expires
|
||||
|
||||
async def _handle_message(self, body: bytes) -> None:
|
||||
"""Handle a completion message using our own Prisma client."""
|
||||
try:
|
||||
data = orjson.loads(body)
|
||||
message = OperationCompleteMessage(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse completion message: {e}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id}, success={message.success})"
|
||||
)
|
||||
|
||||
# Find task in registry
|
||||
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
||||
if task is None:
|
||||
task = await stream_registry.get_task(message.task_id)
|
||||
|
||||
if task is None:
|
||||
logger.warning(
|
||||
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id})"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
||||
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
||||
)
|
||||
|
||||
# Guard against empty task fields
|
||||
if not task.task_id or not task.session_id or not task.tool_call_id:
|
||||
logger.error(
|
||||
f"[COMPLETION] Task has empty critical fields! "
|
||||
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
||||
f"tool_call_id={task.tool_call_id!r}"
|
||||
)
|
||||
return
|
||||
|
||||
if message.success:
|
||||
await self._handle_success(task, message)
|
||||
else:
|
||||
await self._handle_failure(task, message)
|
||||
|
||||
async def _handle_success(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle successful operation completion."""
|
||||
prisma = await self._ensure_prisma()
|
||||
await process_operation_success(task, message.result, prisma)
|
||||
|
||||
async def _handle_failure(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle failed operation completion."""
|
||||
prisma = await self._ensure_prisma()
|
||||
await process_operation_failure(task, message.error, prisma)
|
||||
|
||||
|
||||
# Module-level consumer instance
|
||||
_consumer: ChatCompletionConsumer | None = None
|
||||
|
||||
|
||||
async def start_completion_consumer() -> None:
|
||||
"""Start the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer is None:
|
||||
_consumer = ChatCompletionConsumer()
|
||||
await _consumer.start()
|
||||
|
||||
|
||||
async def stop_completion_consumer() -> None:
|
||||
"""Stop the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer:
|
||||
await _consumer.stop()
|
||||
_consumer = None
|
||||
|
||||
|
||||
async def publish_operation_complete(
|
||||
operation_id: str,
|
||||
task_id: str,
|
||||
success: bool,
|
||||
result: dict | str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Publish an operation completion message to Redis Streams.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID that completed.
|
||||
task_id: The task ID associated with the operation.
|
||||
success: Whether the operation succeeded.
|
||||
result: The result data (for success).
|
||||
error: The error message (for failure).
|
||||
"""
|
||||
message = OperationCompleteMessage(
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
success=success,
|
||||
result=result,
|
||||
error=error,
|
||||
)
|
||||
|
||||
redis = await get_redis_async()
|
||||
await redis.xadd(
|
||||
config.stream_completion_name,
|
||||
{"data": message.model_dump_json()},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
logger.info(f"Published completion for operation {operation_id}")
|
||||
@@ -0,0 +1,344 @@
|
||||
"""Shared completion handling for operation success and failure.
|
||||
|
||||
This module provides common logic for handling operation completion from both:
|
||||
- The Redis Streams consumer (completion_consumer.py)
|
||||
- The HTTP webhook endpoint (routes.py)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from prisma import Prisma
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .response_model import StreamError, StreamToolOutputAvailable
|
||||
from .tools.models import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tools that produce agent_json that needs to be saved to library
|
||||
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
|
||||
|
||||
# Keys that should be stripped from agent_json when returning in error responses
|
||||
SENSITIVE_KEYS = frozenset(
|
||||
{
|
||||
"api_key",
|
||||
"apikey",
|
||||
"api_secret",
|
||||
"password",
|
||||
"secret",
|
||||
"credentials",
|
||||
"credential",
|
||||
"token",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"private_key",
|
||||
"privatekey",
|
||||
"auth",
|
||||
"authorization",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_agent_json(obj: Any) -> Any:
|
||||
"""Recursively sanitize agent_json by removing sensitive keys.
|
||||
|
||||
Args:
|
||||
obj: The object to sanitize (dict, list, or primitive)
|
||||
|
||||
Returns:
|
||||
Sanitized copy with sensitive keys removed/redacted
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
|
||||
for k, v in obj.items()
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [_sanitize_agent_json(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
class ToolMessageUpdateError(Exception):
|
||||
"""Raised when updating a tool message in the database fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
async def _update_tool_message(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
content: str,
|
||||
prisma_client: Prisma | None,
|
||||
) -> None:
|
||||
"""Update tool message in database.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
tool_call_id: The tool call ID to update
|
||||
content: The new content for the message
|
||||
prisma_client: Optional Prisma client. If None, uses chat_service.
|
||||
|
||||
Raises:
|
||||
ToolMessageUpdateError: If the database update fails. The caller should
|
||||
handle this to avoid marking the task as completed with inconsistent state.
|
||||
"""
|
||||
try:
|
||||
if prisma_client:
|
||||
# Use provided Prisma client (for consumer with its own connection)
|
||||
updated_count = await prisma_client.chatmessage.update_many(
|
||||
where={
|
||||
"sessionId": session_id,
|
||||
"toolCallId": tool_call_id,
|
||||
},
|
||||
data={"content": content},
|
||||
)
|
||||
# Check if any rows were updated - 0 means message not found
|
||||
if updated_count == 0:
|
||||
raise ToolMessageUpdateError(
|
||||
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
|
||||
)
|
||||
else:
|
||||
# Use service function (for webhook endpoint)
|
||||
await chat_service._update_pending_operation(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
result=content,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
||||
raise ToolMessageUpdateError(
|
||||
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
|
||||
"""Serialize result to JSON string with sensible defaults.
|
||||
|
||||
Args:
|
||||
result: The result to serialize. Can be a dict, list, string,
|
||||
number, boolean, or None.
|
||||
|
||||
Returns:
|
||||
JSON string representation of the result. Returns '{"status": "completed"}'
|
||||
only when result is explicitly None.
|
||||
"""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
if result is None:
|
||||
return '{"status": "completed"}'
|
||||
return orjson.dumps(result).decode("utf-8")
|
||||
|
||||
|
||||
async def _save_agent_from_result(
|
||||
result: dict[str, Any],
|
||||
user_id: str | None,
|
||||
tool_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Save agent to library if result contains agent_json.
|
||||
|
||||
Args:
|
||||
result: The result dict that may contain agent_json
|
||||
user_id: The user ID to save the agent for
|
||||
tool_name: The tool name (create_agent or edit_agent)
|
||||
|
||||
Returns:
|
||||
Updated result dict with saved agent details, or original result if no agent_json
|
||||
"""
|
||||
if not user_id:
|
||||
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
|
||||
return result
|
||||
|
||||
agent_json = result.get("agent_json")
|
||||
if not agent_json:
|
||||
logger.warning(
|
||||
f"[COMPLETION] {tool_name} completed but no agent_json in result"
|
||||
)
|
||||
return result
|
||||
|
||||
try:
|
||||
from .tools.agent_generator import save_agent_to_library
|
||||
|
||||
is_update = tool_name == "edit_agent"
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id, is_update=is_update
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
|
||||
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
|
||||
)
|
||||
|
||||
# Return a response similar to AgentSavedResponse
|
||||
return {
|
||||
"type": "agent_saved",
|
||||
"message": f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
"agent_id": created_graph.id,
|
||||
"agent_name": created_graph.name,
|
||||
"library_agent_id": library_agent.id,
|
||||
"library_agent_link": f"/library/agents/{library_agent.id}",
|
||||
"agent_page_link": f"/build?flowID={created_graph.id}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to save agent to library: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Return error but don't fail the whole operation
|
||||
# Sanitize agent_json to remove sensitive keys before returning
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"Agent was generated but failed to save: {str(e)}",
|
||||
"error": str(e),
|
||||
"agent_json": _sanitize_agent_json(agent_json),
|
||||
}
|
||||
|
||||
|
||||
async def process_operation_success(
|
||||
task: stream_registry.ActiveTask,
|
||||
result: dict | str | None,
|
||||
prisma_client: Prisma | None = None,
|
||||
) -> None:
|
||||
"""Handle successful operation completion.
|
||||
|
||||
Publishes the result to the stream registry, updates the database,
|
||||
generates LLM continuation, and marks the task as completed.
|
||||
|
||||
Args:
|
||||
task: The active task that completed
|
||||
result: The result data from the operation
|
||||
prisma_client: Optional Prisma client for database operations.
|
||||
If None, uses chat_service._update_pending_operation instead.
|
||||
|
||||
Raises:
|
||||
ToolMessageUpdateError: If the database update fails. The task will be
|
||||
marked as failed instead of completed to avoid inconsistent state.
|
||||
"""
|
||||
# For agent generation tools, save the agent to library
|
||||
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
||||
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
|
||||
|
||||
# Serialize result for output (only substitute default when result is exactly None)
|
||||
result_output = result if result is not None else {"status": "completed"}
|
||||
output_str = (
|
||||
result_output
|
||||
if isinstance(result_output, str)
|
||||
else orjson.dumps(result_output).decode("utf-8")
|
||||
)
|
||||
|
||||
# Publish result to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=task.tool_call_id,
|
||||
toolName=task.tool_name,
|
||||
output=output_str,
|
||||
success=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update pending operation in database
|
||||
# If this fails, we must not continue to mark the task as completed
|
||||
result_str = serialize_result(result)
|
||||
try:
|
||||
await _update_tool_message(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
content=result_str,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
# DB update failed - mark task as failed to avoid inconsistent state
|
||||
logger.error(
|
||||
f"[COMPLETION] DB update failed for task {task.task_id}, "
|
||||
"marking as failed instead of completed"
|
||||
)
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText="Failed to save operation result to database"),
|
||||
)
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
raise
|
||||
|
||||
# Generate LLM continuation with streaming
|
||||
try:
|
||||
await chat_service._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
||||
)
|
||||
|
||||
|
||||
async def process_operation_failure(
|
||||
task: stream_registry.ActiveTask,
|
||||
error: str | None,
|
||||
prisma_client: Prisma | None = None,
|
||||
) -> None:
|
||||
"""Handle failed operation completion.
|
||||
|
||||
Publishes the error to the stream registry, updates the database with
|
||||
the error response, and marks the task as failed.
|
||||
|
||||
Args:
|
||||
task: The active task that failed
|
||||
error: The error message from the operation
|
||||
prisma_client: Optional Prisma client for database operations.
|
||||
If None, uses chat_service._update_pending_operation instead.
|
||||
"""
|
||||
error_msg = error or "Operation failed"
|
||||
|
||||
# Publish error to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
|
||||
# Update pending operation with error
|
||||
# If this fails, we still continue to mark the task as failed
|
||||
error_response = ErrorResponse(
|
||||
message=error_msg,
|
||||
error=error,
|
||||
)
|
||||
try:
|
||||
await _update_tool_message(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
content=error_response.model_dump_json(),
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
# DB update failed - log but continue with cleanup
|
||||
logger.error(
|
||||
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
|
||||
"continuing with cleanup"
|
||||
)
|
||||
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
|
||||
@@ -11,7 +11,7 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
# OpenAI API Configuration
|
||||
model: str = Field(
|
||||
default="anthropic/claude-opus-4.5", description="Default model to use"
|
||||
default="anthropic/claude-opus-4.6", description="Default model to use"
|
||||
)
|
||||
title_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
@@ -44,6 +44,48 @@ class ChatConfig(BaseSettings):
|
||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||
)
|
||||
|
||||
# Stream registry configuration for SSE reconnection
|
||||
stream_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL in seconds for stream data in Redis (1 hour)",
|
||||
)
|
||||
stream_max_length: int = Field(
|
||||
default=10000,
|
||||
description="Maximum number of messages to store per stream",
|
||||
)
|
||||
|
||||
# Redis Streams configuration for completion consumer
|
||||
stream_completion_name: str = Field(
|
||||
default="chat:completions",
|
||||
description="Redis Stream name for operation completions",
|
||||
)
|
||||
stream_consumer_group: str = Field(
|
||||
default="chat_consumers",
|
||||
description="Consumer group name for completion stream",
|
||||
)
|
||||
stream_claim_min_idle_ms: int = Field(
|
||||
default=60000,
|
||||
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
|
||||
)
|
||||
|
||||
# Redis key prefixes for stream registry
|
||||
task_meta_prefix: str = Field(
|
||||
default="chat:task:meta:",
|
||||
description="Prefix for task metadata hash keys",
|
||||
)
|
||||
task_stream_prefix: str = Field(
|
||||
default="chat:stream:",
|
||||
description="Prefix for task message stream keys",
|
||||
)
|
||||
task_op_prefix: str = Field(
|
||||
default="chat:task:op:",
|
||||
description="Prefix for operation ID to task ID mapping keys",
|
||||
)
|
||||
internal_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
||||
)
|
||||
|
||||
# Langfuse Prompt Management Configuration
|
||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||
langfuse_prompt_name: str = Field(
|
||||
@@ -82,6 +124,14 @@ class ChatConfig(BaseSettings):
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
@field_validator("internal_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_internal_api_key(cls, v):
|
||||
"""Get internal API key from environment if not provided."""
|
||||
if v is None:
|
||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||
return v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
|
||||
@@ -45,10 +45,7 @@ async def create_chat_session(
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
)
|
||||
return await PrismaChatSession.prisma().create(
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
return await PrismaChatSession.prisma().create(data=data)
|
||||
|
||||
|
||||
async def update_chat_session(
|
||||
|
||||
@@ -18,6 +18,10 @@ class ResponseType(str, Enum):
|
||||
START = "start"
|
||||
FINISH = "finish"
|
||||
|
||||
# Step lifecycle (one LLM API call within a message)
|
||||
START_STEP = "start-step"
|
||||
FINISH_STEP = "finish-step"
|
||||
|
||||
# Text streaming
|
||||
TEXT_START = "text-start"
|
||||
TEXT_DELTA = "text-delta"
|
||||
@@ -52,6 +56,20 @@ class StreamStart(StreamBaseResponse):
|
||||
|
||||
type: ResponseType = ResponseType.START
|
||||
messageId: str = Field(..., description="Unique message ID")
|
||||
taskId: str | None = Field(
|
||||
default=None,
|
||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, excluding non-protocol fields like taskId."""
|
||||
import json
|
||||
|
||||
data: dict[str, Any] = {
|
||||
"type": self.type.value,
|
||||
"messageId": self.messageId,
|
||||
}
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
class StreamFinish(StreamBaseResponse):
|
||||
@@ -60,6 +78,26 @@ class StreamFinish(StreamBaseResponse):
|
||||
type: ResponseType = ResponseType.FINISH
|
||||
|
||||
|
||||
class StreamStartStep(StreamBaseResponse):
|
||||
"""Start of a step (one LLM API call within a message).
|
||||
|
||||
The AI SDK uses this to add a step-start boundary to message.parts,
|
||||
enabling visual separation between multiple LLM calls in a single message.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.START_STEP
|
||||
|
||||
|
||||
class StreamFinishStep(StreamBaseResponse):
|
||||
"""End of a step (one LLM API call within a message).
|
||||
|
||||
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
|
||||
so the next LLM call in a tool-call continuation starts with clean state.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.FINISH_STEP
|
||||
|
||||
|
||||
# ========== Text Streaming ==========
|
||||
|
||||
|
||||
@@ -113,7 +151,7 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
||||
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
||||
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
||||
# Additional fields for internal use (not part of AI SDK spec but useful)
|
||||
# Keep these for internal backend use
|
||||
toolName: str | None = Field(
|
||||
default=None, description="Name of the tool that was executed"
|
||||
)
|
||||
@@ -121,6 +159,17 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
default=True, description="Whether the tool execution succeeded"
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, excluding non-spec fields."""
|
||||
import json
|
||||
|
||||
data = {
|
||||
"type": self.type.value,
|
||||
"toolCallId": self.toolCallId,
|
||||
"output": self.output,
|
||||
}
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
# ========== Other ==========
|
||||
|
||||
|
||||
@@ -1,19 +1,45 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import logging
|
||||
import uuid as uuid_module
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Query, Security
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
from .response_model import StreamFinish, StreamHeartbeat
|
||||
from .tools.models import (
|
||||
AgentDetailsResponse,
|
||||
AgentOutputResponse,
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AgentsFoundResponse,
|
||||
BlockListResponse,
|
||||
BlockOutputResponse,
|
||||
ClarificationNeededResponse,
|
||||
DocPageResponse,
|
||||
DocSearchResultsResponse,
|
||||
ErrorResponse,
|
||||
ExecutionStartedResponse,
|
||||
InputValidationErrorResponse,
|
||||
NeedLoginResponse,
|
||||
NoResultsResponse,
|
||||
OperationInProgressResponse,
|
||||
OperationPendingResponse,
|
||||
OperationStartedResponse,
|
||||
SetupRequirementsResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -55,6 +81,15 @@ class CreateSessionResponse(BaseModel):
|
||||
user_id: str | None
|
||||
|
||||
|
||||
class ActiveStreamInfo(BaseModel):
|
||||
"""Information about an active stream for reconnection."""
|
||||
|
||||
task_id: str
|
||||
last_message_id: str # Redis Stream message ID for resumption
|
||||
operation_id: str # Operation ID for completion tracking
|
||||
tool_name: str # Name of the tool being executed
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
"""Response model providing complete details for a chat session, including messages."""
|
||||
|
||||
@@ -63,6 +98,7 @@ class SessionDetailResponse(BaseModel):
|
||||
updated_at: str
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -81,6 +117,14 @@ class ListSessionsResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class OperationCompleteRequest(BaseModel):
|
||||
"""Request model for external completion webhook."""
|
||||
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@@ -166,13 +210,14 @@ async def get_session(
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
If there's an active stream for this session, returns the task_id for reconnection.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session, or None if not found.
|
||||
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
||||
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
@@ -180,11 +225,28 @@ async def get_session(
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
logger.info(
|
||||
f"Returning session {session_id}: "
|
||||
f"message_count={len(messages)}, "
|
||||
f"roles={[m.get('role') for m in messages]}"
|
||||
|
||||
# Check if there's an active stream for this session
|
||||
active_stream_info = None
|
||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
if active_task:
|
||||
# Filter out the in-progress assistant message from the session response.
|
||||
# The client will receive the complete assistant response through the SSE
|
||||
# stream replay instead, preventing duplicate content.
|
||||
if messages and messages[-1].get("role") == "assistant":
|
||||
messages = messages[:-1]
|
||||
|
||||
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
||||
# Since we filtered out the cached assistant message, the client needs
|
||||
# the full stream to reconstruct the response.
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
task_id=active_task.task_id,
|
||||
last_message_id="0-0",
|
||||
operation_id=active_task.operation_id,
|
||||
tool_name=active_task.tool_name,
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
@@ -192,6 +254,7 @@ async def get_session(
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
)
|
||||
|
||||
|
||||
@@ -211,49 +274,264 @@ async def stream_chat_post(
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
The AI generation runs in a background task that continues even if the client disconnects.
|
||||
All chunks are written to Redis for reconnection support. If the client disconnects,
|
||||
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
||||
containing the task_id for reconnection.
|
||||
|
||||
"""
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
stream_start_time = time.perf_counter()
|
||||
log_meta = {"component": "ChatStream", "session_id": session_id}
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
|
||||
f"user={user_id}, message_len={len(request.message)}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
logger.info(
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
task_id = str(uuid_module.uuid4())
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
log_meta["task_id"] = task_id
|
||||
|
||||
task_create_start = time.perf_counter()
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
||||
tool_name="chat",
|
||||
operation_id=operation_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Background task that runs the AI generation independently of SSE connection
|
||||
async def run_ai_generation():
|
||||
import time as time_module
|
||||
|
||||
gen_start_time = time_module.perf_counter()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
first_chunk_time, ttfc = None, None
|
||||
chunk_count = 0
|
||||
try:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
|
||||
):
|
||||
chunk_count += 1
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time_module.perf_counter()
|
||||
ttfc = first_chunk_time - gen_start_time
|
||||
logger.info(
|
||||
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
"time_to_first_chunk_ms": ttfc * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
# Write to Redis (subscribers will receive via XREAD)
|
||||
await stream_registry.publish_chunk(task_id, chunk)
|
||||
|
||||
gen_end_time = time_module.perf_counter()
|
||||
total_time = (gen_end_time - gen_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
||||
f"task={task_id}, session={session_id}, "
|
||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"time_to_first_chunk_ms": (
|
||||
ttfc * 1000 if ttfc is not None else None
|
||||
),
|
||||
"n_chunks": chunk_count,
|
||||
}
|
||||
},
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "completed")
|
||||
except Exception as e:
|
||||
elapsed = time_module.perf_counter() - gen_start_time
|
||||
logger.error(
|
||||
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed * 1000,
|
||||
"error": str(e),
|
||||
}
|
||||
},
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "failed")
|
||||
|
||||
# Start the AI generation in a background task
|
||||
bg_task = asyncio.create_task(run_ai_generation())
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
|
||||
# SSE endpoint that subscribes to the task's stream
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import time as time_module
|
||||
|
||||
event_gen_start = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
||||
f"user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
subscriber_queue = None
|
||||
first_chunk_yielded = False
|
||||
chunks_yielded = 0
|
||||
try:
|
||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0", # Get all messages from the beginning
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
logger.info(
|
||||
"[TIMING] Starting to read from subscriber_queue",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
chunks_yielded += 1
|
||||
|
||||
if not first_chunk_yielded:
|
||||
first_chunk_yielded = True
|
||||
elapsed = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
|
||||
f"type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
"elapsed_ms": elapsed * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
|
||||
f"n_chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunks_yielded": chunks_yielded,
|
||||
"total_time_ms": total_time * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
|
||||
except GeneratorExit:
|
||||
logger.info(
|
||||
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunks_yielded": chunks_yielded,
|
||||
"reason": "client_disconnect",
|
||||
}
|
||||
},
|
||||
)
|
||||
pass # Client disconnected - background task continues
|
||||
except Exception as e:
|
||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||
logger.error(
|
||||
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
|
||||
extra={
|
||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||
},
|
||||
)
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
task_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
||||
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time * 1000,
|
||||
"chunks_yielded": chunks_yielded,
|
||||
}
|
||||
},
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -270,63 +548,90 @@ async def stream_chat_post(
|
||||
@router.get(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_chat_get(
|
||||
async def resume_session_stream(
|
||||
session_id: str,
|
||||
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
is_user_message: bool = Query(default=True),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (GET - legacy endpoint).
|
||||
Resume an active stream for a session.
|
||||
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
Called by the AI SDK's ``useChat(resume: true)`` on page load.
|
||||
Checks for an active (in-progress) task on the session and either replays
|
||||
the full SSE stream or returns 204 No Content if nothing is running.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
message: The user's new message to process.
|
||||
session_id: The chat session identifier.
|
||||
user_id: Optional authenticated user ID.
|
||||
is_user_message: Whether the message is a user message.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
Returns:
|
||||
StreamingResponse (SSE) when an active stream exists,
|
||||
or 204 No Content when there is nothing to resume.
|
||||
"""
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
import asyncio
|
||||
|
||||
active_task, _last_id = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
|
||||
if not active_task:
|
||||
return Response(status_code=204)
|
||||
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=active_task.task_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0", # Full replay so useChat rebuilds the message
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
return Response(status_code=204)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Resume stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
except GeneratorExit:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
||||
finally:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
active_task.task_id, subscriber_queue
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info(
|
||||
"Resume stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"n_chunks": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -334,8 +639,8 @@ async def stream_chat_get(
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -366,6 +671,251 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Task Streaming (SSE Reconnection) ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}/stream",
|
||||
)
|
||||
async def stream_task(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
last_message_id: str = Query(
|
||||
default="0-0",
|
||||
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Reconnect to a long-running task's SSE stream.
|
||||
|
||||
When a long-running operation (like agent generation) starts, the client
|
||||
receives a task_id. If the connection drops, the client can reconnect
|
||||
using this endpoint to resume receiving updates.
|
||||
|
||||
Args:
|
||||
task_id: The task ID from the operation_started response.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
|
||||
"""
|
||||
# Check task existence and expiry before subscribing
|
||||
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
|
||||
|
||||
if error_code == "TASK_EXPIRED":
|
||||
raise HTTPException(
|
||||
status_code=410,
|
||||
detail={
|
||||
"code": "TASK_EXPIRED",
|
||||
"message": "This operation has expired. Please try again.",
|
||||
},
|
||||
)
|
||||
|
||||
if error_code == "TASK_NOT_FOUND":
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "TASK_NOT_FOUND",
|
||||
"message": f"Task {task_id} not found.",
|
||||
},
|
||||
)
|
||||
|
||||
# Validate ownership if task has an owner
|
||||
if task and task.user_id and user_id != task.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"code": "ACCESS_DENIED",
|
||||
"message": "You do not have access to this task.",
|
||||
},
|
||||
)
|
||||
|
||||
# Get subscriber queue from stream registry
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "TASK_NOT_FOUND",
|
||||
"message": f"Task {task_id} not found or access denied.",
|
||||
},
|
||||
)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import asyncio
|
||||
|
||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for next chunk with timeout for heartbeats
|
||||
chunk = await asyncio.wait_for(
|
||||
subscriber_queue.get(), timeout=heartbeat_interval
|
||||
)
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield StreamHeartbeat().to_sse()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}",
|
||||
)
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
) -> dict:
|
||||
"""
|
||||
Get the status of a long-running task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to check.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
|
||||
Returns:
|
||||
dict: Task status including task_id, status, tool_name, and operation_id.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If task_id is not found or user doesn't have access.
|
||||
"""
|
||||
task = await stream_registry.get_task(task_id)
|
||||
|
||||
if task is None:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
# Validate ownership - if task has an owner, requester must match
|
||||
if task.user_id and user_id != task.user_id:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
return {
|
||||
"task_id": task.task_id,
|
||||
"session_id": task.session_id,
|
||||
"status": task.status,
|
||||
"tool_name": task.tool_name,
|
||||
"operation_id": task.operation_id,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ========== External Completion Webhook ==========
|
||||
|
||||
|
||||
@router.post(
|
||||
"/operations/{operation_id}/complete",
|
||||
status_code=200,
|
||||
)
|
||||
async def complete_operation(
|
||||
operation_id: str,
|
||||
request: OperationCompleteRequest,
|
||||
x_api_key: str | None = Header(default=None),
|
||||
) -> dict:
|
||||
"""
|
||||
External completion webhook for long-running operations.
|
||||
|
||||
Called by Agent Generator (or other services) when an operation completes.
|
||||
This triggers the stream registry to publish completion and continue LLM generation.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID to complete.
|
||||
request: Completion payload with success status and result/error.
|
||||
x_api_key: Internal API key for authentication.
|
||||
|
||||
Returns:
|
||||
dict: Status of the completion.
|
||||
|
||||
Raises:
|
||||
HTTPException: If API key is invalid or operation not found.
|
||||
"""
|
||||
# Validate internal API key - reject if not configured or invalid
|
||||
if not config.internal_api_key:
|
||||
logger.error(
|
||||
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Webhook not available: internal API key not configured",
|
||||
)
|
||||
if x_api_key != config.internal_api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
# Find task by operation_id
|
||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Operation {operation_id} not found",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received completion webhook for operation {operation_id} "
|
||||
f"(task_id={task.task_id}, success={request.success})"
|
||||
)
|
||||
|
||||
if request.success:
|
||||
await process_operation_success(task, request.result)
|
||||
else:
|
||||
await process_operation_failure(task, request.error)
|
||||
|
||||
return {"status": "ok", "task_id": task.task_id}
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
@router.get("/config/ttl", status_code=200)
|
||||
async def get_ttl_config() -> dict:
|
||||
"""
|
||||
Get the stream TTL configuration.
|
||||
|
||||
Returns the Time-To-Live settings for chat streams, which determines
|
||||
how long clients can reconnect to an active stream.
|
||||
|
||||
Returns:
|
||||
dict: TTL configuration with seconds and milliseconds values.
|
||||
"""
|
||||
return {
|
||||
"stream_ttl_seconds": config.stream_ttl,
|
||||
"stream_ttl_ms": config.stream_ttl * 1000,
|
||||
}
|
||||
|
||||
|
||||
# ========== Health Check ==========
|
||||
|
||||
|
||||
@@ -402,3 +952,42 @@ async def health_check() -> dict:
|
||||
"service": "chat",
|
||||
"version": "0.1.0",
|
||||
}
|
||||
|
||||
|
||||
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
|
||||
|
||||
ToolResponseUnion = (
|
||||
AgentsFoundResponse
|
||||
| NoResultsResponse
|
||||
| AgentDetailsResponse
|
||||
| SetupRequirementsResponse
|
||||
| ExecutionStartedResponse
|
||||
| NeedLoginResponse
|
||||
| ErrorResponse
|
||||
| InputValidationErrorResponse
|
||||
| AgentOutputResponse
|
||||
| UnderstandingUpdatedResponse
|
||||
| AgentPreviewResponse
|
||||
| AgentSavedResponse
|
||||
| ClarificationNeededResponse
|
||||
| BlockListResponse
|
||||
| BlockOutputResponse
|
||||
| DocSearchResultsResponse
|
||||
| DocPageResponse
|
||||
| OperationStartedResponse
|
||||
| OperationPendingResponse
|
||||
| OperationInProgressResponse
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/schema/tool-responses",
|
||||
response_model=ToolResponseUnion,
|
||||
include_in_schema=True,
|
||||
summary="[Dummy] Tool response type export for codegen",
|
||||
description="This endpoint is not meant to be called. It exists solely to "
|
||||
"expose tool response models in the OpenAPI schema for frontend codegen.",
|
||||
)
|
||||
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
|
||||
"""Never called at runtime. Exists only so Orval generates TS types."""
|
||||
raise HTTPException(status_code=501, detail="Schema-only endpoint")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,967 @@
|
||||
"""Stream registry for managing reconnectable SSE streams.
|
||||
|
||||
This module provides a registry for tracking active streaming tasks and their
|
||||
messages. It uses Redis for all state management (no in-memory state), making
|
||||
pods stateless and horizontally scalable.
|
||||
|
||||
Architecture:
|
||||
- Redis Stream: Persists all messages for replay and real-time delivery
|
||||
- Redis Hash: Task metadata (status, session_id, etc.)
|
||||
|
||||
Subscribers:
|
||||
1. Replay missed messages from Redis Stream (XREAD)
|
||||
2. Listen for live updates via blocking XREAD
|
||||
3. No in-memory state required on the subscribing pod
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
import orjson
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from .config import ChatConfig
|
||||
from .response_model import StreamBaseResponse, StreamError, StreamFinish
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
||||
_local_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
# Track listener tasks per subscriber queue for cleanup
|
||||
# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe
|
||||
_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {}
|
||||
|
||||
# Timeout for putting chunks into subscriber queues (seconds)
|
||||
# If the queue is full and doesn't drain within this time, send an overflow error
|
||||
QUEUE_PUT_TIMEOUT = 5.0
|
||||
|
||||
# Lua script for atomic compare-and-swap status update (idempotent completion)
|
||||
# Returns 1 if status was updated, 0 if already completed/failed
|
||||
COMPLETE_TASK_SCRIPT = """
|
||||
local current = redis.call("HGET", KEYS[1], "status")
|
||||
if current == "running" then
|
||||
redis.call("HSET", KEYS[1], "status", ARGV[1])
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveTask:
|
||||
"""Represents an active streaming task (metadata only, no in-memory queues)."""
|
||||
|
||||
task_id: str
|
||||
session_id: str
|
||||
user_id: str | None
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
operation_id: str
|
||||
status: Literal["running", "completed", "failed"] = "running"
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
asyncio_task: asyncio.Task | None = None
|
||||
|
||||
|
||||
def _get_task_meta_key(task_id: str) -> str:
|
||||
"""Get Redis key for task metadata."""
|
||||
return f"{config.task_meta_prefix}{task_id}"
|
||||
|
||||
|
||||
def _get_task_stream_key(task_id: str) -> str:
|
||||
"""Get Redis key for task message stream."""
|
||||
return f"{config.task_stream_prefix}{task_id}"
|
||||
|
||||
|
||||
def _get_operation_mapping_key(operation_id: str) -> str:
|
||||
"""Get Redis key for operation_id to task_id mapping."""
|
||||
return f"{config.task_op_prefix}{operation_id}"
|
||||
|
||||
|
||||
async def create_task(
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
operation_id: str,
|
||||
) -> ActiveTask:
|
||||
"""Create a new streaming task in Redis.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for the task
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (may be None for anonymous)
|
||||
tool_call_id: Tool call ID from the LLM
|
||||
tool_name: Name of the tool being executed
|
||||
operation_id: Operation ID for webhook callbacks
|
||||
|
||||
Returns:
|
||||
The created ActiveTask instance (metadata only)
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Build log metadata for structured logging
|
||||
log_meta = {
|
||||
"component": "StreamRegistry",
|
||||
"task_id": task_id,
|
||||
"session_id": session_id,
|
||||
}
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
|
||||
task = ActiveTask(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# Store metadata in Redis
|
||||
redis_start = time.perf_counter()
|
||||
redis = await get_redis_async()
|
||||
redis_time = (time.perf_counter() - redis_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
|
||||
)
|
||||
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
op_key = _get_operation_mapping_key(operation_id)
|
||||
|
||||
hset_start = time.perf_counter()
|
||||
await redis.hset( # type: ignore[misc]
|
||||
meta_key,
|
||||
mapping={
|
||||
"task_id": task_id,
|
||||
"session_id": session_id,
|
||||
"user_id": user_id or "",
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"operation_id": operation_id,
|
||||
"status": task.status,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
},
|
||||
)
|
||||
hset_time = (time.perf_counter() - hset_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] redis.hset took {hset_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
|
||||
)
|
||||
|
||||
await redis.expire(meta_key, config.stream_ttl)
|
||||
|
||||
# Create operation_id -> task_id mapping for webhook lookups
|
||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
|
||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
async def publish_chunk(
|
||||
task_id: str,
|
||||
chunk: StreamBaseResponse,
|
||||
) -> str:
|
||||
"""Publish a chunk to Redis Stream.
|
||||
|
||||
All delivery is via Redis Streams - no in-memory state.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to publish to
|
||||
chunk: The stream response chunk to publish
|
||||
|
||||
Returns:
|
||||
The Redis Stream message ID
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
chunk_type = type(chunk).__name__
|
||||
chunk_json = chunk.model_dump_json()
|
||||
message_id = "0-0"
|
||||
|
||||
# Build log metadata
|
||||
log_meta = {
|
||||
"component": "StreamRegistry",
|
||||
"task_id": task_id,
|
||||
"chunk_type": chunk_type,
|
||||
}
|
||||
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Write to Redis Stream for persistence and real-time delivery
|
||||
xadd_start = time.perf_counter()
|
||||
raw_id = await redis.xadd(
|
||||
stream_key,
|
||||
{"data": chunk_json},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
xadd_time = (time.perf_counter() - xadd_start) * 1000
|
||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||
|
||||
# Set TTL on stream to match task metadata TTL
|
||||
await redis.expire(stream_key, config.stream_ttl)
|
||||
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
# Only log timing for significant chunks or slow operations
|
||||
if (
|
||||
chunk_type
|
||||
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
|
||||
or total_time > 50
|
||||
):
|
||||
logger.info(
|
||||
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"xadd_time_ms": xadd_time,
|
||||
"message_id": message_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.error(
|
||||
f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}",
|
||||
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return message_id
|
||||
|
||||
|
||||
async def subscribe_to_task(
|
||||
task_id: str,
|
||||
user_id: str | None,
|
||||
last_message_id: str = "0-0",
|
||||
) -> asyncio.Queue[StreamBaseResponse] | None:
|
||||
"""Subscribe to a task's stream with replay of missed messages.
|
||||
|
||||
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to subscribe to
|
||||
user_id: User ID for ownership validation
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
|
||||
|
||||
Returns:
|
||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||
or user doesn't have access
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Build log metadata
|
||||
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
|
||||
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
|
||||
)
|
||||
|
||||
redis_start = time.perf_counter()
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
hgetall_time = (time.perf_counter() - redis_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
||||
)
|
||||
|
||||
if not meta:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed,
|
||||
"reason": "task_not_found",
|
||||
}
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
task_status = meta.get("status", "")
|
||||
task_user_id = meta.get("user_id", "") or None
|
||||
log_meta["session_id"] = meta.get("session_id", "")
|
||||
|
||||
# Validate ownership - if task has an owner, requester must match
|
||||
if task_user_id:
|
||||
if user_id != task_user_id:
|
||||
logger.warning(
|
||||
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"task_owner": task_user_id,
|
||||
"reason": "access_denied",
|
||||
}
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Step 1: Replay messages from Redis Stream
|
||||
xread_start = time.perf_counter()
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": xread_time,
|
||||
"task_status": task_status,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
replayed_count = 0
|
||||
replay_last_id = last_message_id
|
||||
if messages:
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
if "data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
await subscriber_queue.put(chunk)
|
||||
replayed_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"n_messages_replayed": replayed_count,
|
||||
"replay_last_id": replay_last_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Step 2: If task is still running, start stream listener for live updates
|
||||
if task_status == "running":
|
||||
logger.info(
|
||||
"[TIMING] Task still running, starting _stream_listener",
|
||||
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||
)
|
||||
listener_task = asyncio.create_task(
|
||||
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
|
||||
)
|
||||
# Track listener task for cleanup on unsubscribe
|
||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||
else:
|
||||
# Task is completed/failed - add finish marker
|
||||
logger.info(
|
||||
f"[TIMING] Task already {task_status}, adding StreamFinish",
|
||||
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||
)
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
|
||||
f"n_messages_replayed={replayed_count}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"n_messages_replayed": replayed_count,
|
||||
}
|
||||
},
|
||||
)
|
||||
return subscriber_queue
|
||||
|
||||
|
||||
async def _stream_listener(
|
||||
task_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
last_replayed_id: str,
|
||||
log_meta: dict | None = None,
|
||||
) -> None:
|
||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||
|
||||
This approach avoids the duplicate message issue that can occur with pub/sub
|
||||
when messages are published during the gap between replay and subscription.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to listen for
|
||||
subscriber_queue: Queue to deliver messages to
|
||||
last_replayed_id: Last message ID from replay (continue from here)
|
||||
log_meta: Structured logging metadata
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Use provided log_meta or build minimal one
|
||||
if log_meta is None:
|
||||
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
|
||||
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
|
||||
)
|
||||
|
||||
queue_id = id(subscriber_queue)
|
||||
# Track the last successfully delivered message ID for recovery hints
|
||||
last_delivered_id = last_replayed_id
|
||||
messages_delivered = 0
|
||||
first_message_time = None
|
||||
xread_count = 0
|
||||
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
current_id = last_replayed_id
|
||||
|
||||
while True:
|
||||
# Block for up to 30 seconds waiting for new messages
|
||||
# This allows periodic checking if task is still running
|
||||
xread_start = time.perf_counter()
|
||||
xread_count += 1
|
||||
messages = await redis.xread(
|
||||
{stream_key: current_id}, block=30000, count=100
|
||||
)
|
||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||
|
||||
if messages:
|
||||
msg_count = sum(len(msgs) for _, msgs in messages)
|
||||
logger.info(
|
||||
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"xread_count": xread_count,
|
||||
"n_messages": msg_count,
|
||||
"duration_ms": xread_time,
|
||||
}
|
||||
},
|
||||
)
|
||||
elif xread_time > 1000:
|
||||
# Only log timeouts (30s blocking)
|
||||
logger.info(
|
||||
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"xread_count": xread_count,
|
||||
"duration_ms": xread_time,
|
||||
"reason": "timeout",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if not messages:
|
||||
# Timeout - check if task is still running
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
status = await redis.hget(meta_key, "status") # type: ignore[misc]
|
||||
if status and status != "running":
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(StreamFinish()),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Timeout delivering finish event for task {task_id}"
|
||||
)
|
||||
break
|
||||
continue
|
||||
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
|
||||
if "data" not in msg_data:
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(chunk),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
# Update last delivered ID on successful delivery
|
||||
last_delivered_id = current_id
|
||||
messages_delivered += 1
|
||||
if first_message_time is None:
|
||||
first_message_time = time.perf_counter()
|
||||
elapsed = (first_message_time - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
}
|
||||
},
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"timeout_s": QUEUE_PUT_TIMEOUT,
|
||||
"reason": "queue_full",
|
||||
}
|
||||
},
|
||||
)
|
||||
# Send overflow error with recovery info
|
||||
try:
|
||||
overflow_error = StreamError(
|
||||
errorText="Message delivery timeout - some messages may have been missed",
|
||||
code="QUEUE_OVERFLOW",
|
||||
details={
|
||||
"last_delivered_id": last_delivered_id,
|
||||
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
|
||||
},
|
||||
)
|
||||
subscriber_queue.put_nowait(overflow_error)
|
||||
except asyncio.QueueFull:
|
||||
# Queue is completely stuck, nothing more we can do
|
||||
logger.error(
|
||||
f"Cannot deliver overflow error for task {task_id}, "
|
||||
"queue completely blocked"
|
||||
)
|
||||
|
||||
# Stop listening on finish
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"messages_delivered": messages_delivered,
|
||||
}
|
||||
},
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error processing stream message: {e}",
|
||||
extra={"json_fields": {**log_meta, "error": str(e)}},
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed,
|
||||
"messages_delivered": messages_delivered,
|
||||
"reason": "cancelled",
|
||||
}
|
||||
},
|
||||
)
|
||||
raise # Re-raise to propagate cancellation
|
||||
except Exception as e:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.error(
|
||||
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
|
||||
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
||||
)
|
||||
# On error, send finish to unblock subscriber
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(StreamFinish()),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
except (asyncio.TimeoutError, asyncio.QueueFull):
|
||||
logger.warning(
|
||||
"Could not deliver finish event after error",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
finally:
|
||||
# Clean up listener task mapping on exit
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
|
||||
f"delivered={messages_delivered}, xread_count={xread_count}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"messages_delivered": messages_delivered,
|
||||
"xread_count": xread_count,
|
||||
}
|
||||
},
|
||||
)
|
||||
_listener_tasks.pop(queue_id, None)
|
||||
|
||||
|
||||
async def mark_task_completed(
|
||||
task_id: str,
|
||||
status: Literal["completed", "failed"] = "completed",
|
||||
) -> bool:
|
||||
"""Mark a task as completed and publish finish event.
|
||||
|
||||
This is idempotent - calling multiple times with the same task_id is safe.
|
||||
Uses atomic compare-and-swap via Lua script to prevent race conditions.
|
||||
Status is updated first (source of truth), then finish event is published (best-effort).
|
||||
|
||||
Args:
|
||||
task_id: Task ID to mark as completed
|
||||
status: Final status ("completed" or "failed")
|
||||
|
||||
Returns:
|
||||
True if task was newly marked completed, False if already completed/failed
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
|
||||
# Atomic compare-and-swap: only update if status is "running"
|
||||
# This prevents race conditions when multiple callers try to complete simultaneously
|
||||
result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc]
|
||||
|
||||
if result == 0:
|
||||
logger.debug(f"Task {task_id} already completed/failed, skipping")
|
||||
return False
|
||||
|
||||
# THEN publish finish event (best-effort - listeners can detect via status polling)
|
||||
try:
|
||||
await publish_chunk(task_id, StreamFinish())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to publish finish event for task {task_id}: {e}. "
|
||||
"Listeners will detect completion via status polling."
|
||||
)
|
||||
|
||||
# Clean up local task reference if exists
|
||||
_local_tasks.pop(task_id, None)
|
||||
return True
|
||||
|
||||
|
||||
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
|
||||
"""Find a task by its operation ID.
|
||||
|
||||
Used by webhook callbacks to locate the task to update.
|
||||
|
||||
Args:
|
||||
operation_id: Operation ID to search for
|
||||
|
||||
Returns:
|
||||
ActiveTask if found, None otherwise
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
op_key = _get_operation_mapping_key(operation_id)
|
||||
task_id = await redis.get(op_key)
|
||||
|
||||
if not task_id:
|
||||
return None
|
||||
|
||||
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
|
||||
return await get_task(task_id_str)
|
||||
|
||||
|
||||
async def get_task(task_id: str) -> ActiveTask | None:
|
||||
"""Get a task by its ID from Redis.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to look up
|
||||
|
||||
Returns:
|
||||
ActiveTask if found, None otherwise
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if not meta:
|
||||
return None
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||
return ActiveTask(
|
||||
task_id=meta.get("task_id", ""),
|
||||
session_id=meta.get("session_id", ""),
|
||||
user_id=meta.get("user_id", "") or None,
|
||||
tool_call_id=meta.get("tool_call_id", ""),
|
||||
tool_name=meta.get("tool_name", ""),
|
||||
operation_id=meta.get("operation_id", ""),
|
||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
async def get_task_with_expiry_info(
|
||||
task_id: str,
|
||||
) -> tuple[ActiveTask | None, str | None]:
|
||||
"""Get a task by its ID with expiration detection.
|
||||
|
||||
Returns (task, error_code) where error_code is:
|
||||
- None if task found
|
||||
- "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired)
|
||||
- "TASK_NOT_FOUND" if neither exists
|
||||
|
||||
Args:
|
||||
task_id: Task ID to look up
|
||||
|
||||
Returns:
|
||||
Tuple of (ActiveTask or None, error_code or None)
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if not meta:
|
||||
# Check if stream still has data (metadata expired but stream hasn't)
|
||||
stream_len = await redis.xlen(stream_key)
|
||||
if stream_len > 0:
|
||||
return None, "TASK_EXPIRED"
|
||||
return None, "TASK_NOT_FOUND"
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||
return (
|
||||
ActiveTask(
|
||||
task_id=meta.get("task_id", ""),
|
||||
session_id=meta.get("session_id", ""),
|
||||
user_id=meta.get("user_id", "") or None,
|
||||
tool_call_id=meta.get("tool_call_id", ""),
|
||||
tool_name=meta.get("tool_name", ""),
|
||||
operation_id=meta.get("operation_id", ""),
|
||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
async def get_active_task_for_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[ActiveTask | None, str]:
|
||||
"""Get the active (running) task for a session, if any.
|
||||
|
||||
Scans Redis for tasks matching the session_id with status="running".
|
||||
|
||||
Args:
|
||||
session_id: Session ID to look up
|
||||
user_id: User ID for ownership validation (optional)
|
||||
|
||||
Returns:
|
||||
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
|
||||
"""
|
||||
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Scan Redis for task metadata keys
|
||||
cursor = 0
|
||||
tasks_checked = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await redis.scan(
|
||||
cursor, match=f"{config.task_meta_prefix}*", count=100
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
tasks_checked += 1
|
||||
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
|
||||
if not meta:
|
||||
continue
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||
task_session_id = meta.get("session_id", "")
|
||||
task_status = meta.get("status", "")
|
||||
task_user_id = meta.get("user_id", "") or None
|
||||
task_id = meta.get("task_id", "")
|
||||
|
||||
if task_session_id == session_id and task_status == "running":
|
||||
# Validate ownership - if task has an owner, requester must match
|
||||
if task_user_id and user_id != task_user_id:
|
||||
continue
|
||||
|
||||
# Get the last message ID from Redis Stream
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
last_id = "0-0"
|
||||
try:
|
||||
messages = await redis.xrevrange(stream_key, count=1)
|
||||
if messages:
|
||||
msg_id = messages[0][0]
|
||||
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get last message ID: {e}")
|
||||
|
||||
return (
|
||||
ActiveTask(
|
||||
task_id=task_id,
|
||||
session_id=task_session_id,
|
||||
user_id=task_user_id,
|
||||
tool_call_id=meta.get("tool_call_id", ""),
|
||||
tool_name=meta.get("tool_name", ""),
|
||||
operation_id=meta.get("operation_id", ""),
|
||||
status="running",
|
||||
),
|
||||
last_id,
|
||||
)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
return None, "0-0"
|
||||
|
||||
|
||||
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||
"""Reconstruct a StreamBaseResponse from JSON data.
|
||||
|
||||
Args:
|
||||
chunk_data: Parsed JSON data from Redis
|
||||
|
||||
Returns:
|
||||
Reconstructed response object, or None if unknown type
|
||||
"""
|
||||
from .response_model import (
|
||||
ResponseType,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
# Map response types to their corresponding classes
|
||||
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
||||
ResponseType.START.value: StreamStart,
|
||||
ResponseType.FINISH.value: StreamFinish,
|
||||
ResponseType.START_STEP.value: StreamStartStep,
|
||||
ResponseType.FINISH_STEP.value: StreamFinishStep,
|
||||
ResponseType.TEXT_START.value: StreamTextStart,
|
||||
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
||||
ResponseType.TEXT_END.value: StreamTextEnd,
|
||||
ResponseType.TOOL_INPUT_START.value: StreamToolInputStart,
|
||||
ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable,
|
||||
ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable,
|
||||
ResponseType.ERROR.value: StreamError,
|
||||
ResponseType.USAGE.value: StreamUsage,
|
||||
ResponseType.HEARTBEAT.value: StreamHeartbeat,
|
||||
}
|
||||
|
||||
chunk_type = chunk_data.get("type")
|
||||
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
|
||||
|
||||
if chunk_class is None:
|
||||
logger.warning(f"Unknown chunk type: {chunk_type}")
|
||||
return None
|
||||
|
||||
try:
|
||||
return chunk_class(**chunk_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
|
||||
"""Track the asyncio.Task for a task (local reference only).
|
||||
|
||||
This is just for cleanup purposes - the task state is in Redis.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
asyncio_task: The asyncio Task to track
|
||||
"""
|
||||
_local_tasks[task_id] = asyncio_task
|
||||
|
||||
|
||||
async def unsubscribe_from_task(
|
||||
task_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
) -> None:
|
||||
"""Clean up when a subscriber disconnects.
|
||||
|
||||
Cancels the XREAD-based listener task associated with this subscriber queue
|
||||
to prevent resource leaks.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
subscriber_queue: The subscriber's queue used to look up the listener task
|
||||
"""
|
||||
queue_id = id(subscriber_queue)
|
||||
listener_entry = _listener_tasks.pop(queue_id, None)
|
||||
|
||||
if listener_entry is None:
|
||||
logger.debug(
|
||||
f"No listener task found for task {task_id} queue {queue_id} "
|
||||
"(may have already completed)"
|
||||
)
|
||||
return
|
||||
|
||||
stored_task_id, listener_task = listener_entry
|
||||
|
||||
if stored_task_id != task_id:
|
||||
logger.warning(
|
||||
f"Task ID mismatch in unsubscribe: expected {task_id}, "
|
||||
f"found {stored_task_id}"
|
||||
)
|
||||
|
||||
if listener_task.done():
|
||||
logger.debug(f"Listener task for task {task_id} already completed")
|
||||
return
|
||||
|
||||
# Cancel the listener task
|
||||
listener_task.cancel()
|
||||
|
||||
try:
|
||||
# Wait for the task to be cancelled with a timeout
|
||||
await asyncio.wait_for(listener_task, timeout=5.0)
|
||||
except asyncio.CancelledError:
|
||||
# Expected - the task was successfully cancelled
|
||||
pass
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Timeout waiting for listener task cancellation for task {task_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during listener task cancellation for task {task_id}: {e}")
|
||||
|
||||
logger.debug(f"Successfully unsubscribed from task {task_id}")
|
||||
@@ -10,6 +10,7 @@ from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
from .edit_agent import EditAgentTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .find_block import FindBlockTool
|
||||
@@ -34,6 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"add_understanding": AddUnderstandingTool(),
|
||||
"create_agent": CreateAgentTool(),
|
||||
"customize_agent": CustomizeAgentTool(),
|
||||
"edit_agent": EditAgentTool(),
|
||||
"find_agent": FindAgentTool(),
|
||||
"find_block": FindBlockTool(),
|
||||
|
||||
@@ -8,6 +8,7 @@ from .core import (
|
||||
DecompositionStep,
|
||||
LibraryAgentSummary,
|
||||
MarketplaceAgentSummary,
|
||||
customize_template,
|
||||
decompose_goal,
|
||||
enrich_library_agents_from_steps,
|
||||
extract_search_terms_from_steps,
|
||||
@@ -19,6 +20,7 @@ from .core import (
|
||||
get_library_agent_by_graph_id,
|
||||
get_library_agent_by_id,
|
||||
get_library_agents_for_generation,
|
||||
graph_to_json,
|
||||
json_to_graph,
|
||||
save_agent_to_library,
|
||||
search_marketplace_agents_for_generation,
|
||||
@@ -36,6 +38,7 @@ __all__ = [
|
||||
"LibraryAgentSummary",
|
||||
"MarketplaceAgentSummary",
|
||||
"check_external_service_health",
|
||||
"customize_template",
|
||||
"decompose_goal",
|
||||
"enrich_library_agents_from_steps",
|
||||
"extract_search_terms_from_steps",
|
||||
@@ -48,6 +51,7 @@ __all__ = [
|
||||
"get_library_agent_by_id",
|
||||
"get_library_agents_for_generation",
|
||||
"get_user_message_for_error",
|
||||
"graph_to_json",
|
||||
"is_external_service_configured",
|
||||
"json_to_graph",
|
||||
"save_agent_to_library",
|
||||
|
||||
@@ -7,18 +7,11 @@ from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.graph import (
|
||||
Graph,
|
||||
Link,
|
||||
Node,
|
||||
create_graph,
|
||||
get_graph,
|
||||
get_graph_all_versions,
|
||||
get_store_listed_graphs,
|
||||
)
|
||||
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .service import (
|
||||
customize_template_external,
|
||||
decompose_goal_external,
|
||||
generate_agent_external,
|
||||
generate_agent_patch_external,
|
||||
@@ -27,8 +20,6 @@ from .service import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||
|
||||
|
||||
class ExecutionSummary(TypedDict):
|
||||
"""Summary of a single execution for quality assessment."""
|
||||
@@ -549,15 +540,21 @@ async def decompose_goal(
|
||||
async def generate_agent(
|
||||
instructions: DecompositionResult | dict[str, Any],
|
||||
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate agent JSON from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams
|
||||
completion notification)
|
||||
task_id: Task ID for async processing (enables Redis Streams persistence
|
||||
and SSE delivery)
|
||||
|
||||
Returns:
|
||||
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
@@ -565,8 +562,13 @@ async def generate_agent(
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent")
|
||||
result = await generate_agent_external(
|
||||
dict(instructions), _to_dict_list(library_agents)
|
||||
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
|
||||
)
|
||||
|
||||
# Don't modify async response
|
||||
if result and result.get("status") == "accepted":
|
||||
return result
|
||||
|
||||
if result:
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
return result
|
||||
@@ -657,45 +659,6 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
)
|
||||
|
||||
|
||||
def _reassign_node_ids(graph: Graph) -> None:
|
||||
"""Reassign all node and link IDs to new UUIDs.
|
||||
|
||||
This is needed when creating a new version to avoid unique constraint violations.
|
||||
"""
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||
|
||||
for node in graph.nodes:
|
||||
node.id = id_map[node.id]
|
||||
|
||||
for link in graph.links:
|
||||
link.id = str(uuid.uuid4())
|
||||
if link.source_id in id_map:
|
||||
link.source_id = id_map[link.source_id]
|
||||
if link.sink_id in id_map:
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
|
||||
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
|
||||
"""Populate user_id in AgentExecutorBlock nodes.
|
||||
|
||||
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
|
||||
This function fills in the actual user_id so sub-agents run with correct permissions.
|
||||
|
||||
Args:
|
||||
agent_json: Agent JSON dict (modified in place)
|
||||
user_id: User ID to set
|
||||
"""
|
||||
for node in agent_json.get("nodes", []):
|
||||
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
|
||||
input_default = node.get("input_default") or {}
|
||||
if not input_default.get("user_id"):
|
||||
input_default["user_id"] = user_id
|
||||
node["input_default"] = input_default
|
||||
logger.debug(
|
||||
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
|
||||
)
|
||||
|
||||
|
||||
async def save_agent_to_library(
|
||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||
) -> tuple[Graph, Any]:
|
||||
@@ -709,63 +672,21 @@ async def save_agent_to_library(
|
||||
Returns:
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
"""
|
||||
# Populate user_id in AgentExecutorBlock nodes before conversion
|
||||
_populate_agent_executor_user_ids(agent_json, user_id)
|
||||
|
||||
graph = json_to_graph(agent_json)
|
||||
|
||||
if is_update:
|
||||
if graph.id:
|
||||
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
||||
if existing_versions:
|
||||
latest_version = max(v.version for v in existing_versions)
|
||||
graph.version = latest_version + 1
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
||||
else:
|
||||
graph.id = str(uuid.uuid4())
|
||||
graph.version = 1
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Creating new agent with ID {graph.id}")
|
||||
|
||||
created_graph = await create_graph(graph, user_id)
|
||||
|
||||
library_agents = await library_db.create_library_agent(
|
||||
graph=created_graph,
|
||||
user_id=user_id,
|
||||
sensitive_action_safe_mode=True,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
)
|
||||
|
||||
return created_graph, library_agents[0]
|
||||
return await library_db.update_graph_in_library(graph, user_id)
|
||||
return await library_db.create_graph_in_library(graph, user_id)
|
||||
|
||||
|
||||
async def get_agent_as_json(
|
||||
agent_id: str, user_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch an agent and convert to JSON format for editing.
|
||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||
"""Convert a Graph object to JSON format for the agent generator.
|
||||
|
||||
Args:
|
||||
agent_id: Graph ID or library agent ID
|
||||
user_id: User ID
|
||||
graph: Graph object to convert
|
||||
|
||||
Returns:
|
||||
Agent as JSON dict or None if not found
|
||||
Agent as JSON dict
|
||||
"""
|
||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||
|
||||
if not graph and user_id:
|
||||
try:
|
||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id, version=None, user_id=user_id
|
||||
)
|
||||
except NotFoundError:
|
||||
pass
|
||||
|
||||
if not graph:
|
||||
return None
|
||||
|
||||
nodes = []
|
||||
for node in graph.nodes:
|
||||
nodes.append(
|
||||
@@ -802,10 +723,41 @@ async def get_agent_as_json(
|
||||
}
|
||||
|
||||
|
||||
async def get_agent_as_json(
|
||||
agent_id: str, user_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch an agent and convert to JSON format for editing.
|
||||
|
||||
Args:
|
||||
agent_id: Graph ID or library agent ID
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Agent as JSON dict or None if not found
|
||||
"""
|
||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||
|
||||
if not graph and user_id:
|
||||
try:
|
||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id, version=None, user_id=user_id
|
||||
)
|
||||
except NotFoundError:
|
||||
pass
|
||||
|
||||
if not graph:
|
||||
return None
|
||||
|
||||
return graph_to_json(graph)
|
||||
|
||||
|
||||
async def generate_agent_patch(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[AgentSummary] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Update an existing agent using natural language.
|
||||
|
||||
@@ -818,10 +770,12 @@ async def generate_agent_patch(
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
error dict {"type": "error", ...}, or None on unexpected error
|
||||
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
@@ -829,5 +783,43 @@ async def generate_agent_patch(
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||
return await generate_agent_patch_external(
|
||||
update_request, current_agent, _to_dict_list(library_agents)
|
||||
update_request,
|
||||
current_agent,
|
||||
_to_dict_list(library_agents),
|
||||
operation_id,
|
||||
task_id,
|
||||
)
|
||||
|
||||
|
||||
async def customize_template(
|
||||
template_agent: dict[str, Any],
|
||||
modification_request: str,
|
||||
context: str = "",
|
||||
) -> dict[str, Any] | None:
|
||||
"""Customize a template/marketplace agent using natural language.
|
||||
|
||||
This is used when users want to modify a template or marketplace agent
|
||||
to fit their specific needs before adding it to their library.
|
||||
|
||||
The external Agent Generator service handles:
|
||||
- Understanding the modification request
|
||||
- Applying changes to the template
|
||||
- Fixing and validating the result
|
||||
|
||||
Args:
|
||||
template_agent: The template agent JSON to customize
|
||||
modification_request: Natural language description of customizations
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
|
||||
Returns:
|
||||
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
error dict {"type": "error", ...}, or None on unexpected error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for customize_template")
|
||||
return await customize_template_external(
|
||||
template_agent, modification_request, context
|
||||
)
|
||||
|
||||
@@ -212,24 +212,45 @@ async def decompose_goal_external(
|
||||
async def generate_agent_external(
|
||||
instructions: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate an agent from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
||||
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {"instructions": instructions}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await client.post("/api/generate-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {
|
||||
"status": "accepted",
|
||||
"operation_id": operation_id,
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
@@ -261,6 +282,8 @@ async def generate_agent_patch_external(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate a patch for an existing agent.
|
||||
|
||||
@@ -268,21 +291,40 @@ async def generate_agent_patch_external(
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, or error dict on error
|
||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await client.post("/api/update-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async update request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {
|
||||
"status": "accepted",
|
||||
"operation_id": operation_id,
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
@@ -326,6 +368,77 @@ async def generate_agent_patch_external(
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def customize_template_external(
|
||||
template_agent: dict[str, Any],
|
||||
modification_request: str,
|
||||
context: str = "",
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to customize a template/marketplace agent.
|
||||
|
||||
Args:
|
||||
template_agent: The template agent JSON to customize
|
||||
modification_request: Natural language description of customizations
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
|
||||
Returns:
|
||||
Customized agent JSON, clarifying questions dict, or error dict on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
request = modification_request
|
||||
if context:
|
||||
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"template_agent_json": template_agent,
|
||||
"modification_request": request,
|
||||
}
|
||||
|
||||
try:
|
||||
response = await client.post("/api/template-modification", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator template customization failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Check if it's clarifying questions
|
||||
if data.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
|
||||
# Check if it's an error passed through
|
||||
if data.get("type") == "error":
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
|
||||
# Otherwise return the customized agent JSON
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||
"""Get available blocks from the external service.
|
||||
|
||||
|
||||
@@ -206,9 +206,9 @@ async def search_agents(
|
||||
]
|
||||
)
|
||||
no_results_msg = (
|
||||
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
|
||||
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
|
||||
if source == "marketplace"
|
||||
else f"No agents matching '{query}' found in your library."
|
||||
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||
@@ -224,10 +224,10 @@ async def search_agents(
|
||||
message = (
|
||||
"Now you have found some options for the user to choose from. "
|
||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||
"Please ask the user if they would like to use any of these agents."
|
||||
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
|
||||
if source == "marketplace"
|
||||
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
|
||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
|
||||
)
|
||||
|
||||
return AgentsFoundResponse(
|
||||
|
||||
@@ -18,6 +18,7 @@ from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
@@ -98,6 +99,10 @@ class CreateAgentTool(BaseTool):
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not description:
|
||||
return ErrorResponse(
|
||||
message="Please provide a description of what the agent should do.",
|
||||
@@ -219,7 +224,12 @@ class CreateAgentTool(BaseTool):
|
||||
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||
|
||||
try:
|
||||
agent_json = await generate_agent(decomposition_result, library_agents)
|
||||
agent_json = await generate_agent(
|
||||
decomposition_result,
|
||||
library_agents,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
@@ -263,6 +273,19 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if agent_json.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent generation delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent generation started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
|
||||
@@ -0,0 +1,337 @@
|
||||
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
customize_template,
|
||||
get_user_message_for_error,
|
||||
graph_to_json,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomizeAgentTool(BaseTool):
|
||||
"""Tool for customizing marketplace/template agents using natural language."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "customize_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Customize a marketplace or template agent using natural language. "
|
||||
"Takes an existing agent from the marketplace and modifies it based on "
|
||||
"the user's requirements before adding to their library."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The marketplace agent ID in format 'creator/slug' "
|
||||
"(e.g., 'autogpt/newsletter-writer'). "
|
||||
"Get this from find_agent results."
|
||||
),
|
||||
},
|
||||
"modifications": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Natural language description of how to customize the agent. "
|
||||
"Be specific about what changes you want to make."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the customized agent to the user's library. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "modifications"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the customize_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Parse the agent ID to get creator/slug
|
||||
2. Fetch the template agent from the marketplace
|
||||
3. Call customize_template with the modification request
|
||||
4. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
modifications = kwargs.get("modifications", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
||||
error="missing_agent_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not modifications:
|
||||
return ErrorResponse(
|
||||
message="Please describe how you want to customize this agent.",
|
||||
error="missing_modifications",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Parse agent_id in format "creator/slug"
|
||||
parts = [p.strip() for p in agent_id.split("/")]
|
||||
if len(parts) != 2 or not parts[0] or not parts[1]:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Invalid agent ID format: '{agent_id}'. "
|
||||
"Expected format is 'creator/agent-name' "
|
||||
"(e.g., 'autogpt/newsletter-writer')."
|
||||
),
|
||||
error="invalid_agent_id_format",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
creator_username, agent_slug = parts
|
||||
|
||||
# Fetch the marketplace agent details
|
||||
try:
|
||||
agent_details = await store_db.get_store_agent_details(
|
||||
username=creator_username, agent_name=agent_slug
|
||||
)
|
||||
except AgentNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Could not find marketplace agent '{agent_id}'. "
|
||||
"Please check the agent ID and try again."
|
||||
),
|
||||
error="agent_not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to fetch the marketplace agent. Please try again.",
|
||||
error="fetch_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agent_details.store_listing_version_id:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The agent '{agent_id}' does not have an available version. "
|
||||
"Please try a different agent."
|
||||
),
|
||||
error="no_version_available",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Get the full agent graph
|
||||
try:
|
||||
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
||||
template_agent = graph_to_json(graph)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to fetch the agent configuration. Please try again.",
|
||||
error="graph_fetch_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Call customize_template
|
||||
try:
|
||||
result = await customize_template(
|
||||
template_agent=template_agent,
|
||||
modification_request=modifications,
|
||||
context=context,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent customization is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Failed to customize the agent due to a service error. "
|
||||
"Please try again."
|
||||
),
|
||||
error="customization_service_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Failed to customize the agent. "
|
||||
"The agent generation service may be unavailable or timed out. "
|
||||
"Please try again."
|
||||
),
|
||||
error="customization_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Handle error response
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="customize the agent",
|
||||
llm_parse_message=(
|
||||
"The AI had trouble customizing the agent. "
|
||||
"Please try again or simplify your request."
|
||||
),
|
||||
validation_message=(
|
||||
"The customized agent failed validation. "
|
||||
"Please try rephrasing your request."
|
||||
),
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"customization_failed:{error_type}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Handle clarifying questions
|
||||
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions") or []
|
||||
if not isinstance(questions, list):
|
||||
logger.error(
|
||||
f"Unexpected clarifying questions format: {type(questions)}"
|
||||
)
|
||||
questions = []
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to customize this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
if isinstance(q, dict)
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Result should be the customized agent JSON
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
||||
return ErrorResponse(
|
||||
message="Failed to customize the agent due to an unexpected response.",
|
||||
error="unexpected_response_type",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
customized_agent = result
|
||||
|
||||
agent_name = customized_agent.get(
|
||||
"name", f"Customized {agent_details.agent_name}"
|
||||
)
|
||||
agent_description = customized_agent.get("description", "")
|
||||
nodes = customized_agent.get("nodes")
|
||||
links = customized_agent.get("links")
|
||||
node_count = len(nodes) if isinstance(nodes, list) else 0
|
||||
link_count = len(links) if isinstance(links, list) else 0
|
||||
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've customized the agent '{agent_details.agent_name}'. "
|
||||
f"The customized agent has {node_count} blocks. "
|
||||
f"Review it and call customize_agent with save=true to save it."
|
||||
),
|
||||
agent_json=customized_agent,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to user's library
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
customized_agent, user_id, is_update=False
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=(
|
||||
f"Customized agent '{created_graph.name}' "
|
||||
f"(based on '{agent_details.agent_name}') "
|
||||
f"has been saved to your library!"
|
||||
),
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving customized agent: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to save the customized agent. Please try again.",
|
||||
error="save_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -17,6 +17,7 @@ from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
@@ -104,6 +105,10 @@ class EditAgentTool(BaseTool):
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
@@ -149,7 +154,11 @@ class EditAgentTool(BaseTool):
|
||||
|
||||
try:
|
||||
result = await generate_agent_patch(
|
||||
update_request, current_agent, library_agents
|
||||
update_request,
|
||||
current_agent,
|
||||
library_agents,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
@@ -169,6 +178,20 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if result.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent edit delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent edit started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
|
||||
@@ -13,10 +13,32 @@ from backend.api.features.chat.tools.models import (
|
||||
NoResultsResponse,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.data.block import get_block
|
||||
from backend.data.block import BlockType, get_block
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TARGET_RESULTS = 10
|
||||
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
|
||||
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
|
||||
_OVERFETCH_PAGE_SIZE = 40
|
||||
|
||||
# Block types that only work within graphs and cannot run standalone in CoPilot.
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES = {
|
||||
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
|
||||
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
|
||||
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
|
||||
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
|
||||
BlockType.NOTE, # Visual annotation only - no runtime behavior
|
||||
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
|
||||
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
|
||||
}
|
||||
|
||||
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
||||
COPILOT_EXCLUDED_BLOCK_IDS = {
|
||||
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
|
||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||
}
|
||||
|
||||
|
||||
class FindBlockTool(BaseTool):
|
||||
"""Tool for searching available blocks."""
|
||||
@@ -88,7 +110,7 @@ class FindBlockTool(BaseTool):
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=10,
|
||||
page_size=_OVERFETCH_PAGE_SIZE,
|
||||
)
|
||||
|
||||
if not results:
|
||||
@@ -108,60 +130,90 @@ class FindBlockTool(BaseTool):
|
||||
block = get_block(block_id)
|
||||
|
||||
# Skip disabled blocks
|
||||
if block and not block.disabled:
|
||||
# Get input/output schemas
|
||||
input_schema = {}
|
||||
output_schema = {}
|
||||
try:
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
if not block or block.disabled:
|
||||
continue
|
||||
|
||||
# Get categories from block instance
|
||||
categories = []
|
||||
if hasattr(block, "categories") and block.categories:
|
||||
categories = [cat.value for cat in block.categories]
|
||||
# Skip blocks excluded from CoPilot (graph-only blocks)
|
||||
if (
|
||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
):
|
||||
continue
|
||||
|
||||
# Extract required inputs for easier use
|
||||
required_inputs: list[BlockInputFieldInfo] = []
|
||||
if input_schema:
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = set(input_schema.get("required", []))
|
||||
# Get credential field names to exclude from required inputs
|
||||
credentials_fields = set(
|
||||
block.input_schema.get_credentials_fields().keys()
|
||||
)
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields - they're handled separately
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
required_inputs.append(
|
||||
BlockInputFieldInfo(
|
||||
name=field_name,
|
||||
type=field_schema.get("type", "string"),
|
||||
description=field_schema.get("description", ""),
|
||||
required=field_name in required_fields,
|
||||
default=field_schema.get("default"),
|
||||
)
|
||||
)
|
||||
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
description=block.description or "",
|
||||
categories=categories,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
required_inputs=required_inputs,
|
||||
)
|
||||
# Get input/output schemas
|
||||
input_schema = {}
|
||||
output_schema = {}
|
||||
try:
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to generate input schema for block %s: %s",
|
||||
block_id,
|
||||
e,
|
||||
)
|
||||
try:
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to generate output schema for block %s: %s",
|
||||
block_id,
|
||||
e,
|
||||
)
|
||||
|
||||
# Get categories from block instance
|
||||
categories = []
|
||||
if hasattr(block, "categories") and block.categories:
|
||||
categories = [cat.value for cat in block.categories]
|
||||
|
||||
# Extract required inputs for easier use
|
||||
required_inputs: list[BlockInputFieldInfo] = []
|
||||
if input_schema:
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = set(input_schema.get("required", []))
|
||||
# Get credential field names to exclude from required inputs
|
||||
credentials_fields = set(
|
||||
block.input_schema.get_credentials_fields().keys()
|
||||
)
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields - they're handled separately
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
required_inputs.append(
|
||||
BlockInputFieldInfo(
|
||||
name=field_name,
|
||||
type=field_schema.get("type", "string"),
|
||||
description=field_schema.get("description", ""),
|
||||
required=field_name in required_fields,
|
||||
default=field_schema.get("default"),
|
||||
)
|
||||
)
|
||||
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
description=block.description or "",
|
||||
categories=categories,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
required_inputs=required_inputs,
|
||||
)
|
||||
)
|
||||
|
||||
if len(blocks) >= _TARGET_RESULTS:
|
||||
break
|
||||
|
||||
if blocks and len(blocks) < _TARGET_RESULTS:
|
||||
logger.debug(
|
||||
"find_block returned %d/%d results for query '%s' "
|
||||
"(filtered %d excluded/disabled blocks)",
|
||||
len(blocks),
|
||||
_TARGET_RESULTS,
|
||||
query,
|
||||
len(results) - len(blocks),
|
||||
)
|
||||
|
||||
if not blocks:
|
||||
return NoResultsResponse(
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
"""Tests for block filtering in FindBlockTool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
FindBlockTool,
|
||||
)
|
||||
from backend.api.features.chat.tools.models import BlockListResponse
|
||||
from backend.data.block import BlockType
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-find-block"
|
||||
|
||||
|
||||
def make_mock_block(
|
||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||
):
|
||||
"""Create a mock block for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.description = f"{name} description"
|
||||
mock.block_type = block_type
|
||||
mock.disabled = disabled
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||
mock.input_schema.get_credentials_fields.return_value = {}
|
||||
mock.output_schema = MagicMock()
|
||||
mock.output_schema.jsonschema.return_value = {}
|
||||
mock.categories = []
|
||||
return mock
|
||||
|
||||
|
||||
class TestFindBlockFiltering:
|
||||
"""Tests for block filtering in FindBlockTool."""
|
||||
|
||||
def test_excluded_block_types_contains_expected_types(self):
|
||||
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
|
||||
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
|
||||
def test_excluded_block_ids_contains_smart_decision_maker(self):
|
||||
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
||||
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_type_filtered_from_results(self):
|
||||
"""Verify blocks with excluded BlockTypes are filtered from search results."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
|
||||
search_results = [
|
||||
{"content_id": "input-block-id", "score": 0.9},
|
||||
{"content_id": "standard-block-id", "score": 0.8},
|
||||
]
|
||||
|
||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||
standard_block = make_mock_block(
|
||||
"standard-block-id", "HTTP Request", BlockType.STANDARD
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
"input-block-id": input_block,
|
||||
"standard-block-id": standard_block,
|
||||
}.get(block_id)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, 2),
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query="test"
|
||||
)
|
||||
|
||||
# Should only return the standard block, not the INPUT block
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 1
|
||||
assert response.blocks[0].id == "standard-block-id"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_id_filtered_from_results(self):
|
||||
"""Verify SmartDecisionMakerBlock is filtered from search results."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
search_results = [
|
||||
{"content_id": smart_decision_id, "score": 0.9},
|
||||
{"content_id": "normal-block-id", "score": 0.8},
|
||||
]
|
||||
|
||||
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
|
||||
smart_block = make_mock_block(
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
)
|
||||
normal_block = make_mock_block(
|
||||
"normal-block-id", "Normal Block", BlockType.STANDARD
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
smart_decision_id: smart_block,
|
||||
"normal-block-id": normal_block,
|
||||
}.get(block_id)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, 2),
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query="decision"
|
||||
)
|
||||
|
||||
# Should only return normal block, not SmartDecisionMakerBlock
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 1
|
||||
assert response.blocks[0].id == "normal-block-id"
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Shared helpers for chat tools."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_inputs_from_schema(
|
||||
input_schema: dict[str, Any],
|
||||
exclude_fields: set[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Extract input field info from JSON schema."""
|
||||
if not isinstance(input_schema, dict):
|
||||
return []
|
||||
|
||||
exclude = exclude_fields or set()
|
||||
properties = input_schema.get("properties", {})
|
||||
required = set(input_schema.get("required", []))
|
||||
|
||||
return [
|
||||
{
|
||||
"name": name,
|
||||
"title": schema.get("title", name),
|
||||
"type": schema.get("type", "string"),
|
||||
"description": schema.get("description", ""),
|
||||
"required": name in required,
|
||||
"default": schema.get("default"),
|
||||
}
|
||||
for name, schema in properties.items()
|
||||
if name not in exclude
|
||||
]
|
||||
@@ -38,6 +38,8 @@ class ResponseType(str, Enum):
|
||||
OPERATION_STARTED = "operation_started"
|
||||
OPERATION_PENDING = "operation_pending"
|
||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||
# Input validation
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -68,6 +70,10 @@ class AgentInfo(BaseModel):
|
||||
has_external_trigger: bool | None = None
|
||||
new_output: bool | None = None
|
||||
graph_id: str | None = None
|
||||
inputs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Input schema for the agent, including field names, types, and defaults",
|
||||
)
|
||||
|
||||
|
||||
class AgentsFoundResponse(ToolResponseBase):
|
||||
@@ -194,6 +200,20 @@ class ErrorResponse(ToolResponseBase):
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class InputValidationErrorResponse(ToolResponseBase):
|
||||
"""Response when run_agent receives unknown input fields."""
|
||||
|
||||
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
|
||||
unrecognized_fields: list[str] = Field(
|
||||
description="List of input field names that were not recognized"
|
||||
)
|
||||
inputs: dict[str, Any] = Field(
|
||||
description="The agent's valid input schema for reference"
|
||||
)
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
|
||||
|
||||
# Agent output models
|
||||
class ExecutionOutputInfo(BaseModel):
|
||||
"""Summary of a single execution's outputs."""
|
||||
@@ -352,11 +372,15 @@ class OperationStartedResponse(ToolResponseBase):
|
||||
|
||||
This is returned immediately to the client while the operation continues
|
||||
to execute. The user can close the tab and check back later.
|
||||
|
||||
The task_id can be used to reconnect to the SSE stream via
|
||||
GET /chat/tasks/{task_id}/stream?last_idx=0
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
operation_id: str
|
||||
tool_name: str
|
||||
task_id: str | None = None # For SSE reconnection
|
||||
|
||||
|
||||
class OperationPendingResponse(ToolResponseBase):
|
||||
@@ -380,3 +404,20 @@ class OperationInProgressResponse(ToolResponseBase):
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class AsyncProcessingResponse(ToolResponseBase):
|
||||
"""Response when an operation has been delegated to async processing.
|
||||
|
||||
This is returned by tools when the external service accepts the request
|
||||
for async processing (HTTP 202 Accepted). The Redis Streams completion
|
||||
consumer will handle the result when the external service completes.
|
||||
|
||||
The status field is specifically "accepted" to allow the long-running tool
|
||||
handler to detect this response and skip LLM continuation.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
status: str = "accepted" # Must be "accepted" for detection
|
||||
operation_id: str | None = None
|
||||
task_id: str | None = None
|
||||
|
||||
@@ -24,12 +24,14 @@ from backend.util.timezone_utils import (
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .helpers import get_inputs_from_schema
|
||||
from .models import (
|
||||
AgentDetails,
|
||||
AgentDetailsResponse,
|
||||
ErrorResponse,
|
||||
ExecutionOptions,
|
||||
ExecutionStartedResponse,
|
||||
InputValidationErrorResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
@@ -260,7 +262,7 @@ class RunAgentTool(BaseTool):
|
||||
),
|
||||
requirements={
|
||||
"credentials": requirements_creds_list,
|
||||
"inputs": self._get_inputs_list(graph.input_schema),
|
||||
"inputs": get_inputs_from_schema(graph.input_schema),
|
||||
"execution_modes": self._get_execution_modes(graph),
|
||||
},
|
||||
),
|
||||
@@ -273,6 +275,22 @@ class RunAgentTool(BaseTool):
|
||||
input_properties = graph.input_schema.get("properties", {})
|
||||
required_fields = set(graph.input_schema.get("required", []))
|
||||
provided_inputs = set(params.inputs.keys())
|
||||
valid_fields = set(input_properties.keys())
|
||||
|
||||
# Check for unknown input fields
|
||||
unrecognized_fields = provided_inputs - valid_fields
|
||||
if unrecognized_fields:
|
||||
return InputValidationErrorResponse(
|
||||
message=(
|
||||
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
||||
f"Agent was not executed. Please use the correct field names from the schema."
|
||||
),
|
||||
session_id=session_id,
|
||||
unrecognized_fields=sorted(unrecognized_fields),
|
||||
inputs=graph.input_schema,
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
# If agent has inputs but none were provided AND use_defaults is not set,
|
||||
# always show what's available first so user can decide
|
||||
@@ -352,22 +370,6 @@ class RunAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Extract inputs list from schema."""
|
||||
inputs_list = []
|
||||
if isinstance(input_schema, dict) and "properties" in input_schema:
|
||||
for field_name, field_schema in input_schema["properties"].items():
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name in input_schema.get("required", []),
|
||||
}
|
||||
)
|
||||
return inputs_list
|
||||
|
||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||
"""Get available execution modes for the graph."""
|
||||
trigger_info = graph.trigger_setup_info
|
||||
@@ -381,7 +383,7 @@ class RunAgentTool(BaseTool):
|
||||
suffix: str,
|
||||
) -> str:
|
||||
"""Build a message describing available inputs for an agent."""
|
||||
inputs_list = self._get_inputs_list(graph.input_schema)
|
||||
inputs_list = get_inputs_from_schema(graph.input_schema)
|
||||
required_names = [i["name"] for i in inputs_list if i["required"]]
|
||||
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
||||
|
||||
|
||||
@@ -402,3 +402,42 @@ async def test_run_agent_schedule_without_name(setup_test_data):
|
||||
# Should return error about missing schedule_name
|
||||
assert result_data.get("type") == "error"
|
||||
assert "schedule_name" in result_data["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
|
||||
"""Test that run_agent returns input_validation_error for unknown input fields."""
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
tool = RunAgentTool()
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute with unknown input field names
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={
|
||||
"unknown_field": "some value",
|
||||
"another_unknown": "another value",
|
||||
},
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return input_validation_error type with unrecognized fields
|
||||
assert result_data.get("type") == "input_validation_error"
|
||||
assert "unrecognized_fields" in result_data
|
||||
assert set(result_data["unrecognized_fields"]) == {
|
||||
"another_unknown",
|
||||
"unknown_field",
|
||||
}
|
||||
assert "inputs" in result_data # Contains the valid schema
|
||||
assert "Agent was not executed" in result_data["message"]
|
||||
|
||||
@@ -5,15 +5,22 @@ import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.block import get_block
|
||||
from backend.api.features.chat.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
)
|
||||
from backend.data.block import AnyBlockSchema, get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .base import BaseTool
|
||||
from .helpers import get_inputs_from_schema
|
||||
from .models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
@@ -22,7 +29,10 @@ from .models import (
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
from .utils import build_missing_credentials_from_field_info
|
||||
from .utils import (
|
||||
build_missing_credentials_from_field_info,
|
||||
match_credentials_to_requirements,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -71,65 +81,6 @@ class RunBlockTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _check_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: Any,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Check if user has required credentials for a block.
|
||||
|
||||
Returns:
|
||||
tuple[matched_credentials, missing_credentials]
|
||||
"""
|
||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
missing_credentials: list[CredentialsMetaInput] = []
|
||||
|
||||
# Get credential field info from block's input schema
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
|
||||
if not credentials_fields_info:
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
# Get user's available credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
# field_info.provider is a frozenset of acceptable providers
|
||||
# field_info.supported_types is a frozenset of acceptable types
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in field_info.provider
|
||||
and cred.type in field_info.supported_types
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if matching_cred:
|
||||
matched_credentials[field_name] = CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
else:
|
||||
# Create a placeholder for the missing credential
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing_credentials.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=field_name.replace("_", " ").title(),
|
||||
)
|
||||
)
|
||||
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
@@ -184,12 +135,24 @@ class RunBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if block is excluded from CoPilot (graph-only blocks)
|
||||
if (
|
||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
):
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' cannot be run directly in CoPilot. "
|
||||
"This block is designed for use within graphs only."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
# Check credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||
user_id, block
|
||||
matched_credentials, missing_credentials = (
|
||||
await self._resolve_block_credentials(user_id, block, input_data)
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
@@ -318,29 +281,75 @@ class RunBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||
async def _resolve_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Resolve credentials for a block by matching user's available credentials.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
block: Block to resolve credentials for
|
||||
input_data: Input data for the block (used to determine provider via discriminator)
|
||||
|
||||
Returns:
|
||||
tuple of (matched_credentials, missing_credentials) - matched credentials
|
||||
are used for block execution, missing ones indicate setup requirements.
|
||||
"""
|
||||
input_data = input_data or {}
|
||||
requirements = self._resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
if not requirements:
|
||||
return {}, []
|
||||
|
||||
return await match_credentials_to_requirements(user_id, requirements)
|
||||
|
||||
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
|
||||
"""Extract non-credential inputs from block schema."""
|
||||
inputs_list = []
|
||||
schema = block.input_schema.jsonschema()
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = set(schema.get("required", []))
|
||||
|
||||
# Get credential field names to exclude
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
def _resolve_discriminated_credentials(
|
||||
self,
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any],
|
||||
) -> dict[str, CredentialsFieldInfo]:
|
||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
if not credentials_fields_info:
|
||||
return {}
|
||||
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name in required_fields,
|
||||
}
|
||||
)
|
||||
resolved: dict[str, CredentialsFieldInfo] = {}
|
||||
|
||||
return inputs_list
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(
|
||||
field_info.discriminator
|
||||
)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
# For host-scoped credentials, add the discriminator value
|
||||
# (e.g., URL) so _credential_is_for_host can match it
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
f"Discriminated provider for {field_name}: "
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
|
||||
resolved[field_name] = effective_field_info
|
||||
|
||||
return resolved
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Tests for block execution guards in RunBlockTool."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.models import ErrorResponse
|
||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||
from backend.data.block import BlockType
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-run-block"
|
||||
|
||||
|
||||
def make_mock_block(
|
||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||
):
|
||||
"""Create a mock block for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.block_type = block_type
|
||||
mock.disabled = disabled
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||
mock.input_schema.get_credentials_fields_info.return_value = []
|
||||
return mock
|
||||
|
||||
|
||||
class TestRunBlockFiltering:
|
||||
"""Tests for block execution guards in RunBlockTool."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_type_returns_error(self):
|
||||
"""Attempting to execute a block with excluded BlockType returns error."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=input_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="input-block-id",
|
||||
input_data={},
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "cannot be run directly in CoPilot" in response.message
|
||||
assert "designed for use within graphs only" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_id_returns_error(self):
|
||||
"""Attempting to execute SmartDecisionMakerBlock returns error."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
smart_block = make_mock_block(
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=smart_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id=smart_decision_id,
|
||||
input_data={},
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "cannot be run directly in CoPilot" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_non_excluded_block_passes_guard(self):
|
||||
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
standard_block = make_mock_block(
|
||||
"standard-id", "HTTP Request", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=standard_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="standard-id",
|
||||
input_data={},
|
||||
)
|
||||
|
||||
# Should NOT be an ErrorResponse about CoPilot exclusion
|
||||
# (may be other errors like missing credentials, but not the exclusion guard)
|
||||
if isinstance(response, ErrorResponse):
|
||||
assert "cannot be run directly in CoPilot" not in response.message
|
||||
@@ -6,9 +6,14 @@ from typing import Any
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.data.model import (
|
||||
Credentials,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -39,14 +44,8 @@ async def fetch_graph_from_store_slug(
|
||||
return None, None
|
||||
|
||||
# Get the graph from store listing version
|
||||
graph_meta = await store_db.get_available_graph(
|
||||
store_agent.store_listing_version_id
|
||||
)
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=graph_meta.id,
|
||||
version=graph_meta.version,
|
||||
user_id=None, # Public access
|
||||
include_subgraphs=True,
|
||||
graph = await store_db.get_available_graph(
|
||||
store_agent.store_listing_version_id, hide_nodes=False
|
||||
)
|
||||
return graph, store_agent
|
||||
|
||||
@@ -123,7 +122,7 @@ def build_missing_credentials_from_graph(
|
||||
|
||||
return {
|
||||
field_key: _serialize_missing_credential(field_key, field_info)
|
||||
for field_key, (field_info, _node_fields) in aggregated_fields.items()
|
||||
for field_key, (field_info, _, _) in aggregated_fields.items()
|
||||
if field_key not in matched_keys
|
||||
}
|
||||
|
||||
@@ -225,6 +224,99 @@ async def get_or_create_library_agent(
|
||||
return library_agents[0]
|
||||
|
||||
|
||||
async def match_credentials_to_requirements(
|
||||
user_id: str,
|
||||
requirements: dict[str, CredentialsFieldInfo],
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Match user's credentials against a dictionary of credential requirements.
|
||||
|
||||
This is the core matching logic shared by both graph and block credential matching.
|
||||
"""
|
||||
matched: dict[str, CredentialsMetaInput] = {}
|
||||
missing: list[CredentialsMetaInput] = []
|
||||
|
||||
if not requirements:
|
||||
return matched, missing
|
||||
|
||||
available_creds = await get_user_credentials(user_id)
|
||||
|
||||
for field_name, field_info in requirements.items():
|
||||
matching_cred = find_matching_credential(available_creds, field_info)
|
||||
|
||||
if matching_cred:
|
||||
try:
|
||||
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
||||
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
||||
f"credential_id={matching_cred.id}",
|
||||
exc_info=True,
|
||||
)
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=f"{field_name} (validation failed: {e})",
|
||||
)
|
||||
)
|
||||
else:
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=field_name.replace("_", " ").title(),
|
||||
)
|
||||
)
|
||||
|
||||
return matched, missing
|
||||
|
||||
|
||||
async def get_user_credentials(user_id: str) -> list[Credentials]:
|
||||
"""Get all available credentials for a user."""
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
return await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
|
||||
def find_matching_credential(
|
||||
available_creds: list[Credentials],
|
||||
field_info: CredentialsFieldInfo,
|
||||
) -> Credentials | None:
|
||||
"""Find a credential that matches the required provider, type, scopes, and host."""
|
||||
for cred in available_creds:
|
||||
if cred.provider not in field_info.provider:
|
||||
continue
|
||||
if cred.type not in field_info.supported_types:
|
||||
continue
|
||||
if cred.type == "oauth2" and not _credential_has_required_scopes(
|
||||
cred, field_info
|
||||
):
|
||||
continue
|
||||
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
|
||||
continue
|
||||
return cred
|
||||
return None
|
||||
|
||||
|
||||
def create_credential_meta_from_match(
|
||||
matching_cred: Credentials,
|
||||
) -> CredentialsMetaInput:
|
||||
"""Create a CredentialsMetaInput from a matched credential."""
|
||||
return CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
|
||||
|
||||
async def match_user_credentials_to_graph(
|
||||
user_id: str,
|
||||
graph: GraphModel,
|
||||
@@ -264,7 +356,8 @@ async def match_user_credentials_to_graph(
|
||||
# provider is in the set of acceptable providers.
|
||||
for credential_field_name, (
|
||||
credential_requirements,
|
||||
_node_fields,
|
||||
_,
|
||||
_,
|
||||
) in aggregated_creds.items():
|
||||
# Find first matching credential by provider, type, and scopes
|
||||
matching_cred = next(
|
||||
@@ -273,7 +366,14 @@ async def match_user_credentials_to_graph(
|
||||
for cred in available_creds
|
||||
if cred.provider in credential_requirements.provider
|
||||
and cred.type in credential_requirements.supported_types
|
||||
and _credential_has_required_scopes(cred, credential_requirements)
|
||||
and (
|
||||
cred.type != "oauth2"
|
||||
or _credential_has_required_scopes(cred, credential_requirements)
|
||||
)
|
||||
and (
|
||||
cred.type != "host_scoped"
|
||||
or _credential_is_for_host(cred, credential_requirements)
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -318,27 +418,32 @@ async def match_user_credentials_to_graph(
|
||||
|
||||
|
||||
def _credential_has_required_scopes(
|
||||
credential: Credentials,
|
||||
credential: OAuth2Credentials,
|
||||
requirements: CredentialsFieldInfo,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a credential has all the scopes required by the block.
|
||||
|
||||
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||
of the required scopes. For other credential types, returns True (no scope check).
|
||||
"""
|
||||
# Only OAuth2 credentials have scopes to check
|
||||
if credential.type != "oauth2":
|
||||
return True
|
||||
|
||||
"""Check if an OAuth2 credential has all the scopes required by the input."""
|
||||
# If no scopes are required, any credential matches
|
||||
if not requirements.required_scopes:
|
||||
return True
|
||||
|
||||
# Check that credential scopes are a superset of required scopes
|
||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||
|
||||
|
||||
def _credential_is_for_host(
|
||||
credential: HostScopedCredentials,
|
||||
requirements: CredentialsFieldInfo,
|
||||
) -> bool:
|
||||
"""Check if a host-scoped credential matches the host required by the input."""
|
||||
# We need to know the host to match host-scoped credentials to.
|
||||
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
|
||||
# to discriminator_values. No discriminator_values -> no host to match against.
|
||||
if not requirements.discriminator_values:
|
||||
return True
|
||||
|
||||
# Check that credential host matches required host.
|
||||
# Host-scoped credential inputs are grouped by host, so any item from the set works.
|
||||
return credential.matches_url(list(requirements.discriminator_values)[0])
|
||||
|
||||
|
||||
async def check_user_has_required_credentials(
|
||||
user_id: str,
|
||||
required_credentials: list[CredentialsMetaInput],
|
||||
|
||||
@@ -19,7 +19,10 @@ from backend.data.graph import GraphSettings
|
||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||
on_graph_activate,
|
||||
on_graph_deactivate,
|
||||
)
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
@@ -371,7 +374,7 @@ async def get_library_agent_by_graph_id(
|
||||
|
||||
|
||||
async def add_generated_agent_image(
|
||||
graph: graph_db.BaseGraph,
|
||||
graph: graph_db.GraphBaseMeta,
|
||||
user_id: str,
|
||||
library_agent_id: str,
|
||||
) -> Optional[prisma.models.LibraryAgent]:
|
||||
@@ -537,6 +540,92 @@ async def update_agent_version_in_library(
|
||||
return library_model.LibraryAgent.from_db(lib)
|
||||
|
||||
|
||||
async def create_graph_in_library(
|
||||
graph: graph_db.Graph,
|
||||
user_id: str,
|
||||
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||
"""Create a new graph and add it to the user's library."""
|
||||
graph.version = 1
|
||||
graph_model = graph_db.make_graph_model(graph, user_id)
|
||||
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||
|
||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||
|
||||
library_agents = await create_library_agent(
|
||||
graph=created_graph,
|
||||
user_id=user_id,
|
||||
sensitive_action_safe_mode=True,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
)
|
||||
|
||||
if created_graph.is_active:
|
||||
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
||||
|
||||
return created_graph, library_agents[0]
|
||||
|
||||
|
||||
async def update_graph_in_library(
|
||||
graph: graph_db.Graph,
|
||||
user_id: str,
|
||||
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||
"""Create a new version of an existing graph and update the library entry."""
|
||||
existing_versions = await graph_db.get_graph_all_versions(graph.id, user_id)
|
||||
current_active_version = (
|
||||
next((v for v in existing_versions if v.is_active), None)
|
||||
if existing_versions
|
||||
else None
|
||||
)
|
||||
graph.version = (
|
||||
max(v.version for v in existing_versions) + 1 if existing_versions else 1
|
||||
)
|
||||
|
||||
graph_model = graph_db.make_graph_model(graph, user_id)
|
||||
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||
|
||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||
|
||||
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
||||
if not library_agent:
|
||||
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
||||
|
||||
library_agent = await update_library_agent_version_and_settings(
|
||||
user_id, created_graph
|
||||
)
|
||||
|
||||
if created_graph.is_active:
|
||||
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
||||
await graph_db.set_graph_active_version(
|
||||
graph_id=created_graph.id,
|
||||
version=created_graph.version,
|
||||
user_id=user_id,
|
||||
)
|
||||
if current_active_version:
|
||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||
|
||||
return created_graph, library_agent
|
||||
|
||||
|
||||
async def update_library_agent_version_and_settings(
|
||||
user_id: str, agent_graph: graph_db.GraphModel
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Update library agent to point to new graph version and sync settings."""
|
||||
library = await update_agent_version_in_library(
|
||||
user_id, agent_graph.id, agent_graph.version
|
||||
)
|
||||
updated_settings = GraphSettings.from_graph(
|
||||
graph=agent_graph,
|
||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||
)
|
||||
if updated_settings != library.settings:
|
||||
library = await update_library_agent(
|
||||
library_agent_id=library.id,
|
||||
user_id=user_id,
|
||||
settings=updated_settings,
|
||||
)
|
||||
return library
|
||||
|
||||
|
||||
async def update_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
import fastapi
|
||||
import prisma.enums
|
||||
@@ -11,8 +11,8 @@ import prisma.types
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.data.graph import (
|
||||
GraphMeta,
|
||||
GraphModel,
|
||||
GraphModelWithoutNodes,
|
||||
get_graph,
|
||||
get_graph_as_admin,
|
||||
get_sub_graphs,
|
||||
@@ -334,7 +334,22 @@ async def get_store_agent_details(
|
||||
raise DatabaseError("Failed to fetch agent details") from e
|
||||
|
||||
|
||||
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
||||
@overload
|
||||
async def get_available_graph(
|
||||
store_listing_version_id: str, hide_nodes: Literal[False]
|
||||
) -> GraphModel: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def get_available_graph(
|
||||
store_listing_version_id: str, hide_nodes: Literal[True] = True
|
||||
) -> GraphModelWithoutNodes: ...
|
||||
|
||||
|
||||
async def get_available_graph(
|
||||
store_listing_version_id: str,
|
||||
hide_nodes: bool = True,
|
||||
) -> GraphModelWithoutNodes | GraphModel:
|
||||
try:
|
||||
# Get avaialble, non-deleted store listing version
|
||||
store_listing_version = (
|
||||
@@ -344,7 +359,7 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
||||
"isAvailable": True,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include={"AgentGraph": {"include": {"Nodes": True}}},
|
||||
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -354,7 +369,9 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
||||
detail=f"Store listing version {store_listing_version_id} not found",
|
||||
)
|
||||
|
||||
return GraphModel.from_db(store_listing_version.AgentGraph).meta()
|
||||
return (GraphModelWithoutNodes if hide_nodes else GraphModel).from_db(
|
||||
store_listing_version.AgentGraph
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent: {e}")
|
||||
|
||||
@@ -454,6 +454,9 @@ async def test_unified_hybrid_search_pagination(
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test unified search pagination works correctly."""
|
||||
# Use a unique search term to avoid matching other test data
|
||||
unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Create multiple items
|
||||
content_ids = []
|
||||
for i in range(5):
|
||||
@@ -465,14 +468,14 @@ async def test_unified_hybrid_search_pagination(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text=f"pagination test item number {i}",
|
||||
searchable_text=f"{unique_term} item number {i}",
|
||||
metadata={"index": i},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Get first page
|
||||
page1_results, total1 = await unified_hybrid_search(
|
||||
query="pagination test",
|
||||
query=unique_term,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=2,
|
||||
@@ -480,7 +483,7 @@ async def test_unified_hybrid_search_pagination(
|
||||
|
||||
# Get second page
|
||||
page2_results, total2 = await unified_hybrid_search(
|
||||
query="pagination test",
|
||||
query=unique_term,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=2,
|
||||
page_size=2,
|
||||
|
||||
@@ -8,6 +8,7 @@ Includes BM25 reranking for improved lexical relevance.
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
@@ -362,7 +363,11 @@ async def unified_hybrid_search(
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
results = await query_raw_with_schema(sql_query, *params)
|
||||
try:
|
||||
results = await query_raw_with_schema(sql_query, *params)
|
||||
except Exception as e:
|
||||
await _log_vector_error_diagnostics(e)
|
||||
raise
|
||||
|
||||
total = results[0]["total_count"] if results else 0
|
||||
# Apply BM25 reranking
|
||||
@@ -686,7 +691,11 @@ async def hybrid_search(
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
results = await query_raw_with_schema(sql_query, *params)
|
||||
try:
|
||||
results = await query_raw_with_schema(sql_query, *params)
|
||||
except Exception as e:
|
||||
await _log_vector_error_diagnostics(e)
|
||||
raise
|
||||
|
||||
total = results[0]["total_count"] if results else 0
|
||||
|
||||
@@ -718,6 +727,87 @@ async def hybrid_search_simple(
|
||||
return await hybrid_search(query=query, page=page, page_size=page_size)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Diagnostics
|
||||
# ============================================================================
|
||||
|
||||
# Rate limit: only log vector error diagnostics once per this interval
|
||||
_VECTOR_DIAG_INTERVAL_SECONDS = 60
|
||||
_last_vector_diag_time: float = 0
|
||||
|
||||
|
||||
async def _log_vector_error_diagnostics(error: Exception) -> None:
|
||||
"""Log diagnostic info when 'type vector does not exist' error occurs.
|
||||
|
||||
Note: Diagnostic queries use query_raw_with_schema which may run on a different
|
||||
pooled connection than the one that failed. Session-level search_path can differ,
|
||||
so these diagnostics show cluster-wide state, not necessarily the failed session.
|
||||
|
||||
Includes rate limiting to avoid log spam - only logs once per minute.
|
||||
Caller should re-raise the error after calling this function.
|
||||
"""
|
||||
global _last_vector_diag_time
|
||||
|
||||
# Check if this is the vector type error
|
||||
error_str = str(error).lower()
|
||||
if not (
|
||||
"type" in error_str and "vector" in error_str and "does not exist" in error_str
|
||||
):
|
||||
return
|
||||
|
||||
# Rate limit: only log once per interval
|
||||
now = time.time()
|
||||
if now - _last_vector_diag_time < _VECTOR_DIAG_INTERVAL_SECONDS:
|
||||
return
|
||||
_last_vector_diag_time = now
|
||||
|
||||
try:
|
||||
diagnostics: dict[str, object] = {}
|
||||
|
||||
try:
|
||||
search_path_result = await query_raw_with_schema("SHOW search_path")
|
||||
diagnostics["search_path"] = search_path_result
|
||||
except Exception as e:
|
||||
diagnostics["search_path"] = f"Error: {e}"
|
||||
|
||||
try:
|
||||
schema_result = await query_raw_with_schema("SELECT current_schema()")
|
||||
diagnostics["current_schema"] = schema_result
|
||||
except Exception as e:
|
||||
diagnostics["current_schema"] = f"Error: {e}"
|
||||
|
||||
try:
|
||||
user_result = await query_raw_with_schema(
|
||||
"SELECT current_user, session_user, current_database()"
|
||||
)
|
||||
diagnostics["user_info"] = user_result
|
||||
except Exception as e:
|
||||
diagnostics["user_info"] = f"Error: {e}"
|
||||
|
||||
try:
|
||||
# Check pgvector extension installation (cluster-wide, stable info)
|
||||
ext_result = await query_raw_with_schema(
|
||||
"SELECT extname, extversion, nspname as schema "
|
||||
"FROM pg_extension e "
|
||||
"JOIN pg_namespace n ON e.extnamespace = n.oid "
|
||||
"WHERE extname = 'vector'"
|
||||
)
|
||||
diagnostics["pgvector_extension"] = ext_result
|
||||
except Exception as e:
|
||||
diagnostics["pgvector_extension"] = f"Error: {e}"
|
||||
|
||||
logger.error(
|
||||
f"Vector type error diagnostics:\n"
|
||||
f" Error: {error}\n"
|
||||
f" search_path: {diagnostics.get('search_path')}\n"
|
||||
f" current_schema: {diagnostics.get('current_schema')}\n"
|
||||
f" user_info: {diagnostics.get('user_info')}\n"
|
||||
f" pgvector_extension: {diagnostics.get('pgvector_extension')}"
|
||||
)
|
||||
except Exception as diag_error:
|
||||
logger.error(f"Failed to collect vector error diagnostics: {diag_error}")
|
||||
|
||||
|
||||
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
||||
# for existing code that expects the popularity parameter
|
||||
HybridSearchWeights = StoreAgentSearchWeights
|
||||
|
||||
@@ -16,7 +16,7 @@ from backend.blocks.ideogram import (
|
||||
StyleType,
|
||||
UpscaleOption,
|
||||
)
|
||||
from backend.data.graph import BaseGraph
|
||||
from backend.data.graph import GraphBaseMeta
|
||||
from backend.data.model import CredentialsMetaInput, ProviderName
|
||||
from backend.integrations.credentials_store import ideogram_credentials
|
||||
from backend.util.request import Requests
|
||||
@@ -34,14 +34,14 @@ class ImageStyle(str, Enum):
|
||||
DIGITAL_ART = "digital art"
|
||||
|
||||
|
||||
async def generate_agent_image(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||
async def generate_agent_image(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
||||
if settings.config.use_agent_image_generation_v2:
|
||||
return await generate_agent_image_v2(graph=agent)
|
||||
else:
|
||||
return await generate_agent_image_v1(agent=agent)
|
||||
|
||||
|
||||
async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||
async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
||||
"""
|
||||
Generate an image for an agent using Ideogram model.
|
||||
Returns:
|
||||
@@ -54,14 +54,17 @@ async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||
description = f"{name} ({graph.description})" if graph.description else name
|
||||
|
||||
prompt = (
|
||||
f"Create a visually striking retro-futuristic vector pop art illustration prominently featuring "
|
||||
f'"{name}" in bold typography. The image clearly and literally depicts a {description}, '
|
||||
f"along with recognizable objects directly associated with the primary function of a {name}. "
|
||||
f"Ensure the imagery is concrete, intuitive, and immediately understandable, clearly conveying the "
|
||||
f"purpose of a {name}. Maintain vibrant, limited-palette colors, sharp vector lines, geometric "
|
||||
f"shapes, flat illustration techniques, and solid colors without gradients or shading. Preserve a "
|
||||
f"retro-futuristic aesthetic influenced by mid-century futurism and 1960s psychedelia, "
|
||||
f"prioritizing clear visual storytelling and thematic clarity above all else."
|
||||
"Create a visually striking retro-futuristic vector pop art illustration "
|
||||
f'prominently featuring "{name}" in bold typography. The image clearly and '
|
||||
f"literally depicts a {description}, along with recognizable objects directly "
|
||||
f"associated with the primary function of a {name}. "
|
||||
f"Ensure the imagery is concrete, intuitive, and immediately understandable, "
|
||||
f"clearly conveying the purpose of a {name}. "
|
||||
"Maintain vibrant, limited-palette colors, sharp vector lines, "
|
||||
"geometric shapes, flat illustration techniques, and solid colors "
|
||||
"without gradients or shading. Preserve a retro-futuristic aesthetic "
|
||||
"influenced by mid-century futurism and 1960s psychedelia, "
|
||||
"prioritizing clear visual storytelling and thematic clarity above all else."
|
||||
)
|
||||
|
||||
custom_colors = [
|
||||
@@ -99,12 +102,12 @@ async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||
return io.BytesIO(response.content)
|
||||
|
||||
|
||||
async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||
async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
||||
"""
|
||||
Generate an image for an agent using Flux model via Replicate API.
|
||||
|
||||
Args:
|
||||
agent (Graph): The agent to generate an image for
|
||||
agent (GraphBaseMeta | AgentGraph): The agent to generate an image for
|
||||
|
||||
Returns:
|
||||
io.BytesIO: The generated image as bytes
|
||||
@@ -114,7 +117,13 @@ async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||
raise ValueError("Missing Replicate API key in settings")
|
||||
|
||||
# Construct prompt from agent details
|
||||
prompt = f"Create a visually engaging app store thumbnail for the AI agent that highlights what it does in a clear and captivating way:\n- **Name**: {agent.name}\n- **Description**: {agent.description}\nFocus on showcasing its core functionality with an appealing design."
|
||||
prompt = (
|
||||
"Create a visually engaging app store thumbnail for the AI agent "
|
||||
"that highlights what it does in a clear and captivating way:\n"
|
||||
f"- **Name**: {agent.name}\n"
|
||||
f"- **Description**: {agent.description}\n"
|
||||
f"Focus on showcasing its core functionality with an appealing design."
|
||||
)
|
||||
|
||||
# Set up Replicate client
|
||||
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
||||
|
||||
@@ -278,7 +278,7 @@ async def get_agent(
|
||||
)
|
||||
async def get_graph_meta_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> backend.data.graph.GraphMeta:
|
||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||
"""
|
||||
Get Agent Graph from Store Listing Version ID.
|
||||
"""
|
||||
|
||||
@@ -101,7 +101,6 @@ from backend.util.timezone_utils import (
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
from .library import db as library_db
|
||||
from .library import model as library_model
|
||||
from .store.model import StoreAgentDetails
|
||||
|
||||
|
||||
@@ -823,18 +822,16 @@ async def update_graph(
|
||||
graph: graph_db.Graph,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> graph_db.GraphModel:
|
||||
# Sanity check
|
||||
if graph.id and graph.id != graph_id:
|
||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||
|
||||
# Determine new version
|
||||
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||
if not existing_versions:
|
||||
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
||||
latest_version_number = max(g.version for g in existing_versions)
|
||||
graph.version = latest_version_number + 1
|
||||
|
||||
graph.version = max(g.version for g in existing_versions) + 1
|
||||
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
||||
|
||||
graph = graph_db.make_graph_model(graph, user_id)
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||
graph.validate_graph(for_run=False)
|
||||
@@ -842,27 +839,23 @@ async def update_graph(
|
||||
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
||||
|
||||
if new_graph_version.is_active:
|
||||
# Keep the library agent up to date with the new active version
|
||||
await _update_library_agent_version_and_settings(user_id, new_graph_version)
|
||||
|
||||
# Handle activation of the new graph first to ensure continuity
|
||||
await library_db.update_library_agent_version_and_settings(
|
||||
user_id, new_graph_version
|
||||
)
|
||||
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
||||
# Ensure new version is the only active version
|
||||
await graph_db.set_graph_active_version(
|
||||
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
||||
)
|
||||
if current_active_version:
|
||||
# Handle deactivation of the previously active version
|
||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||
|
||||
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
|
||||
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
||||
graph_id,
|
||||
new_graph_version.version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
assert new_graph_version_with_subgraphs # make type checker happy
|
||||
assert new_graph_version_with_subgraphs
|
||||
return new_graph_version_with_subgraphs
|
||||
|
||||
|
||||
@@ -900,33 +893,15 @@ async def set_graph_active_version(
|
||||
)
|
||||
|
||||
# Keep the library agent up to date with the new active version
|
||||
await _update_library_agent_version_and_settings(user_id, new_active_graph)
|
||||
await library_db.update_library_agent_version_and_settings(
|
||||
user_id, new_active_graph
|
||||
)
|
||||
|
||||
if current_active_graph and current_active_graph.version != new_active_version:
|
||||
# Handle deactivation of the previously active version
|
||||
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
||||
|
||||
|
||||
async def _update_library_agent_version_and_settings(
|
||||
user_id: str, agent_graph: graph_db.GraphModel
|
||||
) -> library_model.LibraryAgent:
|
||||
library = await library_db.update_agent_version_in_library(
|
||||
user_id, agent_graph.id, agent_graph.version
|
||||
)
|
||||
updated_settings = GraphSettings.from_graph(
|
||||
graph=agent_graph,
|
||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||
)
|
||||
if updated_settings != library.settings:
|
||||
library = await library_db.update_library_agent(
|
||||
library_agent_id=library.id,
|
||||
user_id=user_id,
|
||||
settings=updated_settings,
|
||||
)
|
||||
return library
|
||||
|
||||
|
||||
@v1_router.patch(
|
||||
path="/graphs/{graph_id}/settings",
|
||||
summary="Update graph settings",
|
||||
|
||||
@@ -40,6 +40,10 @@ import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.api.features.chat.completion_consumer import (
|
||||
start_completion_consumer,
|
||||
stop_completion_consumer,
|
||||
)
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -118,9 +122,21 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
# Start chat completion consumer for Redis Streams notifications
|
||||
try:
|
||||
await start_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not start chat completion consumer: {e}")
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
# Stop chat completion consumer
|
||||
try:
|
||||
await stop_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping chat completion consumer: {e}")
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
|
||||
28
autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py
Normal file
28
autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""ElevenLabs integration blocks - test credentials and shared utilities."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="elevenlabs",
|
||||
api_key=SecretStr("mock-elevenlabs-api-key"),
|
||||
title="Mock ElevenLabs API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
ElevenLabsCredentials = APIKeyCredentials
|
||||
ElevenLabsCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.ELEVENLABS], Literal["api_key"]
|
||||
]
|
||||
77
autogpt_platform/backend/backend/blocks/encoder_block.py
Normal file
77
autogpt_platform/backend/backend/blocks/encoder_block.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Text encoding block for converting special characters to escape sequences."""
|
||||
|
||||
import codecs
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TextEncoderBlock(Block):
|
||||
"""
|
||||
Encodes a string by converting special characters into escape sequences.
|
||||
|
||||
This block is the inverse of TextDecoderBlock. It takes text containing
|
||||
special characters (like newlines, tabs, etc.) and converts them into
|
||||
their escape sequence representations (e.g., newline becomes \\n).
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
"""Input schema for TextEncoderBlock."""
|
||||
|
||||
text: str = SchemaField(
|
||||
description="A string containing special characters to be encoded",
|
||||
placeholder="Your text with newlines and quotes to encode",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
"""Output schema for TextEncoderBlock."""
|
||||
|
||||
encoded_text: str = SchemaField(
|
||||
description="The encoded text with special characters converted to escape sequences"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if encoding fails")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5185f32e-4b65-4ecf-8fbb-873f003f09d6",
|
||||
description="Encodes a string by converting special characters into escape sequences",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=TextEncoderBlock.Input,
|
||||
output_schema=TextEncoderBlock.Output,
|
||||
test_input={
|
||||
"text": """Hello
|
||||
World!
|
||||
This is a "quoted" string."""
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"encoded_text",
|
||||
"""Hello\\nWorld!\\nThis is a "quoted" string.""",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Encode the input text by converting special characters to escape sequences.
|
||||
|
||||
Args:
|
||||
input_data: The input containing the text to encode.
|
||||
**kwargs: Additional keyword arguments (unused).
|
||||
|
||||
Yields:
|
||||
The encoded text with escape sequences, or an error message if encoding fails.
|
||||
"""
|
||||
try:
|
||||
encoded_text = codecs.encode(input_data.text, "unicode_escape").decode(
|
||||
"utf-8"
|
||||
)
|
||||
yield "encoded_text", encoded_text
|
||||
except Exception as e:
|
||||
yield "error", f"Encoding error: {str(e)}"
|
||||
@@ -478,7 +478,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
webset = aexa.websets.get(id=input_data.external_id)
|
||||
webset = await aexa.websets.get(id=input_data.external_id)
|
||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||
|
||||
yield "webset", webset_result
|
||||
@@ -494,7 +494,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
||||
count=input_data.search_count,
|
||||
)
|
||||
|
||||
webset = aexa.websets.create(
|
||||
webset = await aexa.websets.create(
|
||||
params=CreateWebsetParameters(
|
||||
search=search_params,
|
||||
external_id=input_data.external_id,
|
||||
@@ -554,7 +554,7 @@ class ExaUpdateWebsetBlock(Block):
|
||||
if input_data.metadata is not None:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_webset = aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||
|
||||
status_str = (
|
||||
sdk_webset.status.value
|
||||
@@ -617,7 +617,7 @@ class ExaListWebsetsBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
response = aexa.websets.list(
|
||||
response = await aexa.websets.list(
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
@@ -678,7 +678,7 @@ class ExaGetWebsetBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_webset = aexa.websets.get(id=input_data.webset_id)
|
||||
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
status_str = (
|
||||
sdk_webset.status.value
|
||||
@@ -748,7 +748,7 @@ class ExaDeleteWebsetBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_webset = aexa.websets.delete(id=input_data.webset_id)
|
||||
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
|
||||
|
||||
status_str = (
|
||||
deleted_webset.status.value
|
||||
@@ -798,7 +798,7 @@ class ExaCancelWebsetBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_webset = aexa.websets.cancel(id=input_data.webset_id)
|
||||
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
|
||||
|
||||
status_str = (
|
||||
canceled_webset.status.value
|
||||
@@ -968,7 +968,7 @@ class ExaPreviewWebsetBlock(Block):
|
||||
entity["description"] = input_data.entity_description
|
||||
payload["entity"] = entity
|
||||
|
||||
sdk_preview = aexa.websets.preview(params=payload)
|
||||
sdk_preview = await aexa.websets.preview(params=payload)
|
||||
|
||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||
|
||||
@@ -1051,7 +1051,7 @@ class ExaWebsetStatusBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
status = (
|
||||
webset.status.value
|
||||
@@ -1185,7 +1185,7 @@ class ExaWebsetSummaryBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
# Extract basic info
|
||||
webset_id = webset.id
|
||||
@@ -1211,7 +1211,7 @@ class ExaWebsetSummaryBlock(Block):
|
||||
total_items = 0
|
||||
|
||||
if input_data.include_sample_items and input_data.sample_size > 0:
|
||||
items_response = aexa.websets.items.list(
|
||||
items_response = await aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||
)
|
||||
sample_items_data = [
|
||||
@@ -1362,7 +1362,7 @@ class ExaWebsetReadyCheckBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Get webset details
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
status = (
|
||||
webset.status.value
|
||||
|
||||
@@ -202,7 +202,7 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_enrichment = aexa.websets.enrichments.create(
|
||||
sdk_enrichment = await aexa.websets.enrichments.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
@@ -223,7 +223,7 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
items_enriched = 0
|
||||
|
||||
while time.time() - poll_start < input_data.polling_timeout:
|
||||
current_enrich = aexa.websets.enrichments.get(
|
||||
current_enrich = await aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=enrichment_id
|
||||
)
|
||||
current_status = (
|
||||
@@ -234,7 +234,7 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
|
||||
if current_status in ["completed", "failed", "cancelled"]:
|
||||
# Estimate items from webset searches
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
@@ -329,7 +329,7 @@ class ExaGetEnrichmentBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_enrichment = aexa.websets.enrichments.get(
|
||||
sdk_enrichment = await aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
@@ -474,7 +474,7 @@ class ExaDeleteEnrichmentBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_enrichment = aexa.websets.enrichments.delete(
|
||||
deleted_enrichment = await aexa.websets.enrichments.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
@@ -525,13 +525,13 @@ class ExaCancelEnrichmentBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_enrichment = aexa.websets.enrichments.cancel(
|
||||
canceled_enrichment = await aexa.websets.enrichments.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
# Try to estimate how many items were enriched before cancellation
|
||||
items_enriched = 0
|
||||
items_response = aexa.websets.items.list(
|
||||
items_response = await aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=100
|
||||
)
|
||||
|
||||
|
||||
@@ -222,7 +222,7 @@ class ExaCreateImportBlock(Block):
|
||||
def _create_test_mock():
|
||||
"""Create test mocks for the AsyncExa SDK."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
# Create mock SDK import object
|
||||
mock_import = MagicMock()
|
||||
@@ -247,7 +247,7 @@ class ExaCreateImportBlock(Block):
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(
|
||||
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
|
||||
imports=MagicMock(create=AsyncMock(return_value=mock_import))
|
||||
)
|
||||
)
|
||||
}
|
||||
@@ -294,7 +294,7 @@ class ExaCreateImportBlock(Block):
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_import = aexa.websets.imports.create(
|
||||
sdk_import = await aexa.websets.imports.create(
|
||||
params=payload, csv_data=input_data.csv_data
|
||||
)
|
||||
|
||||
@@ -360,7 +360,7 @@ class ExaGetImportBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
|
||||
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
|
||||
|
||||
import_obj = ImportModel.from_sdk(sdk_import)
|
||||
|
||||
@@ -426,7 +426,7 @@ class ExaListImportsBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
response = aexa.websets.imports.list(
|
||||
response = await aexa.websets.imports.list(
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
@@ -474,7 +474,9 @@ class ExaDeleteImportBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
|
||||
deleted_import = await aexa.websets.imports.delete(
|
||||
import_id=input_data.import_id
|
||||
)
|
||||
|
||||
yield "import_id", deleted_import.id
|
||||
yield "success", "true"
|
||||
@@ -573,14 +575,14 @@ class ExaExportWebsetBlock(Block):
|
||||
}
|
||||
)
|
||||
|
||||
# Create mock iterator
|
||||
mock_items = [mock_item1, mock_item2]
|
||||
# Create async iterator for list_all
|
||||
async def async_item_iterator(*args, **kwargs):
|
||||
for item in [mock_item1, mock_item2]:
|
||||
yield item
|
||||
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(
|
||||
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
|
||||
)
|
||||
websets=MagicMock(items=MagicMock(list_all=async_item_iterator))
|
||||
)
|
||||
}
|
||||
|
||||
@@ -602,7 +604,7 @@ class ExaExportWebsetBlock(Block):
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
|
||||
for sdk_item in item_iterator:
|
||||
async for sdk_item in item_iterator:
|
||||
if len(all_items) >= input_data.max_items:
|
||||
break
|
||||
|
||||
|
||||
@@ -178,7 +178,7 @@ class ExaGetWebsetItemBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_item = aexa.websets.items.get(
|
||||
sdk_item = await aexa.websets.items.get(
|
||||
webset_id=input_data.webset_id, id=input_data.item_id
|
||||
)
|
||||
|
||||
@@ -269,7 +269,7 @@ class ExaListWebsetItemsBlock(Block):
|
||||
response = None
|
||||
|
||||
while time.time() - start_time < input_data.wait_timeout:
|
||||
response = aexa.websets.items.list(
|
||||
response = await aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
@@ -282,13 +282,13 @@ class ExaListWebsetItemsBlock(Block):
|
||||
interval = min(interval * 1.2, 10)
|
||||
|
||||
if not response:
|
||||
response = aexa.websets.items.list(
|
||||
response = await aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
else:
|
||||
response = aexa.websets.items.list(
|
||||
response = await aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
@@ -340,7 +340,7 @@ class ExaDeleteWebsetItemBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_item = aexa.websets.items.delete(
|
||||
deleted_item = await aexa.websets.items.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.item_id
|
||||
)
|
||||
|
||||
@@ -408,7 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
|
||||
for sdk_item in item_iterator:
|
||||
async for sdk_item in item_iterator:
|
||||
if len(all_items) >= input_data.max_items:
|
||||
break
|
||||
|
||||
@@ -475,7 +475,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
entity_type = "unknown"
|
||||
if webset.searches:
|
||||
@@ -495,7 +495,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
||||
# Get sample items if requested
|
||||
sample_items: List[WebsetItemModel] = []
|
||||
if input_data.sample_size > 0:
|
||||
items_response = aexa.websets.items.list(
|
||||
items_response = await aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||
)
|
||||
# Convert to our stable models
|
||||
@@ -569,7 +569,7 @@ class ExaGetNewItemsBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Get items starting from cursor
|
||||
response = aexa.websets.items.list(
|
||||
response = await aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.since_cursor,
|
||||
limit=input_data.max_items,
|
||||
|
||||
@@ -233,7 +233,7 @@ class ExaCreateMonitorBlock(Block):
|
||||
def _create_test_mock():
|
||||
"""Create test mocks for the AsyncExa SDK."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
# Create mock SDK monitor object
|
||||
mock_monitor = MagicMock()
|
||||
@@ -263,7 +263,7 @@ class ExaCreateMonitorBlock(Block):
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(
|
||||
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
|
||||
monitors=MagicMock(create=AsyncMock(return_value=mock_monitor))
|
||||
)
|
||||
)
|
||||
}
|
||||
@@ -320,7 +320,7 @@ class ExaCreateMonitorBlock(Block):
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_monitor = aexa.websets.monitors.create(params=payload)
|
||||
sdk_monitor = await aexa.websets.monitors.create(params=payload)
|
||||
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
@@ -384,7 +384,7 @@ class ExaGetMonitorBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
@@ -476,7 +476,7 @@ class ExaUpdateMonitorBlock(Block):
|
||||
if input_data.metadata is not None:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_monitor = aexa.websets.monitors.update(
|
||||
sdk_monitor = await aexa.websets.monitors.update(
|
||||
monitor_id=input_data.monitor_id, params=payload
|
||||
)
|
||||
|
||||
@@ -522,7 +522,9 @@ class ExaDeleteMonitorBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
|
||||
deleted_monitor = await aexa.websets.monitors.delete(
|
||||
monitor_id=input_data.monitor_id
|
||||
)
|
||||
|
||||
yield "monitor_id", deleted_monitor.id
|
||||
yield "success", "true"
|
||||
@@ -579,7 +581,7 @@ class ExaListMonitorsBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
response = aexa.websets.monitors.list(
|
||||
response = await aexa.websets.monitors.list(
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
webset_id=input_data.webset_id,
|
||||
|
||||
@@ -121,7 +121,7 @@ class ExaWaitForWebsetBlock(Block):
|
||||
WebsetTargetStatus.IDLE,
|
||||
WebsetTargetStatus.ANY_COMPLETE,
|
||||
]:
|
||||
final_webset = aexa.websets.wait_until_idle(
|
||||
final_webset = await aexa.websets.wait_until_idle(
|
||||
id=input_data.webset_id,
|
||||
timeout=input_data.timeout,
|
||||
poll_interval=input_data.check_interval,
|
||||
@@ -164,7 +164,7 @@ class ExaWaitForWebsetBlock(Block):
|
||||
interval = input_data.check_interval
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current webset status
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
current_status = (
|
||||
webset.status.value
|
||||
if hasattr(webset.status, "value")
|
||||
@@ -209,7 +209,7 @@ class ExaWaitForWebsetBlock(Block):
|
||||
|
||||
# Timeout reached
|
||||
elapsed = time.time() - start_time
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
final_status = (
|
||||
webset.status.value
|
||||
if hasattr(webset.status, "value")
|
||||
@@ -345,7 +345,7 @@ class ExaWaitForSearchBlock(Block):
|
||||
try:
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current search status using SDK
|
||||
search = aexa.websets.searches.get(
|
||||
search = await aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
@@ -401,7 +401,7 @@ class ExaWaitForSearchBlock(Block):
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Get last known status
|
||||
search = aexa.websets.searches.get(
|
||||
search = await aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
final_status = (
|
||||
@@ -503,7 +503,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
try:
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current enrichment status using SDK
|
||||
enrichment = aexa.websets.enrichments.get(
|
||||
enrichment = await aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
@@ -548,7 +548,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Get last known status
|
||||
enrichment = aexa.websets.enrichments.get(
|
||||
enrichment = await aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
final_status = (
|
||||
@@ -575,7 +575,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
) -> tuple[list[SampleEnrichmentModel], int]:
|
||||
"""Get sample enriched data and count."""
|
||||
# Get a few items to see enrichment results using SDK
|
||||
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||
response = await aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||
|
||||
sample_data: list[SampleEnrichmentModel] = []
|
||||
enriched_count = 0
|
||||
|
||||
@@ -317,7 +317,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
||||
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_search = aexa.websets.searches.create(
|
||||
sdk_search = await aexa.websets.searches.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
@@ -350,7 +350,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
||||
poll_start = time.time()
|
||||
|
||||
while time.time() - poll_start < input_data.polling_timeout:
|
||||
current_search = aexa.websets.searches.get(
|
||||
current_search = await aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=search_id
|
||||
)
|
||||
current_status = (
|
||||
@@ -442,7 +442,7 @@ class ExaGetWebsetSearchBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_search = aexa.websets.searches.get(
|
||||
sdk_search = await aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
@@ -523,7 +523,7 @@ class ExaCancelWebsetSearchBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_search = aexa.websets.searches.cancel(
|
||||
canceled_search = await aexa.websets.searches.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
@@ -604,7 +604,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Get webset to check existing searches
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
# Look for existing search with same query
|
||||
existing_search = None
|
||||
@@ -636,7 +636,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
||||
if input_data.entity_type != SearchEntityType.AUTO:
|
||||
payload["entity"] = {"type": input_data.entity_type.value}
|
||||
|
||||
sdk_search = aexa.websets.searches.create(
|
||||
sdk_search = await aexa.websets.searches.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
|
||||
@@ -162,8 +162,16 @@ class LinearClient:
|
||||
"searchTerm": team_name,
|
||||
}
|
||||
|
||||
team_id = await self.query(query, variables)
|
||||
return team_id["teams"]["nodes"][0]["id"]
|
||||
result = await self.query(query, variables)
|
||||
nodes = result["teams"]["nodes"]
|
||||
|
||||
if not nodes:
|
||||
raise LinearAPIException(
|
||||
f"Team '{team_name}' not found. Check the team name or key and try again.",
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
return nodes[0]["id"]
|
||||
except LinearAPIException as e:
|
||||
raise e
|
||||
|
||||
@@ -240,17 +248,44 @@ class LinearClient:
|
||||
except LinearAPIException as e:
|
||||
raise e
|
||||
|
||||
async def try_search_issues(self, term: str) -> list[Issue]:
|
||||
async def try_search_issues(
|
||||
self,
|
||||
term: str,
|
||||
max_results: int = 10,
|
||||
team_id: str | None = None,
|
||||
) -> list[Issue]:
|
||||
try:
|
||||
query = """
|
||||
query SearchIssues($term: String!, $includeComments: Boolean!) {
|
||||
searchIssues(term: $term, includeComments: $includeComments) {
|
||||
query SearchIssues(
|
||||
$term: String!,
|
||||
$first: Int,
|
||||
$teamId: String
|
||||
) {
|
||||
searchIssues(
|
||||
term: $term,
|
||||
first: $first,
|
||||
teamId: $teamId
|
||||
) {
|
||||
nodes {
|
||||
id
|
||||
identifier
|
||||
title
|
||||
description
|
||||
priority
|
||||
createdAt
|
||||
state {
|
||||
id
|
||||
name
|
||||
type
|
||||
}
|
||||
project {
|
||||
id
|
||||
name
|
||||
}
|
||||
assignee {
|
||||
id
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -258,7 +293,8 @@ class LinearClient:
|
||||
|
||||
variables: dict[str, Any] = {
|
||||
"term": term,
|
||||
"includeComments": True,
|
||||
"first": max_results,
|
||||
"teamId": team_id,
|
||||
}
|
||||
|
||||
issues = await self.query(query, variables)
|
||||
|
||||
@@ -17,7 +17,7 @@ from ._config import (
|
||||
LinearScope,
|
||||
linear,
|
||||
)
|
||||
from .models import CreateIssueResponse, Issue
|
||||
from .models import CreateIssueResponse, Issue, State
|
||||
|
||||
|
||||
class LinearCreateIssueBlock(Block):
|
||||
@@ -135,9 +135,20 @@ class LinearSearchIssuesBlock(Block):
|
||||
description="Linear credentials with read permissions",
|
||||
required_scopes={LinearScope.READ},
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="Maximum number of results to return",
|
||||
default=10,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
team_name: str | None = SchemaField(
|
||||
description="Optional team name to filter results (e.g., 'Internal', 'Open Source')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
issues: list[Issue] = SchemaField(description="List of issues")
|
||||
error: str = SchemaField(description="Error message if the search failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -145,8 +156,11 @@ class LinearSearchIssuesBlock(Block):
|
||||
description="Searches for issues on Linear",
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
|
||||
test_input={
|
||||
"term": "Test issue",
|
||||
"max_results": 10,
|
||||
"team_name": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS_OAUTH,
|
||||
@@ -156,10 +170,14 @@ class LinearSearchIssuesBlock(Block):
|
||||
[
|
||||
Issue(
|
||||
id="abc123",
|
||||
identifier="abc123",
|
||||
identifier="TST-123",
|
||||
title="Test issue",
|
||||
description="Test description",
|
||||
priority=1,
|
||||
state=State(
|
||||
id="state1", name="In Progress", type="started"
|
||||
),
|
||||
createdAt="2026-01-15T10:00:00.000Z",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -168,10 +186,12 @@ class LinearSearchIssuesBlock(Block):
|
||||
"search_issues": lambda *args, **kwargs: [
|
||||
Issue(
|
||||
id="abc123",
|
||||
identifier="abc123",
|
||||
identifier="TST-123",
|
||||
title="Test issue",
|
||||
description="Test description",
|
||||
priority=1,
|
||||
state=State(id="state1", name="In Progress", type="started"),
|
||||
createdAt="2026-01-15T10:00:00.000Z",
|
||||
)
|
||||
]
|
||||
},
|
||||
@@ -181,10 +201,22 @@ class LinearSearchIssuesBlock(Block):
|
||||
async def search_issues(
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
term: str,
|
||||
max_results: int = 10,
|
||||
team_name: str | None = None,
|
||||
) -> list[Issue]:
|
||||
client = LinearClient(credentials=credentials)
|
||||
response: list[Issue] = await client.try_search_issues(term=term)
|
||||
return response
|
||||
|
||||
# Resolve team name to ID if provided
|
||||
# Raises LinearAPIException with descriptive message if team not found
|
||||
team_id: str | None = None
|
||||
if team_name:
|
||||
team_id = await client.try_get_team_by_name(team_name=team_name)
|
||||
|
||||
return await client.try_search_issues(
|
||||
term=term,
|
||||
max_results=max_results,
|
||||
team_id=team_id,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -196,7 +228,10 @@ class LinearSearchIssuesBlock(Block):
|
||||
"""Execute the issue search"""
|
||||
try:
|
||||
issues = await self.search_issues(
|
||||
credentials=credentials, term=input_data.term
|
||||
credentials=credentials,
|
||||
term=input_data.term,
|
||||
max_results=input_data.max_results,
|
||||
team_name=input_data.team_name,
|
||||
)
|
||||
yield "issues", issues
|
||||
except LinearAPIException as e:
|
||||
|
||||
@@ -36,12 +36,21 @@ class Project(BaseModel):
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class State(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str | None = (
|
||||
None # Workflow state type (e.g., "triage", "backlog", "started", "completed", "canceled")
|
||||
)
|
||||
|
||||
|
||||
class Issue(BaseModel):
|
||||
id: str
|
||||
identifier: str
|
||||
title: str
|
||||
description: str | None
|
||||
priority: int
|
||||
state: State | None = None
|
||||
project: Project | None = None
|
||||
createdAt: str | None = None
|
||||
comments: list[Comment] | None = None
|
||||
|
||||
@@ -32,7 +32,7 @@ from backend.data.model import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.prompt import compress_prompt, estimate_token_count
|
||||
from backend.util.prompt import compress_context, estimate_token_count
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
@@ -115,6 +115,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# AI/ML API models
|
||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||
@@ -270,6 +271,9 @@ MODEL_METADATA = {
|
||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
||||
), # claude-4-sonnet-20250514
|
||||
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-6
|
||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-5-20251101
|
||||
@@ -527,12 +531,12 @@ class LLMResponse(BaseModel):
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
openai_tools: list[dict] | None = None,
|
||||
) -> Iterable[ToolParam] | anthropic.NotGiven:
|
||||
) -> Iterable[ToolParam] | anthropic.Omit:
|
||||
"""
|
||||
Convert OpenAI tool format to Anthropic tool format.
|
||||
"""
|
||||
if not openai_tools or len(openai_tools) == 0:
|
||||
return anthropic.NOT_GIVEN
|
||||
return anthropic.omit
|
||||
|
||||
anthropic_tools = []
|
||||
for tool in openai_tools:
|
||||
@@ -592,10 +596,10 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
||||
|
||||
def get_parallel_tool_calls_param(
|
||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||
):
|
||||
) -> bool | openai.Omit:
|
||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||
return openai.NOT_GIVEN
|
||||
return openai.omit
|
||||
return parallel_tool_calls
|
||||
|
||||
|
||||
@@ -634,11 +638,18 @@ async def llm_call(
|
||||
context_window = llm_model.context_window
|
||||
|
||||
if compress_prompt_to_fit:
|
||||
prompt = compress_prompt(
|
||||
result = await compress_context(
|
||||
messages=prompt,
|
||||
target_tokens=llm_model.context_window // 2,
|
||||
lossy_ok=True,
|
||||
client=None, # Truncation-only, no LLM summarization
|
||||
reserve=0, # Caller handles response token budget separately
|
||||
)
|
||||
if result.error:
|
||||
logger.warning(
|
||||
f"Prompt compression did not meet target: {result.error}. "
|
||||
f"Proceeding with {result.token_count} tokens."
|
||||
)
|
||||
prompt = result.messages
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
|
||||
@@ -1,246 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||
from moviepy.video.fx.Loop import Loop
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class MediaDurationBlock(Block):
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
media_in: MediaFileType = SchemaField(
|
||||
description="Media input (URL, data URI, or local path)."
|
||||
)
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video (True) or audio (False).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
duration: float = SchemaField(
|
||||
description="Duration of the media file (in seconds)."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
||||
description="Block to get the duration of a media file.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=MediaDurationBlock.Input,
|
||||
output_schema=MediaDurationBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# 1) Store the input media locally
|
||||
local_media_path = await store_media_file(
|
||||
file=input_data.media_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
assert execution_context.graph_exec_id is not None
|
||||
media_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_media_path
|
||||
)
|
||||
|
||||
# 2) Load the clip
|
||||
if input_data.is_video:
|
||||
clip = VideoFileClip(media_abspath)
|
||||
else:
|
||||
clip = AudioFileClip(media_abspath)
|
||||
|
||||
yield "duration", clip.duration
|
||||
|
||||
|
||||
class LoopVideoBlock(Block):
|
||||
"""
|
||||
Block for looping (repeating) a video clip until a given duration or number of loops.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="The input video (can be a URL, data URI, or local path)."
|
||||
)
|
||||
# Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`.
|
||||
duration: Optional[float] = SchemaField(
|
||||
description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.",
|
||||
default=None,
|
||||
ge=0.0,
|
||||
)
|
||||
n_loops: Optional[int] = SchemaField(
|
||||
description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).",
|
||||
default=None,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: str = SchemaField(
|
||||
description="Looped video returned either as a relative path or a data URI."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
||||
description="Block to loop a video to a given duration or number of repeats.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=LoopVideoBlock.Input,
|
||||
output_schema=LoopVideoBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
assert execution_context.node_exec_id is not None
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
node_exec_id = execution_context.node_exec_id
|
||||
|
||||
# 1) Store the input video locally
|
||||
local_video_path = await store_media_file(
|
||||
file=input_data.video_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||
|
||||
# 2) Load the clip
|
||||
clip = VideoFileClip(input_abspath)
|
||||
|
||||
# 3) Apply the loop effect
|
||||
looped_clip = clip
|
||||
if input_data.duration:
|
||||
# Loop until we reach the specified duration
|
||||
looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)])
|
||||
elif input_data.n_loops:
|
||||
looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)])
|
||||
else:
|
||||
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
||||
|
||||
assert isinstance(looped_clip, VideoFileClip)
|
||||
|
||||
# 4) Save the looped output
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||
|
||||
looped_clip = looped_clip.with_audio(clip.audio)
|
||||
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||
|
||||
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||
video_out = await store_media_file(
|
||||
file=output_filename,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
|
||||
|
||||
class AddAudioToVideoBlock(Block):
|
||||
"""
|
||||
Block that adds (attaches) an audio track to an existing video.
|
||||
Optionally scale the volume of the new track.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="Video input (URL, data URI, or local path)."
|
||||
)
|
||||
audio_in: MediaFileType = SchemaField(
|
||||
description="Audio input (URL, data URI, or local path)."
|
||||
)
|
||||
volume: float = SchemaField(
|
||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Final video (with attached audio), as a path or data URI."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3503748d-62b6-4425-91d6-725b064af509",
|
||||
description="Block to attach an audio file to a video file using moviepy.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=AddAudioToVideoBlock.Input,
|
||||
output_schema=AddAudioToVideoBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
assert execution_context.node_exec_id is not None
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
node_exec_id = execution_context.node_exec_id
|
||||
|
||||
# 1) Store the inputs locally
|
||||
local_video_path = await store_media_file(
|
||||
file=input_data.video_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
local_audio_path = await store_media_file(
|
||||
file=input_data.audio_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
||||
video_abspath = os.path.join(abs_temp_dir, local_video_path)
|
||||
audio_abspath = os.path.join(abs_temp_dir, local_audio_path)
|
||||
|
||||
# 2) Load video + audio with moviepy
|
||||
video_clip = VideoFileClip(video_abspath)
|
||||
audio_clip = AudioFileClip(audio_abspath)
|
||||
# Optionally scale volume
|
||||
if input_data.volume != 1.0:
|
||||
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
||||
|
||||
# 3) Attach the new audio track
|
||||
final_clip = video_clip.with_audio(audio_clip)
|
||||
|
||||
# 4) Write to output file
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
||||
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||
|
||||
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||
video_out = await store_media_file(
|
||||
file=output_filename,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
@@ -182,10 +182,7 @@ class StagehandObserveBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}")
|
||||
logger.info(
|
||||
f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
@@ -282,10 +279,7 @@ class StagehandActBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}")
|
||||
logger.info(
|
||||
f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
@@ -370,10 +364,7 @@ class StagehandExtractBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}")
|
||||
logger.info(
|
||||
f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
import pytest
|
||||
|
||||
from backend.blocks.encoder_block import TextEncoderBlock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_encoder_basic():
|
||||
"""Test basic encoding of newlines and special characters."""
|
||||
block = TextEncoderBlock()
|
||||
result = []
|
||||
async for output in block.run(TextEncoderBlock.Input(text="Hello\nWorld")):
|
||||
result.append(output)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "encoded_text"
|
||||
assert result[0][1] == "Hello\\nWorld"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_encoder_multiple_escapes():
|
||||
"""Test encoding of multiple escape sequences."""
|
||||
block = TextEncoderBlock()
|
||||
result = []
|
||||
async for output in block.run(
|
||||
TextEncoderBlock.Input(text="Line1\nLine2\tTabbed\rCarriage")
|
||||
):
|
||||
result.append(output)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "encoded_text"
|
||||
assert "\\n" in result[0][1]
|
||||
assert "\\t" in result[0][1]
|
||||
assert "\\r" in result[0][1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_encoder_unicode():
|
||||
"""Test that unicode characters are handled correctly."""
|
||||
block = TextEncoderBlock()
|
||||
result = []
|
||||
async for output in block.run(TextEncoderBlock.Input(text="Hello 世界\n")):
|
||||
result.append(output)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "encoded_text"
|
||||
# Unicode characters should be escaped as \uXXXX sequences
|
||||
assert "\\n" in result[0][1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_encoder_empty_string():
|
||||
"""Test encoding of an empty string."""
|
||||
block = TextEncoderBlock()
|
||||
result = []
|
||||
async for output in block.run(TextEncoderBlock.Input(text="")):
|
||||
result.append(output)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "encoded_text"
|
||||
assert result[0][1] == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_encoder_error_handling():
|
||||
"""Test that encoding errors are handled gracefully."""
|
||||
from unittest.mock import patch
|
||||
|
||||
block = TextEncoderBlock()
|
||||
result = []
|
||||
|
||||
with patch("codecs.encode", side_effect=Exception("Mocked encoding error")):
|
||||
async for output in block.run(TextEncoderBlock.Input(text="test")):
|
||||
result.append(output)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "error"
|
||||
assert "Mocked encoding error" in result[0][1]
|
||||
37
autogpt_platform/backend/backend/blocks/video/__init__.py
Normal file
37
autogpt_platform/backend/backend/blocks/video/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Video editing blocks for AutoGPT Platform.
|
||||
|
||||
This module provides blocks for:
|
||||
- Downloading videos from URLs (YouTube, Vimeo, news sites, direct links)
|
||||
- Clipping/trimming video segments
|
||||
- Concatenating multiple videos
|
||||
- Adding text overlays
|
||||
- Adding AI-generated narration
|
||||
- Getting media duration
|
||||
- Looping videos
|
||||
- Adding audio to videos
|
||||
|
||||
Dependencies:
|
||||
- yt-dlp: For video downloading
|
||||
- moviepy: For video editing operations
|
||||
- elevenlabs: For AI narration (optional)
|
||||
"""
|
||||
|
||||
from backend.blocks.video.add_audio import AddAudioToVideoBlock
|
||||
from backend.blocks.video.clip import VideoClipBlock
|
||||
from backend.blocks.video.concat import VideoConcatBlock
|
||||
from backend.blocks.video.download import VideoDownloadBlock
|
||||
from backend.blocks.video.duration import MediaDurationBlock
|
||||
from backend.blocks.video.loop import LoopVideoBlock
|
||||
from backend.blocks.video.narration import VideoNarrationBlock
|
||||
from backend.blocks.video.text_overlay import VideoTextOverlayBlock
|
||||
|
||||
__all__ = [
|
||||
"AddAudioToVideoBlock",
|
||||
"LoopVideoBlock",
|
||||
"MediaDurationBlock",
|
||||
"VideoClipBlock",
|
||||
"VideoConcatBlock",
|
||||
"VideoDownloadBlock",
|
||||
"VideoNarrationBlock",
|
||||
"VideoTextOverlayBlock",
|
||||
]
|
||||
131
autogpt_platform/backend/backend/blocks/video/_utils.py
Normal file
131
autogpt_platform/backend/backend/blocks/video/_utils.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Shared utilities for video blocks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Known operation tags added by video blocks
|
||||
_VIDEO_OPS = (
|
||||
r"(?:clip|overlay|narrated|looped|concat|audio_attached|with_audio|narration)"
|
||||
)
|
||||
|
||||
# Matches: {node_exec_id}_{operation}_ where node_exec_id contains a UUID
|
||||
_BLOCK_PREFIX_RE = re.compile(
|
||||
r"^[a-zA-Z0-9_-]*"
|
||||
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||
r"[a-zA-Z0-9_-]*"
|
||||
r"_" + _VIDEO_OPS + r"_"
|
||||
)
|
||||
|
||||
# Matches: a lone {node_exec_id}_ prefix (no operation keyword, e.g. download output)
|
||||
_UUID_PREFIX_RE = re.compile(
|
||||
r"^[a-zA-Z0-9_-]*"
|
||||
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||
r"[a-zA-Z0-9_-]*_"
|
||||
)
|
||||
|
||||
|
||||
def extract_source_name(input_path: str, max_length: int = 50) -> str:
|
||||
"""Extract the original source filename by stripping block-generated prefixes.
|
||||
|
||||
Iteratively removes {node_exec_id}_{operation}_ prefixes that accumulate
|
||||
when chaining video blocks, recovering the original human-readable name.
|
||||
|
||||
Safe for plain filenames (no UUID -> no stripping).
|
||||
Falls back to "video" if everything is stripped.
|
||||
"""
|
||||
stem = Path(input_path).stem
|
||||
|
||||
# Pass 1: strip {node_exec_id}_{operation}_ prefixes iteratively
|
||||
while _BLOCK_PREFIX_RE.match(stem):
|
||||
stem = _BLOCK_PREFIX_RE.sub("", stem, count=1)
|
||||
|
||||
# Pass 2: strip a lone {node_exec_id}_ prefix (e.g. from download block)
|
||||
if _UUID_PREFIX_RE.match(stem):
|
||||
stem = _UUID_PREFIX_RE.sub("", stem, count=1)
|
||||
|
||||
if not stem:
|
||||
return "video"
|
||||
|
||||
return stem[:max_length]
|
||||
|
||||
|
||||
def get_video_codecs(output_path: str) -> tuple[str, str]:
|
||||
"""Get appropriate video and audio codecs based on output file extension.
|
||||
|
||||
Args:
|
||||
output_path: Path to the output file (used to determine extension)
|
||||
|
||||
Returns:
|
||||
Tuple of (video_codec, audio_codec)
|
||||
|
||||
Codec mappings:
|
||||
- .mp4: H.264 + AAC (universal compatibility)
|
||||
- .webm: VP8 + Vorbis (web streaming)
|
||||
- .mkv: H.264 + AAC (container supports many codecs)
|
||||
- .mov: H.264 + AAC (Apple QuickTime, widely compatible)
|
||||
- .m4v: H.264 + AAC (Apple iTunes/devices)
|
||||
- .avi: MPEG-4 + MP3 (legacy Windows)
|
||||
"""
|
||||
ext = os.path.splitext(output_path)[1].lower()
|
||||
|
||||
codec_map: dict[str, tuple[str, str]] = {
|
||||
".mp4": ("libx264", "aac"),
|
||||
".webm": ("libvpx", "libvorbis"),
|
||||
".mkv": ("libx264", "aac"),
|
||||
".mov": ("libx264", "aac"),
|
||||
".m4v": ("libx264", "aac"),
|
||||
".avi": ("mpeg4", "libmp3lame"),
|
||||
}
|
||||
|
||||
return codec_map.get(ext, ("libx264", "aac"))
|
||||
|
||||
|
||||
def strip_chapters_inplace(video_path: str) -> None:
|
||||
"""Strip chapter metadata from a media file in-place using ffmpeg.
|
||||
|
||||
MoviePy 2.x crashes with IndexError when parsing files with embedded
|
||||
chapter metadata (https://github.com/Zulko/moviepy/issues/2419).
|
||||
This strips chapters without re-encoding.
|
||||
|
||||
Args:
|
||||
video_path: Absolute path to the media file to strip chapters from.
|
||||
"""
|
||||
base, ext = os.path.splitext(video_path)
|
||||
tmp_path = base + ".tmp" + ext
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
video_path,
|
||||
"-map_chapters",
|
||||
"-1",
|
||||
"-codec",
|
||||
"copy",
|
||||
tmp_path,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning(
|
||||
"ffmpeg chapter strip failed (rc=%d): %s",
|
||||
result.returncode,
|
||||
result.stderr,
|
||||
)
|
||||
return
|
||||
os.replace(tmp_path, video_path)
|
||||
except FileNotFoundError:
|
||||
logger.warning("ffmpeg not found; skipping chapter strip")
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
113
autogpt_platform/backend/backend/blocks/video/add_audio.py
Normal file
113
autogpt_platform/backend/backend/blocks/video/add_audio.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""AddAudioToVideoBlock - Attach an audio track to a video file."""
|
||||
|
||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class AddAudioToVideoBlock(Block):
|
||||
"""Add (attach) an audio track to an existing video."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="Video input (URL, data URI, or local path)."
|
||||
)
|
||||
audio_in: MediaFileType = SchemaField(
|
||||
description="Audio input (URL, data URI, or local path)."
|
||||
)
|
||||
volume: float = SchemaField(
|
||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Final video (with attached audio), as a path or data URI."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3503748d-62b6-4425-91d6-725b064af509",
|
||||
description="Block to attach an audio file to a video file using moviepy.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=AddAudioToVideoBlock.Input,
|
||||
output_schema=AddAudioToVideoBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
assert execution_context.node_exec_id is not None
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
node_exec_id = execution_context.node_exec_id
|
||||
|
||||
# 1) Store the inputs locally
|
||||
local_video_path = await store_media_file(
|
||||
file=input_data.video_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
local_audio_path = await store_media_file(
|
||||
file=input_data.audio_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
video_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||
audio_abspath = get_exec_file_path(graph_exec_id, local_audio_path)
|
||||
|
||||
# 2) Load video + audio with moviepy
|
||||
strip_chapters_inplace(video_abspath)
|
||||
strip_chapters_inplace(audio_abspath)
|
||||
video_clip = None
|
||||
audio_clip = None
|
||||
final_clip = None
|
||||
try:
|
||||
video_clip = VideoFileClip(video_abspath)
|
||||
audio_clip = AudioFileClip(audio_abspath)
|
||||
# Optionally scale volume
|
||||
if input_data.volume != 1.0:
|
||||
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
||||
|
||||
# 3) Attach the new audio track
|
||||
final_clip = video_clip.with_audio(audio_clip)
|
||||
|
||||
# 4) Write to output file
|
||||
source = extract_source_name(local_video_path)
|
||||
output_filename = MediaFileType(f"{node_exec_id}_with_audio_{source}.mp4")
|
||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||
final_clip.write_videofile(
|
||||
output_abspath, codec="libx264", audio_codec="aac"
|
||||
)
|
||||
finally:
|
||||
if final_clip:
|
||||
final_clip.close()
|
||||
if audio_clip:
|
||||
audio_clip.close()
|
||||
if video_clip:
|
||||
video_clip.close()
|
||||
|
||||
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||
video_out = await store_media_file(
|
||||
file=output_filename,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
167
autogpt_platform/backend/backend/blocks/video/clip.py
Normal file
167
autogpt_platform/backend/backend/blocks/video/clip.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""VideoClipBlock - Extract a segment from a video file."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.blocks.video._utils import (
|
||||
extract_source_name,
|
||||
get_video_codecs,
|
||||
strip_chapters_inplace,
|
||||
)
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class VideoClipBlock(Block):
|
||||
"""Extract a time segment from a video."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="Input video (URL, data URI, or local path)"
|
||||
)
|
||||
start_time: float = SchemaField(description="Start time in seconds", ge=0.0)
|
||||
end_time: float = SchemaField(description="End time in seconds", ge=0.0)
|
||||
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
||||
description="Output format", default="mp4", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Clipped video file (path or data URI)"
|
||||
)
|
||||
duration: float = SchemaField(description="Clip duration in seconds")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8f539119-e580-4d86-ad41-86fbcb22abb1",
|
||||
description="Extract a time segment from a video",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"video_in": "/tmp/test.mp4",
|
||||
"start_time": 0.0,
|
||||
"end_time": 10.0,
|
||||
},
|
||||
test_output=[("video_out", str), ("duration", float)],
|
||||
test_mock={
|
||||
"_clip_video": lambda *args: 10.0,
|
||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||
"_store_output_video": lambda *args, **kwargs: "clip_test.mp4",
|
||||
},
|
||||
)
|
||||
|
||||
async def _store_input_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store input video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
async def _store_output_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store output video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
def _clip_video(
|
||||
self,
|
||||
video_abspath: str,
|
||||
output_abspath: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
) -> float:
|
||||
"""Extract a clip from a video. Extracted for testability."""
|
||||
clip = None
|
||||
subclip = None
|
||||
try:
|
||||
strip_chapters_inplace(video_abspath)
|
||||
clip = VideoFileClip(video_abspath)
|
||||
subclip = clip.subclipped(start_time, end_time)
|
||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||
subclip.write_videofile(
|
||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||
)
|
||||
return subclip.duration
|
||||
finally:
|
||||
if subclip:
|
||||
subclip.close()
|
||||
if clip:
|
||||
clip.close()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Validate time range
|
||||
if input_data.end_time <= input_data.start_time:
|
||||
raise BlockExecutionError(
|
||||
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
)
|
||||
|
||||
try:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
|
||||
# Store the input video locally
|
||||
local_video_path = await self._store_input_video(
|
||||
execution_context, input_data.video_in
|
||||
)
|
||||
video_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_video_path
|
||||
)
|
||||
|
||||
# Build output path
|
||||
source = extract_source_name(local_video_path)
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_clip_{source}.{input_data.output_format}"
|
||||
)
|
||||
output_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, output_filename
|
||||
)
|
||||
|
||||
duration = self._clip_video(
|
||||
video_abspath,
|
||||
output_abspath,
|
||||
input_data.start_time,
|
||||
input_data.end_time,
|
||||
)
|
||||
|
||||
# Return as workspace path or data URI based on context
|
||||
video_out = await self._store_output_video(
|
||||
execution_context, output_filename
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
yield "duration", duration
|
||||
|
||||
except BlockExecutionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Failed to clip video: {e}",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
) from e
|
||||
227
autogpt_platform/backend/backend/blocks/video/concat.py
Normal file
227
autogpt_platform/backend/backend/blocks/video/concat.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""VideoConcatBlock - Concatenate multiple video clips into one."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from moviepy import concatenate_videoclips
|
||||
from moviepy.video.fx import CrossFadeIn, CrossFadeOut, FadeIn, FadeOut
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.blocks.video._utils import (
|
||||
extract_source_name,
|
||||
get_video_codecs,
|
||||
strip_chapters_inplace,
|
||||
)
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class VideoConcatBlock(Block):
|
||||
"""Merge multiple video clips into one continuous video."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
videos: list[MediaFileType] = SchemaField(
|
||||
description="List of video files to concatenate (in order)"
|
||||
)
|
||||
transition: Literal["none", "crossfade", "fade_black"] = SchemaField(
|
||||
description="Transition between clips", default="none"
|
||||
)
|
||||
transition_duration: int = SchemaField(
|
||||
description="Transition duration in seconds",
|
||||
default=1,
|
||||
ge=0,
|
||||
advanced=True,
|
||||
)
|
||||
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
||||
description="Output format", default="mp4", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Concatenated video file (path or data URI)"
|
||||
)
|
||||
total_duration: float = SchemaField(description="Total duration in seconds")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9b0f531a-1118-487f-aeec-3fa63ea8900a",
|
||||
description="Merge multiple video clips into one continuous video",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"videos": ["/tmp/a.mp4", "/tmp/b.mp4"],
|
||||
},
|
||||
test_output=[
|
||||
("video_out", str),
|
||||
("total_duration", float),
|
||||
],
|
||||
test_mock={
|
||||
"_concat_videos": lambda *args: 20.0,
|
||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||
"_store_output_video": lambda *args, **kwargs: "concat_test.mp4",
|
||||
},
|
||||
)
|
||||
|
||||
async def _store_input_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store input video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
async def _store_output_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store output video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
def _concat_videos(
|
||||
self,
|
||||
video_abspaths: list[str],
|
||||
output_abspath: str,
|
||||
transition: str,
|
||||
transition_duration: int,
|
||||
) -> float:
|
||||
"""Concatenate videos. Extracted for testability.
|
||||
|
||||
Returns:
|
||||
Total duration of the concatenated video.
|
||||
"""
|
||||
clips = []
|
||||
faded_clips = []
|
||||
final = None
|
||||
try:
|
||||
# Load clips
|
||||
for v in video_abspaths:
|
||||
strip_chapters_inplace(v)
|
||||
clips.append(VideoFileClip(v))
|
||||
|
||||
# Validate transition_duration against shortest clip
|
||||
if transition in {"crossfade", "fade_black"} and transition_duration > 0:
|
||||
min_duration = min(c.duration for c in clips)
|
||||
if transition_duration >= min_duration:
|
||||
raise BlockExecutionError(
|
||||
message=(
|
||||
f"transition_duration ({transition_duration}s) must be "
|
||||
f"shorter than the shortest clip ({min_duration:.2f}s)"
|
||||
),
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
)
|
||||
|
||||
if transition == "crossfade":
|
||||
for i, clip in enumerate(clips):
|
||||
effects = []
|
||||
if i > 0:
|
||||
effects.append(CrossFadeIn(transition_duration))
|
||||
if i < len(clips) - 1:
|
||||
effects.append(CrossFadeOut(transition_duration))
|
||||
if effects:
|
||||
clip = clip.with_effects(effects)
|
||||
faded_clips.append(clip)
|
||||
final = concatenate_videoclips(
|
||||
faded_clips,
|
||||
method="compose",
|
||||
padding=-transition_duration,
|
||||
)
|
||||
elif transition == "fade_black":
|
||||
for clip in clips:
|
||||
faded = clip.with_effects(
|
||||
[FadeIn(transition_duration), FadeOut(transition_duration)]
|
||||
)
|
||||
faded_clips.append(faded)
|
||||
final = concatenate_videoclips(faded_clips)
|
||||
else:
|
||||
final = concatenate_videoclips(clips)
|
||||
|
||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||
final.write_videofile(
|
||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||
)
|
||||
|
||||
return final.duration
|
||||
finally:
|
||||
if final:
|
||||
final.close()
|
||||
for clip in faded_clips:
|
||||
clip.close()
|
||||
for clip in clips:
|
||||
clip.close()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Validate minimum clips
|
||||
if len(input_data.videos) < 2:
|
||||
raise BlockExecutionError(
|
||||
message="At least 2 videos are required for concatenation",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
)
|
||||
|
||||
try:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
|
||||
# Store all input videos locally
|
||||
video_abspaths = []
|
||||
for video in input_data.videos:
|
||||
local_path = await self._store_input_video(execution_context, video)
|
||||
video_abspaths.append(
|
||||
get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||
)
|
||||
|
||||
# Build output path
|
||||
source = (
|
||||
extract_source_name(video_abspaths[0]) if video_abspaths else "video"
|
||||
)
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_concat_{source}.{input_data.output_format}"
|
||||
)
|
||||
output_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, output_filename
|
||||
)
|
||||
|
||||
total_duration = self._concat_videos(
|
||||
video_abspaths,
|
||||
output_abspath,
|
||||
input_data.transition,
|
||||
input_data.transition_duration,
|
||||
)
|
||||
|
||||
# Return as workspace path or data URI based on context
|
||||
video_out = await self._store_output_video(
|
||||
execution_context, output_filename
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
yield "total_duration", total_duration
|
||||
|
||||
except BlockExecutionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Failed to concatenate videos: {e}",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
) from e
|
||||
172
autogpt_platform/backend/backend/blocks/video/download.py
Normal file
172
autogpt_platform/backend/backend/blocks/video/download.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""VideoDownloadBlock - Download video from URL (YouTube, Vimeo, news sites, direct links)."""
|
||||
|
||||
import os
|
||||
import typing
|
||||
from typing import Literal
|
||||
|
||||
import yt_dlp
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from yt_dlp import _Params
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class VideoDownloadBlock(Block):
|
||||
"""Download video from URL using yt-dlp."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
url: str = SchemaField(
|
||||
description="URL of the video to download (YouTube, Vimeo, direct link, etc.)",
|
||||
placeholder="https://www.youtube.com/watch?v=...",
|
||||
)
|
||||
quality: Literal["best", "1080p", "720p", "480p", "audio_only"] = SchemaField(
|
||||
description="Video quality preference", default="720p"
|
||||
)
|
||||
output_format: Literal["mp4", "webm", "mkv"] = SchemaField(
|
||||
description="Output video format", default="mp4", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_file: MediaFileType = SchemaField(
|
||||
description="Downloaded video (path or data URI)"
|
||||
)
|
||||
duration: float = SchemaField(description="Video duration in seconds")
|
||||
title: str = SchemaField(description="Video title from source")
|
||||
source_url: str = SchemaField(description="Original source URL")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c35daabb-cd60-493b-b9ad-51f1fe4b50c4",
|
||||
description="Download video from URL (YouTube, Vimeo, news sites, direct links)",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
disabled=True, # Disable until we can sandbox yt-dlp and handle security implications
|
||||
test_input={
|
||||
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
||||
"quality": "480p",
|
||||
},
|
||||
test_output=[
|
||||
("video_file", str),
|
||||
("duration", float),
|
||||
("title", str),
|
||||
("source_url", str),
|
||||
],
|
||||
test_mock={
|
||||
"_download_video": lambda *args: (
|
||||
"video.mp4",
|
||||
212.0,
|
||||
"Test Video",
|
||||
),
|
||||
"_store_output_video": lambda *args, **kwargs: "video.mp4",
|
||||
},
|
||||
)
|
||||
|
||||
async def _store_output_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store output video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
def _get_format_string(self, quality: str) -> str:
|
||||
formats = {
|
||||
"best": "bestvideo+bestaudio/best",
|
||||
"1080p": "bestvideo[height<=1080]+bestaudio/best[height<=1080]",
|
||||
"720p": "bestvideo[height<=720]+bestaudio/best[height<=720]",
|
||||
"480p": "bestvideo[height<=480]+bestaudio/best[height<=480]",
|
||||
"audio_only": "bestaudio/best",
|
||||
}
|
||||
return formats.get(quality, formats["720p"])
|
||||
|
||||
def _download_video(
|
||||
self,
|
||||
url: str,
|
||||
quality: str,
|
||||
output_format: str,
|
||||
output_dir: str,
|
||||
node_exec_id: str,
|
||||
) -> tuple[str, float, str]:
|
||||
"""Download video. Extracted for testability."""
|
||||
output_template = os.path.join(
|
||||
output_dir, f"{node_exec_id}_%(title).50s.%(ext)s"
|
||||
)
|
||||
|
||||
ydl_opts: "_Params" = {
|
||||
"format": f"{self._get_format_string(quality)}/best",
|
||||
"outtmpl": output_template,
|
||||
"merge_output_format": output_format,
|
||||
"quiet": True,
|
||||
"no_warnings": True,
|
||||
}
|
||||
|
||||
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||
info = ydl.extract_info(url, download=True)
|
||||
video_path = ydl.prepare_filename(info)
|
||||
|
||||
# Handle format conversion in filename
|
||||
if not video_path.endswith(f".{output_format}"):
|
||||
video_path = video_path.rsplit(".", 1)[0] + f".{output_format}"
|
||||
|
||||
# Return just the filename, not the full path
|
||||
filename = os.path.basename(video_path)
|
||||
|
||||
return (
|
||||
filename,
|
||||
info.get("duration") or 0.0,
|
||||
info.get("title") or "Unknown",
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
|
||||
# Get the exec file directory
|
||||
output_dir = get_exec_file_path(execution_context.graph_exec_id, "")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
filename, duration, title = self._download_video(
|
||||
input_data.url,
|
||||
input_data.quality,
|
||||
input_data.output_format,
|
||||
output_dir,
|
||||
node_exec_id,
|
||||
)
|
||||
|
||||
# Return as workspace path or data URI based on context
|
||||
video_out = await self._store_output_video(
|
||||
execution_context, MediaFileType(filename)
|
||||
)
|
||||
|
||||
yield "video_file", video_out
|
||||
yield "duration", duration
|
||||
yield "title", title
|
||||
yield "source_url", input_data.url
|
||||
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Failed to download video: {e}",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
) from e
|
||||
77
autogpt_platform/backend/backend/blocks/video/duration.py
Normal file
77
autogpt_platform/backend/backend/blocks/video/duration.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""MediaDurationBlock - Get the duration of a media file."""
|
||||
|
||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.blocks.video._utils import strip_chapters_inplace
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class MediaDurationBlock(Block):
|
||||
"""Get the duration of a media file (video or audio)."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
media_in: MediaFileType = SchemaField(
|
||||
description="Media input (URL, data URI, or local path)."
|
||||
)
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video (True) or audio (False).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
duration: float = SchemaField(
|
||||
description="Duration of the media file (in seconds)."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
||||
description="Block to get the duration of a media file.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=MediaDurationBlock.Input,
|
||||
output_schema=MediaDurationBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# 1) Store the input media locally
|
||||
local_media_path = await store_media_file(
|
||||
file=input_data.media_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
assert execution_context.graph_exec_id is not None
|
||||
media_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_media_path
|
||||
)
|
||||
|
||||
# 2) Strip chapters to avoid MoviePy crash, then load the clip
|
||||
strip_chapters_inplace(media_abspath)
|
||||
clip = None
|
||||
try:
|
||||
if input_data.is_video:
|
||||
clip = VideoFileClip(media_abspath)
|
||||
else:
|
||||
clip = AudioFileClip(media_abspath)
|
||||
|
||||
duration = clip.duration
|
||||
finally:
|
||||
if clip:
|
||||
clip.close()
|
||||
|
||||
yield "duration", duration
|
||||
115
autogpt_platform/backend/backend/blocks/video/loop.py
Normal file
115
autogpt_platform/backend/backend/blocks/video/loop.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""LoopVideoBlock - Loop a video to a given duration or number of repeats."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from moviepy.video.fx.Loop import Loop
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class LoopVideoBlock(Block):
|
||||
"""Loop (repeat) a video clip until a given duration or number of loops."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="The input video (can be a URL, data URI, or local path)."
|
||||
)
|
||||
duration: Optional[float] = SchemaField(
|
||||
description="Target duration (in seconds) to loop the video to. Either duration or n_loops must be provided.",
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=3600.0, # Max 1 hour to prevent disk exhaustion
|
||||
)
|
||||
n_loops: Optional[int] = SchemaField(
|
||||
description="Number of times to repeat the video. Either n_loops or duration must be provided.",
|
||||
default=None,
|
||||
ge=1,
|
||||
le=10, # Max 10 loops to prevent disk exhaustion
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Looped video returned either as a relative path or a data URI."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
||||
description="Block to loop a video to a given duration or number of repeats.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=LoopVideoBlock.Input,
|
||||
output_schema=LoopVideoBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
assert execution_context.node_exec_id is not None
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
node_exec_id = execution_context.node_exec_id
|
||||
|
||||
# 1) Store the input video locally
|
||||
local_video_path = await store_media_file(
|
||||
file=input_data.video_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||
|
||||
# 2) Load the clip
|
||||
strip_chapters_inplace(input_abspath)
|
||||
clip = None
|
||||
looped_clip = None
|
||||
try:
|
||||
clip = VideoFileClip(input_abspath)
|
||||
|
||||
# 3) Apply the loop effect
|
||||
if input_data.duration:
|
||||
# Loop until we reach the specified duration
|
||||
looped_clip = clip.with_effects([Loop(duration=input_data.duration)])
|
||||
elif input_data.n_loops:
|
||||
looped_clip = clip.with_effects([Loop(n=input_data.n_loops)])
|
||||
else:
|
||||
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
||||
|
||||
assert isinstance(looped_clip, VideoFileClip)
|
||||
|
||||
# 4) Save the looped output
|
||||
source = extract_source_name(local_video_path)
|
||||
output_filename = MediaFileType(f"{node_exec_id}_looped_{source}.mp4")
|
||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||
|
||||
looped_clip = looped_clip.with_audio(clip.audio)
|
||||
looped_clip.write_videofile(
|
||||
output_abspath, codec="libx264", audio_codec="aac"
|
||||
)
|
||||
finally:
|
||||
if looped_clip:
|
||||
looped_clip.close()
|
||||
if clip:
|
||||
clip.close()
|
||||
|
||||
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||
video_out = await store_media_file(
|
||||
file=output_filename,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
267
autogpt_platform/backend/backend/blocks/video/narration.py
Normal file
267
autogpt_platform/backend/backend/blocks/video/narration.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""VideoNarrationBlock - Generate AI voice narration and add to video."""
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from elevenlabs import ElevenLabs
|
||||
from moviepy import CompositeAudioClip
|
||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.blocks.elevenlabs._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ElevenLabsCredentials,
|
||||
ElevenLabsCredentialsInput,
|
||||
)
|
||||
from backend.blocks.video._utils import (
|
||||
extract_source_name,
|
||||
get_video_codecs,
|
||||
strip_chapters_inplace,
|
||||
)
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class VideoNarrationBlock(Block):
|
||||
"""Generate AI narration and add to video."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: ElevenLabsCredentialsInput = CredentialsField(
|
||||
description="ElevenLabs API key for voice synthesis"
|
||||
)
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="Input video (URL, data URI, or local path)"
|
||||
)
|
||||
script: str = SchemaField(description="Narration script text")
|
||||
voice_id: str = SchemaField(
|
||||
description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||
)
|
||||
model_id: Literal[
|
||||
"eleven_multilingual_v2",
|
||||
"eleven_flash_v2_5",
|
||||
"eleven_turbo_v2_5",
|
||||
"eleven_turbo_v2",
|
||||
] = SchemaField(
|
||||
description="ElevenLabs TTS model",
|
||||
default="eleven_multilingual_v2",
|
||||
)
|
||||
mix_mode: Literal["replace", "mix", "ducking"] = SchemaField(
|
||||
description="How to combine with original audio. 'ducking' applies stronger attenuation than 'mix'.",
|
||||
default="ducking",
|
||||
)
|
||||
narration_volume: float = SchemaField(
|
||||
description="Narration volume (0.0 to 2.0)",
|
||||
default=1.0,
|
||||
ge=0.0,
|
||||
le=2.0,
|
||||
advanced=True,
|
||||
)
|
||||
original_volume: float = SchemaField(
|
||||
description="Original audio volume when mixing (0.0 to 1.0)",
|
||||
default=0.3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Video with narration (path or data URI)"
|
||||
)
|
||||
audio_file: MediaFileType = SchemaField(
|
||||
description="Generated audio file (path or data URI)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3d036b53-859c-4b17-9826-ca340f736e0e",
|
||||
description="Generate AI narration and add to video",
|
||||
categories={BlockCategory.MULTIMEDIA, BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"video_in": "/tmp/test.mp4",
|
||||
"script": "Hello world",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("video_out", str), ("audio_file", str)],
|
||||
test_mock={
|
||||
"_generate_narration_audio": lambda *args: b"mock audio content",
|
||||
"_add_narration_to_video": lambda *args: None,
|
||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||
"_store_output_video": lambda *args, **kwargs: "narrated_test.mp4",
|
||||
},
|
||||
)
|
||||
|
||||
async def _store_input_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store input video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
async def _store_output_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store output video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
def _generate_narration_audio(
|
||||
self, api_key: str, script: str, voice_id: str, model_id: str
|
||||
) -> bytes:
|
||||
"""Generate narration audio via ElevenLabs API."""
|
||||
client = ElevenLabs(api_key=api_key)
|
||||
audio_generator = client.text_to_speech.convert(
|
||||
voice_id=voice_id,
|
||||
text=script,
|
||||
model_id=model_id,
|
||||
)
|
||||
# The SDK returns a generator, collect all chunks
|
||||
return b"".join(audio_generator)
|
||||
|
||||
def _add_narration_to_video(
|
||||
self,
|
||||
video_abspath: str,
|
||||
audio_abspath: str,
|
||||
output_abspath: str,
|
||||
mix_mode: str,
|
||||
narration_volume: float,
|
||||
original_volume: float,
|
||||
) -> None:
|
||||
"""Add narration audio to video. Extracted for testability."""
|
||||
video = None
|
||||
final = None
|
||||
narration_original = None
|
||||
narration_scaled = None
|
||||
original = None
|
||||
|
||||
try:
|
||||
strip_chapters_inplace(video_abspath)
|
||||
video = VideoFileClip(video_abspath)
|
||||
narration_original = AudioFileClip(audio_abspath)
|
||||
narration_scaled = narration_original.with_volume_scaled(narration_volume)
|
||||
narration = narration_scaled
|
||||
|
||||
if mix_mode == "replace":
|
||||
final_audio = narration
|
||||
elif mix_mode == "mix":
|
||||
if video.audio:
|
||||
original = video.audio.with_volume_scaled(original_volume)
|
||||
final_audio = CompositeAudioClip([original, narration])
|
||||
else:
|
||||
final_audio = narration
|
||||
else: # ducking - apply stronger attenuation
|
||||
if video.audio:
|
||||
# Ducking uses a much lower volume for original audio
|
||||
ducking_volume = original_volume * 0.3
|
||||
original = video.audio.with_volume_scaled(ducking_volume)
|
||||
final_audio = CompositeAudioClip([original, narration])
|
||||
else:
|
||||
final_audio = narration
|
||||
|
||||
final = video.with_audio(final_audio)
|
||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||
final.write_videofile(
|
||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||
)
|
||||
|
||||
finally:
|
||||
if original:
|
||||
original.close()
|
||||
if narration_scaled:
|
||||
narration_scaled.close()
|
||||
if narration_original:
|
||||
narration_original.close()
|
||||
if final:
|
||||
final.close()
|
||||
if video:
|
||||
video.close()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: ElevenLabsCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
|
||||
# Store the input video locally
|
||||
local_video_path = await self._store_input_video(
|
||||
execution_context, input_data.video_in
|
||||
)
|
||||
video_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_video_path
|
||||
)
|
||||
|
||||
# Generate narration audio via ElevenLabs
|
||||
audio_content = self._generate_narration_audio(
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.script,
|
||||
input_data.voice_id,
|
||||
input_data.model_id,
|
||||
)
|
||||
|
||||
# Save audio to exec file path
|
||||
audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3")
|
||||
audio_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, audio_filename
|
||||
)
|
||||
os.makedirs(os.path.dirname(audio_abspath), exist_ok=True)
|
||||
with open(audio_abspath, "wb") as f:
|
||||
f.write(audio_content)
|
||||
|
||||
# Add narration to video
|
||||
source = extract_source_name(local_video_path)
|
||||
output_filename = MediaFileType(f"{node_exec_id}_narrated_{source}.mp4")
|
||||
output_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, output_filename
|
||||
)
|
||||
|
||||
self._add_narration_to_video(
|
||||
video_abspath,
|
||||
audio_abspath,
|
||||
output_abspath,
|
||||
input_data.mix_mode,
|
||||
input_data.narration_volume,
|
||||
input_data.original_volume,
|
||||
)
|
||||
|
||||
# Return as workspace path or data URI based on context
|
||||
video_out = await self._store_output_video(
|
||||
execution_context, output_filename
|
||||
)
|
||||
audio_out = await self._store_output_video(
|
||||
execution_context, audio_filename
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
yield "audio_file", audio_out
|
||||
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Failed to add narration: {e}",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
) from e
|
||||
231
autogpt_platform/backend/backend/blocks/video/text_overlay.py
Normal file
231
autogpt_platform/backend/backend/blocks/video/text_overlay.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""VideoTextOverlayBlock - Add text overlay to video."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from moviepy import CompositeVideoClip, TextClip
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.blocks.video._utils import (
|
||||
extract_source_name,
|
||||
get_video_codecs,
|
||||
strip_chapters_inplace,
|
||||
)
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class VideoTextOverlayBlock(Block):
|
||||
"""Add text overlay/caption to video."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="Input video (URL, data URI, or local path)"
|
||||
)
|
||||
text: str = SchemaField(description="Text to overlay on video")
|
||||
position: Literal[
|
||||
"top",
|
||||
"center",
|
||||
"bottom",
|
||||
"top-left",
|
||||
"top-right",
|
||||
"bottom-left",
|
||||
"bottom-right",
|
||||
] = SchemaField(description="Position of text on screen", default="bottom")
|
||||
start_time: float | None = SchemaField(
|
||||
description="When to show text (seconds). None = entire video",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
end_time: float | None = SchemaField(
|
||||
description="When to hide text (seconds). None = until end",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
font_size: int = SchemaField(
|
||||
description="Font size", default=48, ge=12, le=200, advanced=True
|
||||
)
|
||||
font_color: str = SchemaField(
|
||||
description="Font color (hex or name)", default="white", advanced=True
|
||||
)
|
||||
bg_color: str | None = SchemaField(
|
||||
description="Background color behind text (None for transparent)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Video with text overlay (path or data URI)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8ef14de6-cc90-430a-8cfa-3a003be92454",
|
||||
description="Add text overlay/caption to video",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
disabled=True, # Disable until we can lockdown imagemagick security policy
|
||||
test_input={"video_in": "/tmp/test.mp4", "text": "Hello World"},
|
||||
test_output=[("video_out", str)],
|
||||
test_mock={
|
||||
"_add_text_overlay": lambda *args: None,
|
||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||
"_store_output_video": lambda *args, **kwargs: "overlay_test.mp4",
|
||||
},
|
||||
)
|
||||
|
||||
async def _store_input_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store input video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
async def _store_output_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store output video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
def _add_text_overlay(
|
||||
self,
|
||||
video_abspath: str,
|
||||
output_abspath: str,
|
||||
text: str,
|
||||
position: str,
|
||||
start_time: float | None,
|
||||
end_time: float | None,
|
||||
font_size: int,
|
||||
font_color: str,
|
||||
bg_color: str | None,
|
||||
) -> None:
|
||||
"""Add text overlay to video. Extracted for testability."""
|
||||
video = None
|
||||
final = None
|
||||
txt_clip = None
|
||||
try:
|
||||
strip_chapters_inplace(video_abspath)
|
||||
video = VideoFileClip(video_abspath)
|
||||
|
||||
txt_clip = TextClip(
|
||||
text=text,
|
||||
font_size=font_size,
|
||||
color=font_color,
|
||||
bg_color=bg_color,
|
||||
)
|
||||
|
||||
# Position mapping
|
||||
pos_map = {
|
||||
"top": ("center", "top"),
|
||||
"center": ("center", "center"),
|
||||
"bottom": ("center", "bottom"),
|
||||
"top-left": ("left", "top"),
|
||||
"top-right": ("right", "top"),
|
||||
"bottom-left": ("left", "bottom"),
|
||||
"bottom-right": ("right", "bottom"),
|
||||
}
|
||||
|
||||
txt_clip = txt_clip.with_position(pos_map[position])
|
||||
|
||||
# Set timing
|
||||
start = start_time or 0
|
||||
end = end_time or video.duration
|
||||
duration = max(0, end - start)
|
||||
txt_clip = txt_clip.with_start(start).with_end(end).with_duration(duration)
|
||||
|
||||
final = CompositeVideoClip([video, txt_clip])
|
||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||
final.write_videofile(
|
||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||
)
|
||||
|
||||
finally:
|
||||
if txt_clip:
|
||||
txt_clip.close()
|
||||
if final:
|
||||
final.close()
|
||||
if video:
|
||||
video.close()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Validate time range if both are provided
|
||||
if (
|
||||
input_data.start_time is not None
|
||||
and input_data.end_time is not None
|
||||
and input_data.end_time <= input_data.start_time
|
||||
):
|
||||
raise BlockExecutionError(
|
||||
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
)
|
||||
|
||||
try:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
|
||||
# Store the input video locally
|
||||
local_video_path = await self._store_input_video(
|
||||
execution_context, input_data.video_in
|
||||
)
|
||||
video_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_video_path
|
||||
)
|
||||
|
||||
# Build output path
|
||||
source = extract_source_name(local_video_path)
|
||||
output_filename = MediaFileType(f"{node_exec_id}_overlay_{source}.mp4")
|
||||
output_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, output_filename
|
||||
)
|
||||
|
||||
self._add_text_overlay(
|
||||
video_abspath,
|
||||
output_abspath,
|
||||
input_data.text,
|
||||
input_data.position,
|
||||
input_data.start_time,
|
||||
input_data.end_time,
|
||||
input_data.font_size,
|
||||
input_data.font_color,
|
||||
input_data.bg_color,
|
||||
)
|
||||
|
||||
# Return as workspace path or data URI based on context
|
||||
video_out = await self._store_output_video(
|
||||
execution_context, output_filename
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
|
||||
except BlockExecutionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Failed to add text overlay: {e}",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
) from e
|
||||
@@ -165,10 +165,13 @@ class TranscribeYoutubeVideoBlock(Block):
|
||||
credentials: WebshareProxyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
video_id = self.extract_video_id(input_data.youtube_url)
|
||||
yield "video_id", video_id
|
||||
try:
|
||||
video_id = self.extract_video_id(input_data.youtube_url)
|
||||
transcript = self.get_transcript(video_id, credentials)
|
||||
transcript_text = self.format_transcript(transcript=transcript)
|
||||
|
||||
transcript = self.get_transcript(video_id, credentials)
|
||||
transcript_text = self.format_transcript(transcript=transcript)
|
||||
|
||||
yield "transcript", transcript_text
|
||||
# Only yield after all operations succeed
|
||||
yield "video_id", video_id
|
||||
yield "transcript", transcript_text
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -246,7 +246,9 @@ class BlockSchema(BaseModel):
|
||||
f"is not of type {CredentialsMetaInput.__name__}"
|
||||
)
|
||||
|
||||
credentials_fields[field_name].validate_credentials_field_schema(cls)
|
||||
CredentialsMetaInput.validate_credentials_field_schema(
|
||||
cls.get_field_schema(field_name), field_name
|
||||
)
|
||||
|
||||
elif field_name in credentials_fields:
|
||||
raise KeyError(
|
||||
@@ -873,14 +875,13 @@ def is_block_auth_configured(
|
||||
|
||||
|
||||
async def initialize_blocks() -> None:
|
||||
# First, sync all provider costs to blocks
|
||||
# Imported here to avoid circular import
|
||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||
from backend.util.retry import func_retry
|
||||
|
||||
sync_all_provider_costs()
|
||||
|
||||
for cls in get_blocks().values():
|
||||
block = cls()
|
||||
@func_retry
|
||||
async def sync_block_to_db(block: Block) -> None:
|
||||
existing_block = await AgentBlock.prisma().find_first(
|
||||
where={"OR": [{"id": block.id}, {"name": block.name}]}
|
||||
)
|
||||
@@ -893,7 +894,7 @@ async def initialize_blocks() -> None:
|
||||
outputSchema=json.dumps(block.output_schema.jsonschema()),
|
||||
)
|
||||
)
|
||||
continue
|
||||
return
|
||||
|
||||
input_schema = json.dumps(block.input_schema.jsonschema())
|
||||
output_schema = json.dumps(block.output_schema.jsonschema())
|
||||
@@ -913,6 +914,25 @@ async def initialize_blocks() -> None:
|
||||
},
|
||||
)
|
||||
|
||||
failed_blocks: list[str] = []
|
||||
for cls in get_blocks().values():
|
||||
block = cls()
|
||||
try:
|
||||
await sync_block_to_db(block)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to sync block {block.name} to database: {e}. "
|
||||
"Block is still available in memory.",
|
||||
exc_info=True,
|
||||
)
|
||||
failed_blocks.append(block.name)
|
||||
|
||||
if failed_blocks:
|
||||
logger.error(
|
||||
f"Failed to sync {len(failed_blocks)} block(s) to database: "
|
||||
f"{', '.join(failed_blocks)}. These blocks are still available in memory."
|
||||
)
|
||||
|
||||
|
||||
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||
def get_block(block_id: str) -> AnyBlockSchema | None:
|
||||
|
||||
@@ -36,12 +36,14 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.blocks.video.narration import VideoNarrationBlock
|
||||
from backend.data.block import Block, BlockCost, BlockCostType
|
||||
from backend.integrations.credentials_store import (
|
||||
aiml_api_credentials,
|
||||
anthropic_credentials,
|
||||
apollo_credentials,
|
||||
did_credentials,
|
||||
elevenlabs_credentials,
|
||||
enrichlayer_credentials,
|
||||
groq_credentials,
|
||||
ideogram_credentials,
|
||||
@@ -78,6 +80,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_4_6_OPUS: 14,
|
||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
@@ -639,4 +642,16 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
},
|
||||
),
|
||||
],
|
||||
VideoNarrationBlock: [
|
||||
BlockCost(
|
||||
cost_amount=5, # ElevenLabs TTS cost
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": elevenlabs_credentials.id,
|
||||
"provider": elevenlabs_credentials.provider,
|
||||
"type": elevenlabs_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
@@ -134,6 +134,16 @@ async def test_block_credit_reset(server: SpinTestServer):
|
||||
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
||||
user_credit.time_now = lambda: month1
|
||||
|
||||
# IMPORTANT: Set updatedAt to December of previous year to ensure it's
|
||||
# in a different month than month1 (January). This fixes a timing bug
|
||||
# where if the test runs in early February, 35 days ago would be January,
|
||||
# matching the mocked month1 and preventing the refill from triggering.
|
||||
dec_previous_year = month1.replace(year=month1.year - 1, month=12, day=15)
|
||||
await UserBalance.prisma().update(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={"updatedAt": dec_previous_year},
|
||||
)
|
||||
|
||||
# First call in month 1 should trigger refill
|
||||
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
assert balance == REFILL_VALUE # Should get 1000 credits
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import logging
|
||||
import queue
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from multiprocessing import Manager
|
||||
from queue import Empty
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
@@ -1200,12 +1199,16 @@ class NodeExecutionEntry(BaseModel):
|
||||
|
||||
class ExecutionQueue(Generic[T]):
|
||||
"""
|
||||
Queue for managing the execution of agents.
|
||||
This will be shared between different processes
|
||||
Thread-safe queue for managing node execution within a single graph execution.
|
||||
|
||||
Note: Uses queue.Queue (not multiprocessing.Queue) since all access is from
|
||||
threads within the same process. If migrating back to ProcessPoolExecutor,
|
||||
replace with multiprocessing.Manager().Queue() for cross-process safety.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.queue = Manager().Queue()
|
||||
# Thread-safe queue (not multiprocessing) — see class docstring
|
||||
self.queue: queue.Queue[T] = queue.Queue()
|
||||
|
||||
def add(self, execution: T) -> T:
|
||||
self.queue.put(execution)
|
||||
@@ -1220,7 +1223,7 @@ class ExecutionQueue(Generic[T]):
|
||||
def get_or_none(self) -> T | None:
|
||||
try:
|
||||
return self.queue.get_nowait()
|
||||
except Empty:
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
"""Tests for ExecutionQueue thread-safety."""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
|
||||
from backend.data.execution import ExecutionQueue
|
||||
|
||||
|
||||
def test_execution_queue_uses_stdlib_queue():
|
||||
"""Verify ExecutionQueue uses queue.Queue (not multiprocessing)."""
|
||||
q = ExecutionQueue()
|
||||
assert isinstance(q.queue, queue.Queue)
|
||||
|
||||
|
||||
def test_basic_operations():
|
||||
"""Test add, get, empty, and get_or_none."""
|
||||
q = ExecutionQueue()
|
||||
|
||||
assert q.empty() is True
|
||||
assert q.get_or_none() is None
|
||||
|
||||
result = q.add("item1")
|
||||
assert result == "item1"
|
||||
assert q.empty() is False
|
||||
|
||||
item = q.get()
|
||||
assert item == "item1"
|
||||
assert q.empty() is True
|
||||
|
||||
|
||||
def test_thread_safety():
|
||||
"""Test concurrent access from multiple threads."""
|
||||
q = ExecutionQueue()
|
||||
results = []
|
||||
num_items = 100
|
||||
|
||||
def producer():
|
||||
for i in range(num_items):
|
||||
q.add(f"item_{i}")
|
||||
|
||||
def consumer():
|
||||
count = 0
|
||||
while count < num_items:
|
||||
item = q.get_or_none()
|
||||
if item is not None:
|
||||
results.append(item)
|
||||
count += 1
|
||||
|
||||
producer_thread = threading.Thread(target=producer)
|
||||
consumer_thread = threading.Thread(target=consumer)
|
||||
|
||||
producer_thread.start()
|
||||
consumer_thread.start()
|
||||
|
||||
producer_thread.join(timeout=5)
|
||||
consumer_thread.join(timeout=5)
|
||||
|
||||
assert len(results) == num_items
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, cast
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Self, cast
|
||||
|
||||
from prisma.enums import SubmissionStatus
|
||||
from prisma.models import (
|
||||
@@ -20,7 +20,7 @@ from prisma.types import (
|
||||
AgentNodeLinkCreateInput,
|
||||
StoreListingVersionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel, BeforeValidator, Field, create_model
|
||||
from pydantic import BaseModel, BeforeValidator, Field
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
@@ -30,7 +30,6 @@ from backend.data.db import prisma as db
|
||||
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
||||
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
@@ -45,7 +44,6 @@ from .block import (
|
||||
AnyBlockSchema,
|
||||
Block,
|
||||
BlockInput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
EmptySchema,
|
||||
get_block,
|
||||
@@ -113,10 +111,12 @@ class Link(BaseDbModel):
|
||||
|
||||
class Node(BaseDbModel):
|
||||
block_id: str
|
||||
input_default: BlockInput = {} # dict[input_name, default_value]
|
||||
metadata: dict[str, Any] = {}
|
||||
input_links: list[Link] = []
|
||||
output_links: list[Link] = []
|
||||
input_default: BlockInput = Field( # dict[input_name, default_value]
|
||||
default_factory=dict
|
||||
)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
input_links: list[Link] = Field(default_factory=list)
|
||||
output_links: list[Link] = Field(default_factory=list)
|
||||
|
||||
@property
|
||||
def credentials_optional(self) -> bool:
|
||||
@@ -221,18 +221,33 @@ class NodeModel(Node):
|
||||
return result
|
||||
|
||||
|
||||
class BaseGraph(BaseDbModel):
|
||||
class GraphBaseMeta(BaseDbModel):
|
||||
"""
|
||||
Shared base for `GraphMeta` and `BaseGraph`, with core graph metadata fields.
|
||||
"""
|
||||
|
||||
version: int = 1
|
||||
is_active: bool = True
|
||||
name: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
recommended_schedule_cron: str | None = None
|
||||
nodes: list[Node] = []
|
||||
links: list[Link] = []
|
||||
forked_from_id: str | None = None
|
||||
forked_from_version: int | None = None
|
||||
|
||||
|
||||
class BaseGraph(GraphBaseMeta):
|
||||
"""
|
||||
Graph with nodes, links, and computed I/O schema fields.
|
||||
|
||||
Used to represent sub-graphs within a `Graph`. Contains the full graph
|
||||
structure including nodes and links, plus computed fields for schemas
|
||||
and trigger info. Does NOT include user_id or created_at (see GraphModel).
|
||||
"""
|
||||
|
||||
nodes: list[Node] = Field(default_factory=list)
|
||||
links: list[Link] = Field(default_factory=list)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
@@ -361,44 +376,79 @@ class GraphTriggerInfo(BaseModel):
|
||||
|
||||
|
||||
class Graph(BaseGraph):
|
||||
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
|
||||
"""Creatable graph model used in API create/update endpoints."""
|
||||
|
||||
sub_graphs: list[BaseGraph] = Field(default_factory=list) # Flattened sub-graphs
|
||||
|
||||
|
||||
class GraphMeta(GraphBaseMeta):
|
||||
"""
|
||||
Lightweight graph metadata model representing an existing graph from the database,
|
||||
for use in listings and summaries.
|
||||
|
||||
Lacks `GraphModel`'s nodes, links, and expensive computed fields.
|
||||
Use for list endpoints where full graph data is not needed and performance matters.
|
||||
"""
|
||||
|
||||
id: str # type: ignore
|
||||
version: int # type: ignore
|
||||
user_id: str
|
||||
created_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, graph: "AgentGraph") -> Self:
|
||||
return cls(
|
||||
id=graph.id,
|
||||
version=graph.version,
|
||||
is_active=graph.isActive,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
instructions=graph.instructions,
|
||||
recommended_schedule_cron=graph.recommendedScheduleCron,
|
||||
forked_from_id=graph.forkedFromId,
|
||||
forked_from_version=graph.forkedFromVersion,
|
||||
user_id=graph.userId,
|
||||
created_at=graph.createdAt,
|
||||
)
|
||||
|
||||
|
||||
class GraphModel(Graph, GraphMeta):
|
||||
"""
|
||||
Full graph model representing an existing graph from the database.
|
||||
|
||||
This is the primary model for working with persisted graphs. Includes all
|
||||
graph data (nodes, links, sub_graphs) plus user ownership and timestamps.
|
||||
Provides computed fields (input_schema, output_schema, etc.) used during
|
||||
set-up (frontend) and execution (backend).
|
||||
|
||||
Inherits from:
|
||||
- `Graph`: provides structure (nodes, links, sub_graphs) and computed schemas
|
||||
- `GraphMeta`: provides user_id, created_at for database records
|
||||
"""
|
||||
|
||||
nodes: list[NodeModel] = Field(default_factory=list) # type: ignore
|
||||
|
||||
@property
|
||||
def starting_nodes(self) -> list[NodeModel]:
|
||||
outbound_nodes = {link.sink_id for link in self.links}
|
||||
input_nodes = {
|
||||
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
||||
}
|
||||
return [
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.id not in outbound_nodes or node.id in input_nodes
|
||||
]
|
||||
|
||||
@property
|
||||
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
||||
return cast(NodeModel, super().webhook_input_node)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
schema = self._credentials_input_schema.jsonschema()
|
||||
|
||||
# Determine which credential fields are required based on credentials_optional metadata
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
required_fields = []
|
||||
|
||||
# Build a map of node_id -> node for quick lookup
|
||||
all_nodes = {node.id: node for node in self.nodes}
|
||||
for sub_graph in self.sub_graphs:
|
||||
for node in sub_graph.nodes:
|
||||
all_nodes[node.id] = node
|
||||
|
||||
for field_key, (
|
||||
_field_info,
|
||||
node_field_pairs,
|
||||
) in graph_credentials_inputs.items():
|
||||
# A field is required if ANY node using it has credentials_optional=False
|
||||
is_required = False
|
||||
for node_id, _field_name in node_field_pairs:
|
||||
node = all_nodes.get(node_id)
|
||||
if node and not node.credentials_optional:
|
||||
is_required = True
|
||||
break
|
||||
|
||||
if is_required:
|
||||
required_fields.append(field_key)
|
||||
|
||||
schema["required"] = required_fields
|
||||
return schema
|
||||
|
||||
@property
|
||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
logger.debug(
|
||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||
f"{graph_credentials_inputs}"
|
||||
@@ -406,8 +456,8 @@ class Graph(BaseGraph):
|
||||
|
||||
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
||||
graph_cred_fields = list(graph_credentials_inputs.values())
|
||||
for i, (field, keys) in enumerate(graph_cred_fields):
|
||||
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
||||
for i, (field, keys, _) in enumerate(graph_cred_fields):
|
||||
for other_field, other_keys, _ in list(graph_cred_fields)[i + 1 :]:
|
||||
if field.provider != other_field.provider:
|
||||
continue
|
||||
if ProviderName.HTTP in field.provider:
|
||||
@@ -423,31 +473,78 @@ class Graph(BaseGraph):
|
||||
f"keys: {keys} <> {other_keys}."
|
||||
)
|
||||
|
||||
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
|
||||
agg_field_key: (
|
||||
CredentialsMetaInput[
|
||||
Literal[tuple(field_info.provider)], # type: ignore
|
||||
Literal[tuple(field_info.supported_types)], # type: ignore
|
||||
],
|
||||
CredentialsField(
|
||||
required_scopes=set(field_info.required_scopes or []),
|
||||
discriminator=field_info.discriminator,
|
||||
discriminator_mapping=field_info.discriminator_mapping,
|
||||
discriminator_values=field_info.discriminator_values,
|
||||
),
|
||||
)
|
||||
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
||||
}
|
||||
# Build JSON schema directly to avoid expensive create_model + validation overhead
|
||||
properties = {}
|
||||
required_fields = []
|
||||
|
||||
return create_model(
|
||||
self.name.replace(" ", "") + "CredentialsInputSchema",
|
||||
__base__=BlockSchema,
|
||||
**fields, # type: ignore
|
||||
)
|
||||
for agg_field_key, (
|
||||
field_info,
|
||||
_,
|
||||
is_required,
|
||||
) in graph_credentials_inputs.items():
|
||||
providers = list(field_info.provider)
|
||||
cred_types = list(field_info.supported_types)
|
||||
|
||||
field_schema: dict[str, Any] = {
|
||||
"credentials_provider": providers,
|
||||
"credentials_types": cred_types,
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"title": "Id", "type": "string"},
|
||||
"title": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"default": None,
|
||||
"title": "Title",
|
||||
},
|
||||
"provider": {
|
||||
"title": "Provider",
|
||||
"type": "string",
|
||||
**(
|
||||
{"enum": providers}
|
||||
if len(providers) > 1
|
||||
else {"const": providers[0]}
|
||||
),
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"type": "string",
|
||||
**(
|
||||
{"enum": cred_types}
|
||||
if len(cred_types) > 1
|
||||
else {"const": cred_types[0]}
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["id", "provider", "type"],
|
||||
}
|
||||
|
||||
# Add other (optional) field info items
|
||||
field_schema.update(
|
||||
field_info.model_dump(
|
||||
by_alias=True,
|
||||
exclude_defaults=True,
|
||||
exclude={"provider", "supported_types"}, # already included above
|
||||
)
|
||||
)
|
||||
|
||||
# Ensure field schema is well-formed
|
||||
CredentialsMetaInput.validate_credentials_field_schema(
|
||||
field_schema, agg_field_key
|
||||
)
|
||||
|
||||
properties[agg_field_key] = field_schema
|
||||
if is_required:
|
||||
required_fields.append(agg_field_key)
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required_fields,
|
||||
}
|
||||
|
||||
def aggregate_credentials_inputs(
|
||||
self,
|
||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
|
||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
||||
"""
|
||||
Returns:
|
||||
dict[aggregated_field_key, tuple(
|
||||
@@ -455,13 +552,19 @@ class Graph(BaseGraph):
|
||||
(now includes discriminator_values from matching nodes)
|
||||
set[(node_id, field_name)]: Node credentials fields that are
|
||||
compatible with this aggregated field spec
|
||||
bool: True if the field is required (any node has credentials_optional=False)
|
||||
)]
|
||||
"""
|
||||
# First collect all credential field data with input defaults
|
||||
node_credential_data = []
|
||||
# Track (field_info, (node_id, field_name), is_required) for each credential field
|
||||
node_credential_data: list[tuple[CredentialsFieldInfo, tuple[str, str]]] = []
|
||||
node_required_map: dict[str, bool] = {} # node_id -> is_required
|
||||
|
||||
for graph in [self] + self.sub_graphs:
|
||||
for node in graph.nodes:
|
||||
# Track if this node requires credentials (credentials_optional=False means required)
|
||||
node_required_map[node.id] = not node.credentials_optional
|
||||
|
||||
for (
|
||||
field_name,
|
||||
field_info,
|
||||
@@ -485,37 +588,21 @@ class Graph(BaseGraph):
|
||||
)
|
||||
|
||||
# Combine credential field info (this will merge discriminator_values automatically)
|
||||
return CredentialsFieldInfo.combine(*node_credential_data)
|
||||
combined = CredentialsFieldInfo.combine(*node_credential_data)
|
||||
|
||||
|
||||
class GraphModel(Graph):
|
||||
user_id: str
|
||||
nodes: list[NodeModel] = [] # type: ignore
|
||||
|
||||
created_at: datetime
|
||||
|
||||
@property
|
||||
def starting_nodes(self) -> list[NodeModel]:
|
||||
outbound_nodes = {link.sink_id for link in self.links}
|
||||
input_nodes = {
|
||||
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
||||
# Add is_required flag to each aggregated field
|
||||
# A field is required if ANY node using it has credentials_optional=False
|
||||
return {
|
||||
key: (
|
||||
field_info,
|
||||
node_field_pairs,
|
||||
any(
|
||||
node_required_map.get(node_id, True)
|
||||
for node_id, _ in node_field_pairs
|
||||
),
|
||||
)
|
||||
for key, (field_info, node_field_pairs) in combined.items()
|
||||
}
|
||||
return [
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.id not in outbound_nodes or node.id in input_nodes
|
||||
]
|
||||
|
||||
@property
|
||||
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
||||
return cast(NodeModel, super().webhook_input_node)
|
||||
|
||||
def meta(self) -> "GraphMeta":
|
||||
"""
|
||||
Returns a GraphMeta object with metadata about the graph.
|
||||
This is used to return metadata about the graph without exposing nodes and links.
|
||||
"""
|
||||
return GraphMeta.from_graph(self)
|
||||
|
||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||
"""
|
||||
@@ -799,13 +886,14 @@ class GraphModel(Graph):
|
||||
if is_static_output_block(link.source_id):
|
||||
link.is_static = True # Each value block output should be static.
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
@classmethod
|
||||
def from_db( # type: ignore[reportIncompatibleMethodOverride]
|
||||
cls,
|
||||
graph: AgentGraph,
|
||||
for_export: bool = False,
|
||||
sub_graphs: list[AgentGraph] | None = None,
|
||||
) -> "GraphModel":
|
||||
return GraphModel(
|
||||
) -> Self:
|
||||
return cls(
|
||||
id=graph.id,
|
||||
user_id=graph.userId if not for_export else "",
|
||||
version=graph.version,
|
||||
@@ -831,17 +919,28 @@ class GraphModel(Graph):
|
||||
],
|
||||
)
|
||||
|
||||
def hide_nodes(self) -> "GraphModelWithoutNodes":
|
||||
"""
|
||||
Returns a copy of the `GraphModel` with nodes, links, and sub-graphs hidden
|
||||
(excluded from serialization). They are still present in the model instance
|
||||
so all computed fields (e.g. `credentials_input_schema`) still work.
|
||||
"""
|
||||
return GraphModelWithoutNodes.model_validate(self, from_attributes=True)
|
||||
|
||||
class GraphMeta(Graph):
|
||||
user_id: str
|
||||
|
||||
# Easy work-around to prevent exposing nodes and links in the API response
|
||||
nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
|
||||
links: list[Link] = Field(default=[], exclude=True)
|
||||
class GraphModelWithoutNodes(GraphModel):
|
||||
"""
|
||||
GraphModel variant that excludes nodes, links, and sub-graphs from serialization.
|
||||
|
||||
@staticmethod
|
||||
def from_graph(graph: GraphModel) -> "GraphMeta":
|
||||
return GraphMeta(**graph.model_dump())
|
||||
Used in contexts like the store where exposing internal graph structure
|
||||
is not desired. Inherits all computed fields from GraphModel but marks
|
||||
nodes and links as excluded from JSON output.
|
||||
"""
|
||||
|
||||
nodes: list[NodeModel] = Field(default_factory=list, exclude=True)
|
||||
links: list[Link] = Field(default_factory=list, exclude=True)
|
||||
|
||||
sub_graphs: list[BaseGraph] = Field(default_factory=list, exclude=True)
|
||||
|
||||
|
||||
class GraphsPaginated(BaseModel):
|
||||
@@ -912,21 +1011,11 @@ async def list_graphs_paginated(
|
||||
where=where_clause,
|
||||
distinct=["id"],
|
||||
order={"version": "desc"},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
skip=offset,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
graph_models: list[GraphMeta] = []
|
||||
for graph in graphs:
|
||||
try:
|
||||
graph_meta = GraphModel.from_db(graph).meta()
|
||||
# Trigger serialization to validate that the graph is well formed
|
||||
graph_meta.model_dump()
|
||||
graph_models.append(graph_meta)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing graph {graph.id}: {e}")
|
||||
continue
|
||||
graph_models = [GraphMeta.from_db(graph) for graph in graphs]
|
||||
|
||||
return GraphsPaginated(
|
||||
graphs=graph_models,
|
||||
|
||||
@@ -19,7 +19,6 @@ from typing import (
|
||||
cast,
|
||||
get_args,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
from prisma.enums import CreditTransactionType, OnboardingStep
|
||||
@@ -42,6 +41,7 @@ from typing_extensions import TypedDict
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.json import loads as json_loads
|
||||
from backend.util.request import parse_url
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
# Type alias for any provider name (including custom ones)
|
||||
@@ -163,7 +163,6 @@ class User(BaseModel):
|
||||
if TYPE_CHECKING:
|
||||
from prisma.models import User as PrismaUser
|
||||
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
T = TypeVar("T")
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -397,19 +396,25 @@ class HostScopedCredentials(_BaseCredentials):
|
||||
def matches_url(self, url: str) -> bool:
|
||||
"""Check if this credential should be applied to the given URL."""
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
# Extract hostname without port
|
||||
request_host = parsed_url.hostname
|
||||
request_host, request_port = _extract_host_from_url(url)
|
||||
cred_scope_host, cred_scope_port = _extract_host_from_url(self.host)
|
||||
if not request_host:
|
||||
return False
|
||||
|
||||
# Simple host matching - exact match or wildcard subdomain match
|
||||
if self.host == request_host:
|
||||
# If a port is specified in credential host, the request host port must match
|
||||
if cred_scope_port is not None and request_port != cred_scope_port:
|
||||
return False
|
||||
# Non-standard ports are only allowed if explicitly specified in credential host
|
||||
elif cred_scope_port is None and request_port not in (80, 443, None):
|
||||
return False
|
||||
|
||||
# Simple host matching
|
||||
if cred_scope_host == request_host:
|
||||
return True
|
||||
|
||||
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
||||
if self.host.startswith("*."):
|
||||
domain = self.host[2:] # Remove "*."
|
||||
if cred_scope_host.startswith("*."):
|
||||
domain = cred_scope_host[2:] # Remove "*."
|
||||
return request_host.endswith(f".{domain}") or request_host == domain
|
||||
|
||||
return False
|
||||
@@ -502,15 +507,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
|
||||
return get_args(cls.model_fields["type"].annotation)
|
||||
|
||||
@classmethod
|
||||
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
|
||||
@staticmethod
|
||||
def validate_credentials_field_schema(
|
||||
field_schema: dict[str, Any], field_name: str
|
||||
):
|
||||
"""Validates the schema of a credentials input field"""
|
||||
field_name = next(
|
||||
name for name, type in model.get_credentials_fields().items() if type is cls
|
||||
)
|
||||
field_schema = model.jsonschema()["properties"][field_name]
|
||||
try:
|
||||
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
||||
field_info = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
||||
except ValidationError as e:
|
||||
if "Field required [type=missing" not in str(e):
|
||||
raise
|
||||
@@ -520,11 +523,11 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
f"{field_schema}"
|
||||
) from e
|
||||
|
||||
providers = cls.allowed_providers()
|
||||
providers = field_info.provider
|
||||
if (
|
||||
providers is not None
|
||||
and len(providers) > 1
|
||||
and not schema_extra.discriminator
|
||||
and not field_info.discriminator
|
||||
):
|
||||
raise TypeError(
|
||||
f"Multi-provider CredentialsField '{field_name}' "
|
||||
@@ -551,13 +554,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
|
||||
|
||||
def _extract_host_from_url(url: str) -> str:
|
||||
"""Extract host from URL for grouping host-scoped credentials."""
|
||||
def _extract_host_from_url(url: str) -> tuple[str, int | None]:
|
||||
"""Extract host and port from URL for grouping host-scoped credentials."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return parsed.hostname or url
|
||||
parsed = parse_url(url)
|
||||
return parsed.hostname or url, parsed.port
|
||||
except Exception:
|
||||
return ""
|
||||
return "", None
|
||||
|
||||
|
||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
@@ -606,7 +609,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
providers = frozenset(
|
||||
[cast(CP, "http")]
|
||||
+ [
|
||||
cast(CP, _extract_host_from_url(str(value)))
|
||||
cast(CP, parse_url(str(value)).netloc)
|
||||
for value in field.discriminator_values
|
||||
]
|
||||
)
|
||||
|
||||
@@ -79,10 +79,23 @@ class TestHostScopedCredentials:
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||
# Non-standard ports require explicit port in credential host
|
||||
assert not creds.matches_url("http://localhost:8080/api/v1")
|
||||
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
||||
assert creds.matches_url("http://localhost/simple")
|
||||
|
||||
def test_matches_url_with_explicit_port(self):
|
||||
"""Test URL matching with explicit port in credential host."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="localhost:8080",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||
assert not creds.matches_url("http://localhost:3000/api/v1")
|
||||
assert not creds.matches_url("http://localhost/simple")
|
||||
|
||||
def test_empty_headers_dict(self):
|
||||
"""Test HostScopedCredentials with empty headers."""
|
||||
creds = HostScopedCredentials(
|
||||
@@ -128,8 +141,20 @@ class TestHostScopedCredentials:
|
||||
("*.example.com", "https://sub.api.example.com/test", True),
|
||||
("*.example.com", "https://example.com/test", True),
|
||||
("*.example.com", "https://example.org/test", False),
|
||||
("localhost", "http://localhost:3000/test", True),
|
||||
# Non-standard ports require explicit port in credential host
|
||||
("localhost", "http://localhost:3000/test", False),
|
||||
("localhost:3000", "http://localhost:3000/test", True),
|
||||
("localhost", "http://127.0.0.1:3000/test", False),
|
||||
# IPv6 addresses (frontend stores with brackets via URL.hostname)
|
||||
("[::1]", "http://[::1]/test", True),
|
||||
("[::1]", "http://[::1]:80/test", True),
|
||||
("[::1]", "https://[::1]:443/test", True),
|
||||
("[::1]", "http://[::1]:8080/test", False), # Non-standard port
|
||||
("[::1]:8080", "http://[::1]:8080/test", True),
|
||||
("[::1]:8080", "http://[::1]:9090/test", False),
|
||||
("[2001:db8::1]", "http://[2001:db8::1]/path", True),
|
||||
("[2001:db8::1]", "https://[2001:db8::1]:443/path", True),
|
||||
("[2001:db8::1]", "http://[2001:db8::ff]/path", False),
|
||||
],
|
||||
)
|
||||
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
@@ -225,6 +226,10 @@ class SyncRabbitMQ(RabbitMQBase):
|
||||
class AsyncRabbitMQ(RabbitMQBase):
|
||||
"""Asynchronous RabbitMQ client"""
|
||||
|
||||
def __init__(self, config: RabbitMQConfig):
|
||||
super().__init__(config)
|
||||
self._reconnect_lock: asyncio.Lock | None = None
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return bool(self._connection and not self._connection.is_closed)
|
||||
@@ -235,7 +240,17 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
|
||||
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
|
||||
async def connect(self):
|
||||
if self.is_connected:
|
||||
if self.is_connected and self._channel and not self._channel.is_closed:
|
||||
return
|
||||
|
||||
if (
|
||||
self.is_connected
|
||||
and self._connection
|
||||
and (self._channel is None or self._channel.is_closed)
|
||||
):
|
||||
self._channel = await self._connection.channel()
|
||||
await self._channel.set_qos(prefetch_count=1)
|
||||
await self.declare_infrastructure()
|
||||
return
|
||||
|
||||
self._connection = await aio_pika.connect_robust(
|
||||
@@ -291,24 +306,46 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
exchange, routing_key=queue.routing_key or queue.name
|
||||
)
|
||||
|
||||
@func_retry
|
||||
async def publish_message(
|
||||
@property
|
||||
def _lock(self) -> asyncio.Lock:
|
||||
if self._reconnect_lock is None:
|
||||
self._reconnect_lock = asyncio.Lock()
|
||||
return self._reconnect_lock
|
||||
|
||||
async def _ensure_channel(self) -> aio_pika.abc.AbstractChannel:
|
||||
"""Get a valid channel, reconnecting if the current one is stale.
|
||||
|
||||
Uses a lock to prevent concurrent reconnection attempts from racing.
|
||||
"""
|
||||
if self.is_ready:
|
||||
return self._channel # type: ignore # is_ready guarantees non-None
|
||||
|
||||
async with self._lock:
|
||||
# Double-check after acquiring lock
|
||||
if self.is_ready:
|
||||
return self._channel # type: ignore
|
||||
|
||||
self._channel = None
|
||||
await self.connect()
|
||||
|
||||
if self._channel is None:
|
||||
raise RuntimeError("Channel should be established after connect")
|
||||
|
||||
return self._channel
|
||||
|
||||
async def _publish_once(
|
||||
self,
|
||||
routing_key: str,
|
||||
message: str,
|
||||
exchange: Optional[Exchange] = None,
|
||||
persistent: bool = True,
|
||||
) -> None:
|
||||
if not self.is_ready:
|
||||
await self.connect()
|
||||
|
||||
if self._channel is None:
|
||||
raise RuntimeError("Channel should be established after connect")
|
||||
channel = await self._ensure_channel()
|
||||
|
||||
if exchange:
|
||||
exchange_obj = await self._channel.get_exchange(exchange.name)
|
||||
exchange_obj = await channel.get_exchange(exchange.name)
|
||||
else:
|
||||
exchange_obj = self._channel.default_exchange
|
||||
exchange_obj = channel.default_exchange
|
||||
|
||||
await exchange_obj.publish(
|
||||
aio_pika.Message(
|
||||
@@ -322,9 +359,23 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
routing_key=routing_key,
|
||||
)
|
||||
|
||||
@func_retry
|
||||
async def publish_message(
|
||||
self,
|
||||
routing_key: str,
|
||||
message: str,
|
||||
exchange: Optional[Exchange] = None,
|
||||
persistent: bool = True,
|
||||
) -> None:
|
||||
try:
|
||||
await self._publish_once(routing_key, message, exchange, persistent)
|
||||
except aio_pika.exceptions.ChannelInvalidStateError:
|
||||
logger.warning(
|
||||
"RabbitMQ channel invalid, forcing reconnect and retrying publish"
|
||||
)
|
||||
async with self._lock:
|
||||
self._channel = None
|
||||
await self._publish_once(routing_key, message, exchange, persistent)
|
||||
|
||||
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
|
||||
if not self.is_ready:
|
||||
await self.connect()
|
||||
if self._channel is None:
|
||||
raise RuntimeError("Channel should be established after connect")
|
||||
return self._channel
|
||||
return await self._ensure_channel()
|
||||
|
||||
@@ -17,6 +17,7 @@ from backend.data.analytics import (
|
||||
get_accuracy_trends_and_alerts,
|
||||
get_marketplace_graphs_for_monitoring,
|
||||
)
|
||||
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
create_graph_execution,
|
||||
@@ -219,6 +220,9 @@ class DatabaseManager(AppService):
|
||||
# Onboarding
|
||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||
|
||||
# OAuth
|
||||
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
|
||||
|
||||
# Store
|
||||
get_store_agents = _(get_store_agents)
|
||||
get_store_agent_details = _(get_store_agent_details)
|
||||
@@ -349,6 +353,9 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
# Onboarding
|
||||
increment_onboarding_runs = d.increment_onboarding_runs
|
||||
|
||||
# OAuth
|
||||
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
|
||||
|
||||
# Store
|
||||
get_store_agents = d.get_store_agents
|
||||
get_store_agent_details = d.get_store_agent_details
|
||||
|
||||
@@ -24,11 +24,9 @@ from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.monitoring import (
|
||||
NotificationJobArgs,
|
||||
@@ -38,7 +36,11 @@ from backend.monitoring import (
|
||||
report_execution_accuracy_alerts,
|
||||
report_late_executions,
|
||||
)
|
||||
from backend.util.clients import get_database_manager_client, get_scheduler_client
|
||||
from backend.util.clients import (
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_scheduler_client,
|
||||
)
|
||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||
from backend.util.exceptions import (
|
||||
GraphNotFoundError,
|
||||
@@ -148,6 +150,7 @@ def execute_graph(**kwargs):
|
||||
async def _execute_graph(**kwargs):
|
||||
args = GraphExecutionJobArgs(**kwargs)
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
db = get_database_manager_async_client()
|
||||
try:
|
||||
logger.info(f"Executing recurring job for graph #{args.graph_id}")
|
||||
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
|
||||
@@ -157,7 +160,7 @@ async def _execute_graph(**kwargs):
|
||||
inputs=args.input_data,
|
||||
graph_credentials_inputs=args.input_credentials,
|
||||
)
|
||||
await increment_onboarding_runs(args.user_id)
|
||||
await db.increment_onboarding_runs(args.user_id)
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
logger.info(
|
||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
||||
@@ -246,8 +249,13 @@ def cleanup_expired_files():
|
||||
|
||||
def cleanup_oauth_tokens():
|
||||
"""Clean up expired OAuth tokens from the database."""
|
||||
|
||||
# Wait for completion
|
||||
run_async(cleanup_expired_oauth_tokens())
|
||||
async def _cleanup():
|
||||
db = get_database_manager_async_client()
|
||||
return await db.cleanup_expired_oauth_tokens()
|
||||
|
||||
run_async(_cleanup())
|
||||
|
||||
|
||||
def execution_accuracy_alerts():
|
||||
|
||||
@@ -373,7 +373,7 @@ def make_node_credentials_input_map(
|
||||
# Get aggregated credentials fields for the graph
|
||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||
|
||||
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
|
||||
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
|
||||
# Best-effort map: skip missing items
|
||||
if graph_input_name not in graph_credentials_input:
|
||||
continue
|
||||
|
||||
@@ -224,6 +224,14 @@ openweathermap_credentials = APIKeyCredentials(
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
elevenlabs_credentials = APIKeyCredentials(
|
||||
id="f4a8b6c2-3d1e-4f5a-9b8c-7d6e5f4a3b2c",
|
||||
provider="elevenlabs",
|
||||
api_key=SecretStr(settings.secrets.elevenlabs_api_key),
|
||||
title="Use Credits for ElevenLabs",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
DEFAULT_CREDENTIALS = [
|
||||
ollama_credentials,
|
||||
revid_credentials,
|
||||
@@ -252,6 +260,7 @@ DEFAULT_CREDENTIALS = [
|
||||
v0_credentials,
|
||||
webshare_proxy_credentials,
|
||||
openweathermap_credentials,
|
||||
elevenlabs_credentials,
|
||||
]
|
||||
|
||||
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
||||
@@ -366,6 +375,8 @@ class IntegrationCredentialsStore:
|
||||
all_credentials.append(webshare_proxy_credentials)
|
||||
if settings.secrets.openweathermap_api_key:
|
||||
all_credentials.append(openweathermap_credentials)
|
||||
if settings.secrets.elevenlabs_api_key:
|
||||
all_credentials.append(elevenlabs_credentials)
|
||||
return all_credentials
|
||||
|
||||
async def get_creds_by_id(
|
||||
|
||||
@@ -18,6 +18,7 @@ class ProviderName(str, Enum):
|
||||
DISCORD = "discord"
|
||||
D_ID = "d_id"
|
||||
E2B = "e2b"
|
||||
ELEVENLABS = "elevenlabs"
|
||||
FAL = "fal"
|
||||
GITHUB = "github"
|
||||
GOOGLE = "google"
|
||||
|
||||
@@ -8,6 +8,8 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.request import Requests
|
||||
from backend.util.settings import Config
|
||||
@@ -17,6 +19,35 @@ from backend.util.virus_scanner import scan_content_safe
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
|
||||
class WorkspaceUri(BaseModel):
|
||||
"""Parsed workspace:// URI."""
|
||||
|
||||
file_ref: str # File ID or path (e.g. "abc123" or "/path/to/file.txt")
|
||||
mime_type: str | None = None # MIME type from fragment (e.g. "video/mp4")
|
||||
is_path: bool = False # True if file_ref is a path (starts with "/")
|
||||
|
||||
|
||||
def parse_workspace_uri(uri: str) -> WorkspaceUri:
|
||||
"""Parse a workspace:// URI into its components.
|
||||
|
||||
Examples:
|
||||
"workspace://abc123" → WorkspaceUri(file_ref="abc123", mime_type=None, is_path=False)
|
||||
"workspace://abc123#video/mp4" → WorkspaceUri(file_ref="abc123", mime_type="video/mp4", is_path=False)
|
||||
"workspace:///path/to/file.txt" → WorkspaceUri(file_ref="/path/to/file.txt", mime_type=None, is_path=True)
|
||||
"""
|
||||
raw = uri.removeprefix("workspace://")
|
||||
mime_type: str | None = None
|
||||
if "#" in raw:
|
||||
raw, fragment = raw.split("#", 1)
|
||||
mime_type = fragment or None
|
||||
return WorkspaceUri(
|
||||
file_ref=raw,
|
||||
mime_type=mime_type,
|
||||
is_path=raw.startswith("/"),
|
||||
)
|
||||
|
||||
|
||||
# Return format options for store_media_file
|
||||
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
||||
@@ -183,22 +214,20 @@ async def store_media_file(
|
||||
"This file type is only available in CoPilot sessions."
|
||||
)
|
||||
|
||||
# Parse workspace reference
|
||||
# workspace://abc123 - by file ID
|
||||
# workspace:///path/to/file.txt - by virtual path
|
||||
file_ref = file[12:] # Remove "workspace://"
|
||||
# Parse workspace reference (strips #mimeType fragment from file ID)
|
||||
ws = parse_workspace_uri(file)
|
||||
|
||||
if file_ref.startswith("/"):
|
||||
# Path reference
|
||||
workspace_content = await workspace_manager.read_file(file_ref)
|
||||
file_info = await workspace_manager.get_file_info_by_path(file_ref)
|
||||
if ws.is_path:
|
||||
# Path reference: workspace:///path/to/file.txt
|
||||
workspace_content = await workspace_manager.read_file(ws.file_ref)
|
||||
file_info = await workspace_manager.get_file_info_by_path(ws.file_ref)
|
||||
filename = sanitize_filename(
|
||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||
)
|
||||
else:
|
||||
# ID reference
|
||||
workspace_content = await workspace_manager.read_file_by_id(file_ref)
|
||||
file_info = await workspace_manager.get_file_info(file_ref)
|
||||
# ID reference: workspace://abc123 or workspace://abc123#video/mp4
|
||||
workspace_content = await workspace_manager.read_file_by_id(ws.file_ref)
|
||||
file_info = await workspace_manager.get_file_info(ws.file_ref)
|
||||
filename = sanitize_filename(
|
||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||
)
|
||||
@@ -313,6 +342,14 @@ async def store_media_file(
|
||||
if not target_path.is_file():
|
||||
raise ValueError(f"Local file does not exist: {target_path}")
|
||||
|
||||
# Virus scan the local file before any further processing
|
||||
local_content = target_path.read_bytes()
|
||||
if len(local_content) > MAX_FILE_SIZE_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large: {len(local_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||
)
|
||||
await scan_content_safe(local_content, filename=sanitized_file)
|
||||
|
||||
# Return based on requested format
|
||||
if return_format == "for_local_processing":
|
||||
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
||||
@@ -334,7 +371,21 @@ async def store_media_file(
|
||||
|
||||
# Don't re-save if input was already from workspace
|
||||
if is_from_workspace:
|
||||
# Return original workspace reference
|
||||
# Return original workspace reference, ensuring MIME type fragment
|
||||
ws = parse_workspace_uri(file)
|
||||
if not ws.mime_type:
|
||||
# Add MIME type fragment if missing (older refs without it)
|
||||
try:
|
||||
if ws.is_path:
|
||||
info = await workspace_manager.get_file_info_by_path(
|
||||
ws.file_ref
|
||||
)
|
||||
else:
|
||||
info = await workspace_manager.get_file_info(ws.file_ref)
|
||||
if info:
|
||||
return MediaFileType(f"{file}#{info.mimeType}")
|
||||
except Exception:
|
||||
pass
|
||||
return MediaFileType(file)
|
||||
|
||||
# Save new content to workspace
|
||||
@@ -346,7 +397,7 @@ async def store_media_file(
|
||||
filename=filename,
|
||||
overwrite=True,
|
||||
)
|
||||
return MediaFileType(f"workspace://{file_record.id}")
|
||||
return MediaFileType(f"workspace://{file_record.id}#{file_record.mimeType}")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid return_format: {return_format}")
|
||||
|
||||
@@ -247,3 +247,100 @@ class TestFileCloudIntegration:
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_media_file_local_path_scanned(self):
|
||||
"""Test that local file paths are scanned for viruses."""
|
||||
graph_exec_id = "test-exec-123"
|
||||
local_file = "test_video.mp4"
|
||||
file_content = b"fake video content"
|
||||
|
||||
with patch(
|
||||
"backend.util.file.get_cloud_storage_handler"
|
||||
) as mock_handler_getter, patch(
|
||||
"backend.util.file.scan_content_safe"
|
||||
) as mock_scan, patch(
|
||||
"backend.util.file.Path"
|
||||
) as mock_path_class:
|
||||
|
||||
# Mock cloud storage handler - not a cloud path
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.is_cloud_path.return_value = False
|
||||
mock_handler_getter.return_value = mock_handler
|
||||
|
||||
# Mock virus scanner
|
||||
mock_scan.return_value = None
|
||||
|
||||
# Mock file system operations
|
||||
mock_base_path = MagicMock()
|
||||
mock_target_path = MagicMock()
|
||||
mock_resolved_path = MagicMock()
|
||||
|
||||
mock_path_class.return_value = mock_base_path
|
||||
mock_base_path.mkdir = MagicMock()
|
||||
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
||||
mock_target_path.resolve.return_value = mock_resolved_path
|
||||
mock_resolved_path.is_relative_to.return_value = True
|
||||
mock_resolved_path.is_file.return_value = True
|
||||
mock_resolved_path.read_bytes.return_value = file_content
|
||||
mock_resolved_path.relative_to.return_value = Path(local_file)
|
||||
mock_resolved_path.name = local_file
|
||||
|
||||
result = await store_media_file(
|
||||
file=MediaFileType(local_file),
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
# Verify virus scan was called for local file
|
||||
mock_scan.assert_called_once_with(file_content, filename=local_file)
|
||||
|
||||
# Result should be the relative path
|
||||
assert str(result) == local_file
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_media_file_local_path_virus_detected(self):
|
||||
"""Test that infected local files raise VirusDetectedError."""
|
||||
from backend.api.features.store.exceptions import VirusDetectedError
|
||||
|
||||
graph_exec_id = "test-exec-123"
|
||||
local_file = "infected.exe"
|
||||
file_content = b"malicious content"
|
||||
|
||||
with patch(
|
||||
"backend.util.file.get_cloud_storage_handler"
|
||||
) as mock_handler_getter, patch(
|
||||
"backend.util.file.scan_content_safe"
|
||||
) as mock_scan, patch(
|
||||
"backend.util.file.Path"
|
||||
) as mock_path_class:
|
||||
|
||||
# Mock cloud storage handler - not a cloud path
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.is_cloud_path.return_value = False
|
||||
mock_handler_getter.return_value = mock_handler
|
||||
|
||||
# Mock virus scanner to detect virus
|
||||
mock_scan.side_effect = VirusDetectedError(
|
||||
"EICAR-Test-File", "File rejected due to virus detection"
|
||||
)
|
||||
|
||||
# Mock file system operations
|
||||
mock_base_path = MagicMock()
|
||||
mock_target_path = MagicMock()
|
||||
mock_resolved_path = MagicMock()
|
||||
|
||||
mock_path_class.return_value = mock_base_path
|
||||
mock_base_path.mkdir = MagicMock()
|
||||
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
||||
mock_target_path.resolve.return_value = mock_resolved_path
|
||||
mock_resolved_path.is_relative_to.return_value = True
|
||||
mock_resolved_path.is_file.return_value = True
|
||||
mock_resolved_path.read_bytes.return_value = file_content
|
||||
|
||||
with pytest.raises(VirusDetectedError):
|
||||
await store_media_file(
|
||||
file=MediaFileType(local_file),
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.util import json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------#
|
||||
# CONSTANTS #
|
||||
# ---------------------------------------------------------------------------#
|
||||
@@ -100,9 +109,17 @@ def _is_objective_message(msg: dict) -> bool:
|
||||
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
|
||||
"""
|
||||
Carefully truncate tool message content while preserving tool structure.
|
||||
Only truncates tool_result content, leaves tool_use intact.
|
||||
Handles both Anthropic-style (list content) and OpenAI-style (string content) tool messages.
|
||||
"""
|
||||
content = msg.get("content")
|
||||
|
||||
# OpenAI-style tool message: role="tool" with string content
|
||||
if msg.get("role") == "tool" and isinstance(content, str):
|
||||
if _tok_len(content, enc) > max_tokens:
|
||||
msg["content"] = _truncate_middle_tokens(content, enc, max_tokens)
|
||||
return
|
||||
|
||||
# Anthropic-style: list content with tool_result items
|
||||
if not isinstance(content, list):
|
||||
return
|
||||
|
||||
@@ -140,141 +157,6 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||
# ---------------------------------------------------------------------------#
|
||||
|
||||
|
||||
def compress_prompt(
|
||||
messages: list[dict],
|
||||
target_tokens: int,
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
reserve: int = 2_048,
|
||||
start_cap: int = 8_192,
|
||||
floor_cap: int = 128,
|
||||
lossy_ok: bool = True,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Shrink *messages* so that::
|
||||
|
||||
token_count(prompt) + reserve ≤ target_tokens
|
||||
|
||||
Strategy
|
||||
--------
|
||||
1. **Token-aware truncation** – progressively halve a per-message cap
|
||||
(`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the
|
||||
*content* of every message except the first and last. Tool shells
|
||||
are included: we keep the envelope but shorten huge payloads.
|
||||
2. **Middle-out deletion** – if still over the limit, delete whole
|
||||
messages working outward from the centre, **skipping** any message
|
||||
that contains ``tool_calls`` or has ``role == "tool"``.
|
||||
3. **Last-chance trim** – if still too big, truncate the *first* and
|
||||
*last* message bodies down to `floor_cap` tokens.
|
||||
4. If the prompt is *still* too large:
|
||||
• raise ``ValueError`` when ``lossy_ok == False`` (default)
|
||||
• return the partially-trimmed prompt when ``lossy_ok == True``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
messages Complete chat history (will be deep-copied).
|
||||
model Model name; passed to tiktoken to pick the right
|
||||
tokenizer (gpt-4o → 'o200k_base', others fallback).
|
||||
target_tokens Hard ceiling for prompt size **excluding** the model's
|
||||
forthcoming answer.
|
||||
reserve How many tokens you want to leave available for that
|
||||
answer (`max_tokens` in your subsequent completion call).
|
||||
start_cap Initial per-message truncation ceiling (tokens).
|
||||
floor_cap Lowest cap we'll accept before moving to deletions.
|
||||
lossy_ok If *True* return best-effort prompt instead of raising
|
||||
after all trim passes have been exhausted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[dict] – A *new* messages list that abides by the rules above.
|
||||
"""
|
||||
enc = encoding_for_model(model) # best-match tokenizer
|
||||
msgs = deepcopy(messages) # never mutate caller
|
||||
|
||||
def total_tokens() -> int:
|
||||
"""Current size of *msgs* in tokens."""
|
||||
return sum(_msg_tokens(m, enc) for m in msgs)
|
||||
|
||||
original_token_count = total_tokens()
|
||||
|
||||
if original_token_count + reserve <= target_tokens:
|
||||
return msgs
|
||||
|
||||
# ---- STEP 0 : normalise content --------------------------------------
|
||||
# Convert non-string payloads to strings so token counting is coherent.
|
||||
for i, m in enumerate(msgs):
|
||||
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
||||
if _is_tool_message(m):
|
||||
continue
|
||||
|
||||
# Keep first and last messages intact (unless they're tool messages)
|
||||
if i == 0 or i == len(msgs) - 1:
|
||||
continue
|
||||
|
||||
# Reasonable 20k-char ceiling prevents pathological blobs
|
||||
content_str = json.dumps(m["content"], separators=(",", ":"))
|
||||
if len(content_str) > 20_000:
|
||||
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
|
||||
m["content"] = content_str
|
||||
|
||||
# ---- STEP 1 : token-aware truncation ---------------------------------
|
||||
cap = start_cap
|
||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||
for m in msgs[1:-1]: # keep first & last intact
|
||||
if _is_tool_message(m):
|
||||
# For tool messages, only truncate tool result content, preserve structure
|
||||
_truncate_tool_message_content(m, enc, cap)
|
||||
continue
|
||||
|
||||
if _is_objective_message(m):
|
||||
# Never truncate objective messages - they contain the core task
|
||||
continue
|
||||
|
||||
content = m.get("content") or ""
|
||||
if _tok_len(content, enc) > cap:
|
||||
m["content"] = _truncate_middle_tokens(content, enc, cap)
|
||||
cap //= 2 # tighten the screw
|
||||
|
||||
# ---- STEP 2 : middle-out deletion -----------------------------------
|
||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||
# Identify all deletable messages (not first/last, not tool messages, not objective messages)
|
||||
deletable_indices = []
|
||||
for i in range(1, len(msgs) - 1): # Skip first and last
|
||||
if not _is_tool_message(msgs[i]) and not _is_objective_message(msgs[i]):
|
||||
deletable_indices.append(i)
|
||||
|
||||
if not deletable_indices:
|
||||
break # nothing more we can drop
|
||||
|
||||
# Delete from center outward - find the index closest to center
|
||||
centre = len(msgs) // 2
|
||||
to_delete = min(deletable_indices, key=lambda i: abs(i - centre))
|
||||
del msgs[to_delete]
|
||||
|
||||
# ---- STEP 3 : final safety-net trim on first & last ------------------
|
||||
cap = start_cap
|
||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||
for idx in (0, -1): # first and last
|
||||
if _is_tool_message(msgs[idx]):
|
||||
# For tool messages at first/last position, truncate tool result content only
|
||||
_truncate_tool_message_content(msgs[idx], enc, cap)
|
||||
continue
|
||||
|
||||
text = msgs[idx].get("content") or ""
|
||||
if _tok_len(text, enc) > cap:
|
||||
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
|
||||
cap //= 2 # tighten the screw
|
||||
|
||||
# ---- STEP 4 : success or fail-gracefully -----------------------------
|
||||
if total_tokens() + reserve > target_tokens and not lossy_ok:
|
||||
raise ValueError(
|
||||
"compress_prompt: prompt still exceeds budget "
|
||||
f"({total_tokens() + reserve} > {target_tokens})."
|
||||
)
|
||||
|
||||
return msgs
|
||||
|
||||
|
||||
def estimate_token_count(
|
||||
messages: list[dict],
|
||||
*,
|
||||
@@ -293,7 +175,8 @@ def estimate_token_count(
|
||||
-------
|
||||
int – Token count.
|
||||
"""
|
||||
enc = encoding_for_model(model) # best-match tokenizer
|
||||
token_model = _normalize_model_for_tokenizer(model)
|
||||
enc = encoding_for_model(token_model)
|
||||
return sum(_msg_tokens(m, enc) for m in messages)
|
||||
|
||||
|
||||
@@ -315,6 +198,543 @@ def estimate_token_count_str(
|
||||
-------
|
||||
int – Token count.
|
||||
"""
|
||||
enc = encoding_for_model(model) # best-match tokenizer
|
||||
token_model = _normalize_model_for_tokenizer(model)
|
||||
enc = encoding_for_model(token_model)
|
||||
text = json.dumps(text) if not isinstance(text, str) else text
|
||||
return _tok_len(text, enc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------#
|
||||
# UNIFIED CONTEXT COMPRESSION #
|
||||
# ---------------------------------------------------------------------------#
|
||||
|
||||
# Default thresholds
|
||||
DEFAULT_TOKEN_THRESHOLD = 120_000
|
||||
DEFAULT_KEEP_RECENT = 15
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressResult:
|
||||
"""Result of context compression."""
|
||||
|
||||
messages: list[dict]
|
||||
token_count: int
|
||||
was_compacted: bool
|
||||
error: str | None = None
|
||||
original_token_count: int = 0
|
||||
messages_summarized: int = 0
|
||||
messages_dropped: int = 0
|
||||
|
||||
|
||||
def _normalize_model_for_tokenizer(model: str) -> str:
|
||||
"""Normalize model name for tiktoken tokenizer selection."""
|
||||
if "/" in model:
|
||||
model = model.split("/")[-1]
|
||||
if "claude" in model.lower() or not any(
|
||||
known in model.lower() for known in ["gpt", "o1", "chatgpt", "text-"]
|
||||
):
|
||||
return "gpt-4o"
|
||||
return model
|
||||
|
||||
|
||||
def _extract_tool_call_ids_from_message(msg: dict) -> set[str]:
|
||||
"""
|
||||
Extract tool_call IDs from an assistant message.
|
||||
|
||||
Supports both formats:
|
||||
- OpenAI: {"role": "assistant", "tool_calls": [{"id": "..."}]}
|
||||
- Anthropic: {"role": "assistant", "content": [{"type": "tool_use", "id": "..."}]}
|
||||
|
||||
Returns:
|
||||
Set of tool_call IDs found in the message.
|
||||
"""
|
||||
ids: set[str] = set()
|
||||
if msg.get("role") != "assistant":
|
||||
return ids
|
||||
|
||||
# OpenAI format: tool_calls array
|
||||
if msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id:
|
||||
ids.add(tc_id)
|
||||
|
||||
# Anthropic format: content list with tool_use blocks
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
tc_id = block.get("id")
|
||||
if tc_id:
|
||||
ids.add(tc_id)
|
||||
|
||||
return ids
|
||||
|
||||
|
||||
def _extract_tool_response_ids_from_message(msg: dict) -> set[str]:
|
||||
"""
|
||||
Extract tool_call IDs that this message is responding to.
|
||||
|
||||
Supports both formats:
|
||||
- OpenAI: {"role": "tool", "tool_call_id": "..."}
|
||||
- Anthropic: {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "..."}]}
|
||||
|
||||
Returns:
|
||||
Set of tool_call IDs this message responds to.
|
||||
"""
|
||||
ids: set[str] = set()
|
||||
|
||||
# OpenAI format: role=tool with tool_call_id
|
||||
if msg.get("role") == "tool":
|
||||
tc_id = msg.get("tool_call_id")
|
||||
if tc_id:
|
||||
ids.add(tc_id)
|
||||
|
||||
# Anthropic format: content list with tool_result blocks
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
tc_id = block.get("tool_use_id")
|
||||
if tc_id:
|
||||
ids.add(tc_id)
|
||||
|
||||
return ids
|
||||
|
||||
|
||||
def _is_tool_response_message(msg: dict) -> bool:
|
||||
"""Check if message is a tool response (OpenAI or Anthropic format)."""
|
||||
# OpenAI format
|
||||
if msg.get("role") == "tool":
|
||||
return True
|
||||
# Anthropic format
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _remove_orphan_tool_responses(
|
||||
messages: list[dict], orphan_ids: set[str]
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Remove tool response messages/blocks that reference orphan tool_call IDs.
|
||||
|
||||
Supports both OpenAI and Anthropic formats.
|
||||
For Anthropic messages with mixed valid/orphan tool_result blocks,
|
||||
filters out only the orphan blocks instead of dropping the entire message.
|
||||
"""
|
||||
result = []
|
||||
for msg in messages:
|
||||
# OpenAI format: role=tool - drop entire message if orphan
|
||||
if msg.get("role") == "tool":
|
||||
tc_id = msg.get("tool_call_id")
|
||||
if tc_id and tc_id in orphan_ids:
|
||||
continue
|
||||
result.append(msg)
|
||||
continue
|
||||
|
||||
# Anthropic format: content list may have mixed tool_result blocks
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
has_tool_results = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result" for b in content
|
||||
)
|
||||
if has_tool_results:
|
||||
# Filter out orphan tool_result blocks, keep valid ones
|
||||
filtered_content = [
|
||||
block
|
||||
for block in content
|
||||
if not (
|
||||
isinstance(block, dict)
|
||||
and block.get("type") == "tool_result"
|
||||
and block.get("tool_use_id") in orphan_ids
|
||||
)
|
||||
]
|
||||
# Only keep message if it has remaining content
|
||||
if filtered_content:
|
||||
msg = msg.copy()
|
||||
msg["content"] = filtered_content
|
||||
result.append(msg)
|
||||
continue
|
||||
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_tool_pairs_intact(
|
||||
recent_messages: list[dict],
|
||||
all_messages: list[dict],
|
||||
start_index: int,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Ensure tool_call/tool_response pairs stay together after slicing.
|
||||
|
||||
When slicing messages for context compaction, a naive slice can separate
|
||||
an assistant message containing tool_calls from its corresponding tool
|
||||
response messages. This causes API validation errors (e.g., Anthropic's
|
||||
"unexpected tool_use_id found in tool_result blocks").
|
||||
|
||||
This function checks for orphan tool responses in the slice and extends
|
||||
backwards to include their corresponding assistant messages.
|
||||
|
||||
Supports both formats:
|
||||
- OpenAI: tool_calls array + role="tool" responses
|
||||
- Anthropic: tool_use blocks + tool_result blocks
|
||||
|
||||
Args:
|
||||
recent_messages: The sliced messages to validate
|
||||
all_messages: The complete message list (for looking up missing assistants)
|
||||
start_index: The index in all_messages where recent_messages begins
|
||||
|
||||
Returns:
|
||||
A potentially extended list of messages with tool pairs intact
|
||||
"""
|
||||
if not recent_messages:
|
||||
return recent_messages
|
||||
|
||||
# Collect all tool_call_ids from assistant messages in the slice
|
||||
available_tool_call_ids: set[str] = set()
|
||||
for msg in recent_messages:
|
||||
available_tool_call_ids |= _extract_tool_call_ids_from_message(msg)
|
||||
|
||||
# Find orphan tool responses (responses whose tool_call_id is missing)
|
||||
orphan_tool_call_ids: set[str] = set()
|
||||
for msg in recent_messages:
|
||||
response_ids = _extract_tool_response_ids_from_message(msg)
|
||||
for tc_id in response_ids:
|
||||
if tc_id not in available_tool_call_ids:
|
||||
orphan_tool_call_ids.add(tc_id)
|
||||
|
||||
if not orphan_tool_call_ids:
|
||||
# No orphans, slice is valid
|
||||
return recent_messages
|
||||
|
||||
# Find the assistant messages that contain the orphan tool_call_ids
|
||||
# Search backwards from start_index in all_messages
|
||||
messages_to_prepend: list[dict] = []
|
||||
for i in range(start_index - 1, -1, -1):
|
||||
msg = all_messages[i]
|
||||
msg_tool_ids = _extract_tool_call_ids_from_message(msg)
|
||||
if msg_tool_ids & orphan_tool_call_ids:
|
||||
# This assistant message has tool_calls we need
|
||||
# Also collect its contiguous tool responses that follow it
|
||||
assistant_and_responses: list[dict] = [msg]
|
||||
|
||||
# Scan forward from this assistant to collect tool responses
|
||||
for j in range(i + 1, start_index):
|
||||
following_msg = all_messages[j]
|
||||
following_response_ids = _extract_tool_response_ids_from_message(
|
||||
following_msg
|
||||
)
|
||||
if following_response_ids and following_response_ids & msg_tool_ids:
|
||||
assistant_and_responses.append(following_msg)
|
||||
elif not _is_tool_response_message(following_msg):
|
||||
# Stop at first non-tool-response message
|
||||
break
|
||||
|
||||
# Prepend the assistant and its tool responses (maintain order)
|
||||
messages_to_prepend = assistant_and_responses + messages_to_prepend
|
||||
# Mark these as found
|
||||
orphan_tool_call_ids -= msg_tool_ids
|
||||
# Also add this assistant's tool_call_ids to available set
|
||||
available_tool_call_ids |= msg_tool_ids
|
||||
|
||||
if not orphan_tool_call_ids:
|
||||
# Found all missing assistants
|
||||
break
|
||||
|
||||
if orphan_tool_call_ids:
|
||||
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
||||
# This shouldn't happen in normal operation but handles edge cases
|
||||
logger.warning(
|
||||
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
||||
"Removing orphan tool responses."
|
||||
)
|
||||
recent_messages = _remove_orphan_tool_responses(
|
||||
recent_messages, orphan_tool_call_ids
|
||||
)
|
||||
|
||||
if messages_to_prepend:
|
||||
logger.info(
|
||||
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
||||
f"tool_call/tool_response pairs"
|
||||
)
|
||||
return messages_to_prepend + recent_messages
|
||||
|
||||
return recent_messages
|
||||
|
||||
|
||||
async def _summarize_messages_llm(
|
||||
messages: list[dict],
|
||||
client: AsyncOpenAI,
|
||||
model: str,
|
||||
timeout: float = 30.0,
|
||||
) -> str:
|
||||
"""Summarize messages using an LLM."""
|
||||
conversation = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
if content and role in ("user", "assistant", "tool"):
|
||||
conversation.append(f"{role.upper()}: {content}")
|
||||
|
||||
conversation_text = "\n\n".join(conversation)
|
||||
|
||||
if not conversation_text:
|
||||
return "No conversation history available."
|
||||
|
||||
# Limit to ~100k chars for safety
|
||||
MAX_CHARS = 100_000
|
||||
if len(conversation_text) > MAX_CHARS:
|
||||
conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]"
|
||||
|
||||
response = await client.with_options(timeout=timeout).chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"Create a detailed summary of the conversation so far. "
|
||||
"This summary will be used as context when continuing the conversation.\n\n"
|
||||
"Before writing the summary, analyze each message chronologically to identify:\n"
|
||||
"- User requests and their explicit goals\n"
|
||||
"- Your approach and key decisions made\n"
|
||||
"- Technical specifics (file names, tool outputs, function signatures)\n"
|
||||
"- Errors encountered and resolutions applied\n\n"
|
||||
"You MUST include ALL of the following sections:\n\n"
|
||||
"## 1. Primary Request and Intent\n"
|
||||
"The user's explicit goals and what they are trying to accomplish.\n\n"
|
||||
"## 2. Key Technical Concepts\n"
|
||||
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
|
||||
"## 3. Files and Resources Involved\n"
|
||||
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
|
||||
"## 4. Errors and Fixes\n"
|
||||
"Problems encountered, error messages, and their resolutions. "
|
||||
"Include any user feedback on fixes.\n\n"
|
||||
"## 5. Problem Solving\n"
|
||||
"Issues that have been resolved and how they were addressed.\n\n"
|
||||
"## 6. All User Messages\n"
|
||||
"A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n"
|
||||
"## 7. Pending Tasks\n"
|
||||
"Work items the user explicitly requested that have not yet been completed.\n\n"
|
||||
"## 8. Current Work\n"
|
||||
"Precise description of what was being worked on most recently, including relevant context.\n\n"
|
||||
"## 9. Next Steps\n"
|
||||
"What should happen next, aligned with the user's most recent requests. "
|
||||
"Include verbatim quotes of recent instructions if relevant."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
|
||||
],
|
||||
max_tokens=1500,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content or "No summary available."
|
||||
|
||||
|
||||
async def compress_context(
|
||||
messages: list[dict],
|
||||
target_tokens: int = DEFAULT_TOKEN_THRESHOLD,
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
client: AsyncOpenAI | None = None,
|
||||
keep_recent: int = DEFAULT_KEEP_RECENT,
|
||||
reserve: int = 2_048,
|
||||
start_cap: int = 8_192,
|
||||
floor_cap: int = 128,
|
||||
) -> CompressResult:
|
||||
"""
|
||||
Unified context compression that combines summarization and truncation strategies.
|
||||
|
||||
Strategy (in order):
|
||||
1. **LLM summarization** – If client provided, summarize old messages into a
|
||||
single context message while keeping recent messages intact. This is the
|
||||
primary strategy for chat service.
|
||||
2. **Content truncation** – Progressively halve a per-message cap and truncate
|
||||
bloated message content (tool outputs, large pastes). Preserves all messages
|
||||
but shortens their content. Primary strategy when client=None (LLM blocks).
|
||||
3. **Middle-out deletion** – Delete whole messages one at a time from the center
|
||||
outward, skipping tool messages and objective messages.
|
||||
4. **First/last trim** – Truncate first and last message content as last resort.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
messages Complete chat history (will be deep-copied).
|
||||
target_tokens Hard ceiling for prompt size.
|
||||
model Model name for tokenization and summarization.
|
||||
client AsyncOpenAI client. If provided, enables LLM summarization
|
||||
as the first strategy. If None, skips to truncation strategies.
|
||||
keep_recent Number of recent messages to preserve during summarization.
|
||||
reserve Tokens to reserve for model response.
|
||||
start_cap Initial per-message truncation ceiling (tokens).
|
||||
floor_cap Lowest cap before moving to deletions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
CompressResult with compressed messages and metadata.
|
||||
"""
|
||||
# Guard clause for empty messages
|
||||
if not messages:
|
||||
return CompressResult(
|
||||
messages=[],
|
||||
token_count=0,
|
||||
was_compacted=False,
|
||||
original_token_count=0,
|
||||
)
|
||||
|
||||
token_model = _normalize_model_for_tokenizer(model)
|
||||
enc = encoding_for_model(token_model)
|
||||
msgs = deepcopy(messages)
|
||||
|
||||
def total_tokens() -> int:
|
||||
return sum(_msg_tokens(m, enc) for m in msgs)
|
||||
|
||||
original_count = total_tokens()
|
||||
|
||||
# Already under limit
|
||||
if original_count + reserve <= target_tokens:
|
||||
return CompressResult(
|
||||
messages=msgs,
|
||||
token_count=original_count,
|
||||
was_compacted=False,
|
||||
original_token_count=original_count,
|
||||
)
|
||||
|
||||
messages_summarized = 0
|
||||
messages_dropped = 0
|
||||
|
||||
# ---- STEP 1: LLM summarization (if client provided) -------------------
|
||||
# This is the primary compression strategy for chat service.
|
||||
# Summarize old messages while keeping recent ones intact.
|
||||
if client is not None:
|
||||
has_system = len(msgs) > 0 and msgs[0].get("role") == "system"
|
||||
system_msg = msgs[0] if has_system else None
|
||||
|
||||
# Calculate old vs recent messages
|
||||
if has_system:
|
||||
if len(msgs) > keep_recent + 1:
|
||||
old_msgs = msgs[1:-keep_recent]
|
||||
recent_msgs = msgs[-keep_recent:]
|
||||
else:
|
||||
old_msgs = []
|
||||
recent_msgs = msgs[1:] if len(msgs) > 1 else []
|
||||
else:
|
||||
if len(msgs) > keep_recent:
|
||||
old_msgs = msgs[:-keep_recent]
|
||||
recent_msgs = msgs[-keep_recent:]
|
||||
else:
|
||||
old_msgs = []
|
||||
recent_msgs = msgs
|
||||
|
||||
# Ensure tool pairs stay intact
|
||||
slice_start = max(0, len(msgs) - keep_recent)
|
||||
recent_msgs = _ensure_tool_pairs_intact(recent_msgs, msgs, slice_start)
|
||||
|
||||
if old_msgs:
|
||||
try:
|
||||
summary_text = await _summarize_messages_llm(old_msgs, client, model)
|
||||
summary_msg = {
|
||||
"role": "assistant",
|
||||
"content": f"[Previous conversation summary — for context only]: {summary_text}",
|
||||
}
|
||||
messages_summarized = len(old_msgs)
|
||||
|
||||
if has_system:
|
||||
msgs = [system_msg, summary_msg] + recent_msgs
|
||||
else:
|
||||
msgs = [summary_msg] + recent_msgs
|
||||
|
||||
logger.info(
|
||||
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
|
||||
f"summarized {messages_summarized} messages"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Summarization failed, continuing with truncation: {e}")
|
||||
# Fall through to content truncation
|
||||
|
||||
# ---- STEP 2: Normalize content ----------------------------------------
|
||||
# Convert non-string payloads to strings so token counting is coherent.
|
||||
# Always run this before truncation to ensure consistent token counting.
|
||||
for i, m in enumerate(msgs):
|
||||
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
||||
if _is_tool_message(m):
|
||||
continue
|
||||
if i == 0 or i == len(msgs) - 1:
|
||||
continue
|
||||
content_str = json.dumps(m["content"], separators=(",", ":"))
|
||||
if len(content_str) > 20_000:
|
||||
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
|
||||
m["content"] = content_str
|
||||
|
||||
# ---- STEP 3: Token-aware content truncation ---------------------------
|
||||
# Progressively halve per-message cap and truncate bloated content.
|
||||
# This preserves all messages but shortens their content.
|
||||
cap = start_cap
|
||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||
for m in msgs[1:-1]:
|
||||
if _is_tool_message(m):
|
||||
_truncate_tool_message_content(m, enc, cap)
|
||||
continue
|
||||
if _is_objective_message(m):
|
||||
continue
|
||||
content = m.get("content") or ""
|
||||
if _tok_len(content, enc) > cap:
|
||||
m["content"] = _truncate_middle_tokens(content, enc, cap)
|
||||
cap //= 2
|
||||
|
||||
# ---- STEP 4: Middle-out deletion --------------------------------------
|
||||
# Delete messages one at a time from the center outward.
|
||||
# This is more granular than dropping all old messages at once.
|
||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||
deletable: list[int] = []
|
||||
for i in range(1, len(msgs) - 1):
|
||||
msg = msgs[i]
|
||||
if (
|
||||
msg is not None
|
||||
and not _is_tool_message(msg)
|
||||
and not _is_objective_message(msg)
|
||||
):
|
||||
deletable.append(i)
|
||||
if not deletable:
|
||||
break
|
||||
centre = len(msgs) // 2
|
||||
to_delete = min(deletable, key=lambda i: abs(i - centre))
|
||||
del msgs[to_delete]
|
||||
messages_dropped += 1
|
||||
|
||||
# ---- STEP 5: Final trim on first/last ---------------------------------
|
||||
cap = start_cap
|
||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||
for idx in (0, -1):
|
||||
msg = msgs[idx]
|
||||
if msg is None:
|
||||
continue
|
||||
if _is_tool_message(msg):
|
||||
_truncate_tool_message_content(msg, enc, cap)
|
||||
continue
|
||||
text = msg.get("content") or ""
|
||||
if _tok_len(text, enc) > cap:
|
||||
msg["content"] = _truncate_middle_tokens(text, enc, cap)
|
||||
cap //= 2
|
||||
|
||||
# Filter out any None values that may have been introduced
|
||||
final_msgs: list[dict] = [m for m in msgs if m is not None]
|
||||
final_count = sum(_msg_tokens(m, enc) for m in final_msgs)
|
||||
error = None
|
||||
if final_count + reserve > target_tokens:
|
||||
error = f"Could not compress below target ({final_count + reserve} > {target_tokens})"
|
||||
logger.warning(error)
|
||||
|
||||
return CompressResult(
|
||||
messages=final_msgs,
|
||||
token_count=final_count,
|
||||
was_compacted=True,
|
||||
error=error,
|
||||
original_token_count=original_count,
|
||||
messages_summarized=messages_summarized,
|
||||
messages_dropped=messages_dropped,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
"""Tests for prompt utility functions, especially tool call token counting."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.prompt import _msg_tokens, estimate_token_count
|
||||
from backend.util.prompt import (
|
||||
CompressResult,
|
||||
_ensure_tool_pairs_intact,
|
||||
_msg_tokens,
|
||||
_normalize_model_for_tokenizer,
|
||||
_truncate_middle_tokens,
|
||||
_truncate_tool_message_content,
|
||||
compress_context,
|
||||
estimate_token_count,
|
||||
)
|
||||
|
||||
|
||||
class TestMsgTokens:
|
||||
@@ -276,3 +287,690 @@ class TestEstimateTokenCount:
|
||||
|
||||
assert total_tokens == expected_total
|
||||
assert total_tokens > 20 # Should be substantial
|
||||
|
||||
|
||||
class TestNormalizeModelForTokenizer:
|
||||
"""Test model name normalization for tiktoken."""
|
||||
|
||||
def test_openai_models_unchanged(self):
|
||||
"""Test that OpenAI models are returned as-is."""
|
||||
assert _normalize_model_for_tokenizer("gpt-4o") == "gpt-4o"
|
||||
assert _normalize_model_for_tokenizer("gpt-4") == "gpt-4"
|
||||
assert _normalize_model_for_tokenizer("gpt-3.5-turbo") == "gpt-3.5-turbo"
|
||||
|
||||
def test_claude_models_normalized(self):
|
||||
"""Test that Claude models are normalized to gpt-4o."""
|
||||
assert _normalize_model_for_tokenizer("claude-3-opus") == "gpt-4o"
|
||||
assert _normalize_model_for_tokenizer("claude-3-sonnet") == "gpt-4o"
|
||||
assert _normalize_model_for_tokenizer("anthropic/claude-3-haiku") == "gpt-4o"
|
||||
|
||||
def test_openrouter_paths_extracted(self):
|
||||
"""Test that OpenRouter model paths are handled."""
|
||||
assert _normalize_model_for_tokenizer("openai/gpt-4o") == "gpt-4o"
|
||||
assert _normalize_model_for_tokenizer("anthropic/claude-3-opus") == "gpt-4o"
|
||||
|
||||
def test_unknown_models_default_to_gpt4o(self):
|
||||
"""Test that unknown models default to gpt-4o."""
|
||||
assert _normalize_model_for_tokenizer("some-random-model") == "gpt-4o"
|
||||
assert _normalize_model_for_tokenizer("llama-3-70b") == "gpt-4o"
|
||||
|
||||
|
||||
class TestTruncateToolMessageContent:
|
||||
"""Test tool message content truncation."""
|
||||
|
||||
@pytest.fixture
|
||||
def enc(self):
|
||||
return encoding_for_model("gpt-4o")
|
||||
|
||||
def test_truncate_openai_tool_message(self, enc):
|
||||
"""Test truncation of OpenAI-style tool message with string content."""
|
||||
long_content = "x" * 10000
|
||||
msg = {"role": "tool", "tool_call_id": "call_123", "content": long_content}
|
||||
|
||||
_truncate_tool_message_content(msg, enc, max_tokens=100)
|
||||
|
||||
# Content should be truncated
|
||||
assert len(msg["content"]) < len(long_content)
|
||||
assert "…" in msg["content"] # Has ellipsis marker
|
||||
|
||||
def test_truncate_anthropic_tool_result(self, enc):
|
||||
"""Test truncation of Anthropic-style tool_result."""
|
||||
long_content = "y" * 10000
|
||||
msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_123",
|
||||
"content": long_content,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
_truncate_tool_message_content(msg, enc, max_tokens=100)
|
||||
|
||||
# Content should be truncated
|
||||
result_content = msg["content"][0]["content"]
|
||||
assert len(result_content) < len(long_content)
|
||||
assert "…" in result_content
|
||||
|
||||
def test_preserve_tool_use_blocks(self, enc):
|
||||
"""Test that tool_use blocks are not truncated."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_123",
|
||||
"name": "some_function",
|
||||
"input": {"key": "value" * 1000}, # Large input
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
original = json.dumps(msg["content"][0]["input"])
|
||||
_truncate_tool_message_content(msg, enc, max_tokens=10)
|
||||
|
||||
# tool_use should be unchanged
|
||||
assert json.dumps(msg["content"][0]["input"]) == original
|
||||
|
||||
def test_no_truncation_when_under_limit(self, enc):
|
||||
"""Test that short content is not modified."""
|
||||
msg = {"role": "tool", "tool_call_id": "call_123", "content": "Short content"}
|
||||
|
||||
original = msg["content"]
|
||||
_truncate_tool_message_content(msg, enc, max_tokens=1000)
|
||||
|
||||
assert msg["content"] == original
|
||||
|
||||
|
||||
class TestTruncateMiddleTokens:
|
||||
"""Test middle truncation of text."""
|
||||
|
||||
@pytest.fixture
|
||||
def enc(self):
|
||||
return encoding_for_model("gpt-4o")
|
||||
|
||||
def test_truncates_long_text(self, enc):
|
||||
"""Test that long text is truncated with ellipsis in middle."""
|
||||
long_text = "word " * 1000
|
||||
result = _truncate_middle_tokens(long_text, enc, max_tok=50)
|
||||
|
||||
assert len(enc.encode(result)) <= 52 # Allow some slack for ellipsis
|
||||
assert "…" in result
|
||||
assert result.startswith("word") # Head preserved
|
||||
assert result.endswith("word ") # Tail preserved
|
||||
|
||||
def test_preserves_short_text(self, enc):
|
||||
"""Test that short text is not modified."""
|
||||
short_text = "Hello world"
|
||||
result = _truncate_middle_tokens(short_text, enc, max_tok=100)
|
||||
|
||||
assert result == short_text
|
||||
|
||||
|
||||
class TestEnsureToolPairsIntact:
|
||||
"""Test tool call/response pair preservation for both OpenAI and Anthropic formats."""
|
||||
|
||||
# ---- OpenAI Format Tests ----
|
||||
|
||||
def test_openai_adds_missing_tool_call(self):
|
||||
"""Test that orphaned OpenAI tool_response gets its tool_call prepended."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
|
||||
{"role": "user", "content": "Thanks!"},
|
||||
]
|
||||
# Recent messages start at index 2 (the tool response)
|
||||
recent = [all_msgs[2], all_msgs[3]]
|
||||
start_index = 2
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the tool_call message
|
||||
assert len(result) == 3
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert "tool_calls" in result[0]
|
||||
|
||||
def test_openai_keeps_complete_pairs(self):
|
||||
"""Test that complete OpenAI pairs are unchanged."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "System"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
|
||||
]
|
||||
recent = all_msgs[1:] # Include both tool_call and response
|
||||
start_index = 1
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
assert len(result) == 2 # No messages added
|
||||
|
||||
def test_openai_multiple_tool_calls(self):
|
||||
"""Test multiple OpenAI tool calls in one assistant message."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "System"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {"name": "f1"}},
|
||||
{"id": "call_2", "type": "function", "function": {"name": "f2"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "result1"},
|
||||
{"role": "tool", "tool_call_id": "call_2", "content": "result2"},
|
||||
{"role": "user", "content": "Thanks!"},
|
||||
]
|
||||
# Recent messages start at index 2 (first tool response)
|
||||
recent = [all_msgs[2], all_msgs[3], all_msgs[4]]
|
||||
start_index = 2
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the assistant message with both tool_calls
|
||||
assert len(result) == 4
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert len(result[0]["tool_calls"]) == 2
|
||||
|
||||
# ---- Anthropic Format Tests ----
|
||||
|
||||
def test_anthropic_adds_missing_tool_use(self):
|
||||
"""Test that orphaned Anthropic tool_result gets its tool_use prepended."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "SF"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_123",
|
||||
"content": "22°C and sunny",
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Thanks!"},
|
||||
]
|
||||
# Recent messages start at index 2 (the tool_result)
|
||||
recent = [all_msgs[2], all_msgs[3]]
|
||||
start_index = 2
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the tool_use message
|
||||
assert len(result) == 3
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert result[0]["content"][0]["type"] == "tool_use"
|
||||
|
||||
def test_anthropic_keeps_complete_pairs(self):
|
||||
"""Test that complete Anthropic pairs are unchanged."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "System"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_456",
|
||||
"name": "calculator",
|
||||
"input": {"expr": "2+2"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_456",
|
||||
"content": "4",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
recent = all_msgs[1:] # Include both tool_use and result
|
||||
start_index = 1
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
assert len(result) == 2 # No messages added
|
||||
|
||||
def test_anthropic_multiple_tool_uses(self):
|
||||
"""Test multiple Anthropic tool_use blocks in one message."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "System"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Let me check both..."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_1",
|
||||
"name": "get_weather",
|
||||
"input": {"city": "NYC"},
|
||||
},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_2",
|
||||
"name": "get_weather",
|
||||
"input": {"city": "LA"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_1",
|
||||
"content": "Cold",
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_2",
|
||||
"content": "Warm",
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Thanks!"},
|
||||
]
|
||||
# Recent messages start at index 2 (tool_result)
|
||||
recent = [all_msgs[2], all_msgs[3]]
|
||||
start_index = 2
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the assistant message with both tool_uses
|
||||
assert len(result) == 3
|
||||
assert result[0]["role"] == "assistant"
|
||||
tool_use_count = sum(
|
||||
1 for b in result[0]["content"] if b.get("type") == "tool_use"
|
||||
)
|
||||
assert tool_use_count == 2
|
||||
|
||||
# ---- Mixed/Edge Case Tests ----
|
||||
|
||||
def test_anthropic_with_type_message_field(self):
|
||||
"""Test Anthropic format with 'type': 'message' field (smart_decision_maker style)."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_abc",
|
||||
"name": "search",
|
||||
"input": {"q": "test"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message", # Extra field from smart_decision_maker
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_abc",
|
||||
"content": "Found results",
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Thanks!"},
|
||||
]
|
||||
# Recent messages start at index 2 (the tool_result with 'type': 'message')
|
||||
recent = [all_msgs[2], all_msgs[3]]
|
||||
start_index = 2
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the tool_use message
|
||||
assert len(result) == 3
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert result[0]["content"][0]["type"] == "tool_use"
|
||||
|
||||
def test_handles_no_tool_messages(self):
|
||||
"""Test messages without tool calls."""
|
||||
all_msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
recent = all_msgs
|
||||
start_index = 0
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
assert result == all_msgs
|
||||
|
||||
def test_handles_empty_messages(self):
|
||||
"""Test empty message list."""
|
||||
result = _ensure_tool_pairs_intact([], [], 0)
|
||||
assert result == []
|
||||
|
||||
def test_mixed_text_and_tool_content(self):
|
||||
"""Test Anthropic message with mixed text and tool_use content."""
|
||||
all_msgs = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "I'll help you with that."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_mixed",
|
||||
"name": "search",
|
||||
"input": {"q": "test"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_mixed",
|
||||
"content": "Found results",
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": "Here are the results..."},
|
||||
]
|
||||
# Start from tool_result
|
||||
recent = [all_msgs[1], all_msgs[2]]
|
||||
start_index = 1
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the assistant message with tool_use
|
||||
assert len(result) == 3
|
||||
assert result[0]["content"][0]["type"] == "text"
|
||||
assert result[0]["content"][1]["type"] == "tool_use"
|
||||
|
||||
|
||||
class TestCompressContext:
|
||||
"""Test the async compress_context function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_compression_needed(self):
|
||||
"""Test messages under limit return without compression."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
]
|
||||
|
||||
result = await compress_context(messages, target_tokens=100000)
|
||||
|
||||
assert isinstance(result, CompressResult)
|
||||
assert result.was_compacted is False
|
||||
assert len(result.messages) == 2
|
||||
assert result.error is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncation_without_client(self):
|
||||
"""Test that truncation works without LLM client."""
|
||||
long_content = "x" * 50000
|
||||
messages = [
|
||||
{"role": "system", "content": "System"},
|
||||
{"role": "user", "content": long_content},
|
||||
{"role": "assistant", "content": "Response"},
|
||||
]
|
||||
|
||||
result = await compress_context(
|
||||
messages, target_tokens=1000, client=None, reserve=100
|
||||
)
|
||||
|
||||
assert result.was_compacted is True
|
||||
# Should have truncated without summarization
|
||||
assert result.messages_summarized == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_mocked_llm_client(self):
|
||||
"""Test summarization with mocked LLM client."""
|
||||
# Create many messages to trigger summarization
|
||||
messages = [{"role": "system", "content": "System prompt"}]
|
||||
for i in range(30):
|
||||
messages.append({"role": "user", "content": f"User message {i} " * 100})
|
||||
messages.append(
|
||||
{"role": "assistant", "content": f"Assistant response {i} " * 100}
|
||||
)
|
||||
|
||||
# Mock the AsyncOpenAI client
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Summary of conversation"
|
||||
mock_client.with_options.return_value.chat.completions.create = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
result = await compress_context(
|
||||
messages,
|
||||
target_tokens=5000,
|
||||
client=mock_client,
|
||||
keep_recent=5,
|
||||
reserve=500,
|
||||
)
|
||||
|
||||
assert result.was_compacted is True
|
||||
# Should have attempted summarization
|
||||
assert mock_client.with_options.called or result.messages_summarized > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_tool_pairs(self):
|
||||
"""Test that tool call/response pairs stay together."""
|
||||
messages = [
|
||||
{"role": "system", "content": "System"},
|
||||
{"role": "user", "content": "Do something"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {"name": "func"}}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "Result " * 1000},
|
||||
{"role": "assistant", "content": "Done!"},
|
||||
]
|
||||
|
||||
result = await compress_context(
|
||||
messages, target_tokens=500, client=None, reserve=50
|
||||
)
|
||||
|
||||
# Check that if tool response exists, its call exists too
|
||||
tool_call_ids = set()
|
||||
tool_response_ids = set()
|
||||
for msg in result.messages:
|
||||
if "tool_calls" in msg:
|
||||
for tc in msg["tool_calls"]:
|
||||
tool_call_ids.add(tc["id"])
|
||||
if msg.get("role") == "tool":
|
||||
tool_response_ids.add(msg.get("tool_call_id"))
|
||||
|
||||
# All tool responses should have their calls
|
||||
assert tool_response_ids <= tool_call_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_error_when_cannot_compress(self):
|
||||
"""Test that error is returned when compression fails."""
|
||||
# Single huge message that can't be compressed enough
|
||||
messages = [
|
||||
{"role": "user", "content": "x" * 100000},
|
||||
]
|
||||
|
||||
result = await compress_context(
|
||||
messages, target_tokens=100, client=None, reserve=50
|
||||
)
|
||||
|
||||
# Should have an error since we can't get below 100 tokens
|
||||
assert result.error is not None
|
||||
assert result.was_compacted is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_messages(self):
|
||||
"""Test that empty messages list returns early without error."""
|
||||
result = await compress_context([], target_tokens=1000)
|
||||
|
||||
assert result.messages == []
|
||||
assert result.token_count == 0
|
||||
assert result.was_compacted is False
|
||||
assert result.error is None
|
||||
|
||||
|
||||
class TestRemoveOrphanToolResponses:
|
||||
"""Test _remove_orphan_tool_responses helper function."""
|
||||
|
||||
def test_removes_openai_orphan(self):
|
||||
"""Test removal of orphan OpenAI tool response."""
|
||||
from backend.util.prompt import _remove_orphan_tool_responses
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "call_orphan", "content": "result"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
orphan_ids = {"call_orphan"}
|
||||
|
||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
|
||||
def test_keeps_valid_openai_tool(self):
|
||||
"""Test that valid OpenAI tool responses are kept."""
|
||||
from backend.util.prompt import _remove_orphan_tool_responses
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "call_valid", "content": "result"},
|
||||
]
|
||||
orphan_ids = {"call_other"}
|
||||
|
||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["tool_call_id"] == "call_valid"
|
||||
|
||||
def test_filters_anthropic_mixed_blocks(self):
|
||||
"""Test filtering individual orphan blocks from Anthropic message with mixed valid/orphan."""
|
||||
from backend.util.prompt import _remove_orphan_tool_responses
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_valid",
|
||||
"content": "valid result",
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_orphan",
|
||||
"content": "orphan result",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
orphan_ids = {"toolu_orphan"}
|
||||
|
||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
|
||||
assert len(result) == 1
|
||||
# Should only have the valid tool_result, orphan filtered out
|
||||
assert len(result[0]["content"]) == 1
|
||||
assert result[0]["content"][0]["tool_use_id"] == "toolu_valid"
|
||||
|
||||
def test_removes_anthropic_all_orphan(self):
|
||||
"""Test removal of Anthropic message when all tool_results are orphans."""
|
||||
from backend.util.prompt import _remove_orphan_tool_responses
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_orphan1",
|
||||
"content": "result1",
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_orphan2",
|
||||
"content": "result2",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
orphan_ids = {"toolu_orphan1", "toolu_orphan2"}
|
||||
|
||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
|
||||
# Message should be completely removed since no content left
|
||||
assert len(result) == 0
|
||||
|
||||
def test_preserves_non_tool_messages(self):
|
||||
"""Test that non-tool messages are preserved."""
|
||||
from backend.util.prompt import _remove_orphan_tool_responses
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
orphan_ids = {"some_id"}
|
||||
|
||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
|
||||
assert result == messages
|
||||
|
||||
|
||||
class TestCompressResultDataclass:
|
||||
"""Test CompressResult dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly."""
|
||||
result = CompressResult(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
token_count=10,
|
||||
was_compacted=False,
|
||||
)
|
||||
|
||||
assert result.error is None
|
||||
assert result.original_token_count == 0 # Defaults to 0, not None
|
||||
assert result.messages_summarized == 0
|
||||
assert result.messages_dropped == 0
|
||||
|
||||
def test_all_fields(self):
|
||||
"""Test all fields can be set."""
|
||||
result = CompressResult(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
token_count=100,
|
||||
was_compacted=True,
|
||||
error="Some error",
|
||||
original_token_count=500,
|
||||
messages_summarized=10,
|
||||
messages_dropped=5,
|
||||
)
|
||||
|
||||
assert result.token_count == 100
|
||||
assert result.was_compacted is True
|
||||
assert result.error == "Some error"
|
||||
assert result.original_token_count == 500
|
||||
assert result.messages_summarized == 10
|
||||
assert result.messages_dropped == 5
|
||||
|
||||
@@ -157,12 +157,7 @@ async def validate_url(
|
||||
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
||||
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
||||
"""
|
||||
# Canonicalize URL
|
||||
url = url.strip("/ ").replace("\\", "/")
|
||||
parsed = urlparse(url)
|
||||
if not parsed.scheme:
|
||||
url = f"http://{url}"
|
||||
parsed = urlparse(url)
|
||||
parsed = parse_url(url)
|
||||
|
||||
# Check scheme
|
||||
if parsed.scheme not in ALLOWED_SCHEMES:
|
||||
@@ -220,6 +215,17 @@ async def validate_url(
|
||||
)
|
||||
|
||||
|
||||
def parse_url(url: str) -> URL:
|
||||
"""Canonicalizes and parses a URL string."""
|
||||
url = url.strip("/ ").replace("\\", "/")
|
||||
|
||||
# Ensure scheme is present for proper parsing
|
||||
if not re.match(r"[a-z0-9+.\-]+://", url):
|
||||
url = f"http://{url}"
|
||||
|
||||
return urlparse(url)
|
||||
|
||||
|
||||
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
|
||||
"""
|
||||
Pins a URL to a specific IP address to prevent DNS rebinding attacks.
|
||||
|
||||
@@ -656,6 +656,7 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
e2b_api_key: str = Field(default="", description="E2B API key")
|
||||
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
||||
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
||||
elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key")
|
||||
|
||||
linear_client_id: str = Field(default="", description="Linear client ID")
|
||||
linear_client_secret: str = Field(default="", description="Linear client secret")
|
||||
|
||||
@@ -22,6 +22,7 @@ from backend.data.workspace import (
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -187,6 +188,9 @@ class WorkspaceManager:
|
||||
f"{Config().max_file_size_mb}MB limit"
|
||||
)
|
||||
|
||||
# Virus scan content before persisting (defense in depth)
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
# Determine path with session scoping
|
||||
if path is None:
|
||||
path = f"/{filename}"
|
||||
|
||||
7148
autogpt_platform/backend/poetry.lock
generated
7148
autogpt_platform/backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user