mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-12 07:45:14 -05:00
Compare commits
77 Commits
fix/file-i
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5348d97437 | ||
|
|
113e87a23c | ||
|
|
d09f1532a4 | ||
|
|
a78145505b | ||
|
|
36aeb0b2b3 | ||
|
|
2a189c44c4 | ||
|
|
508759610f | ||
|
|
062fe1aa70 | ||
|
|
2cd0d4fe0f | ||
|
|
1ecae8c87e | ||
|
|
659338f90c | ||
|
|
4df5b7bde7 | ||
|
|
017a00af46 | ||
|
|
6573d987ea | ||
|
|
52650eed1d | ||
|
|
ae8ce8b4ca | ||
|
|
81c1524658 | ||
|
|
f2ead70f3d | ||
|
|
7d4c020a9b | ||
|
|
e596ea87cb | ||
|
|
81f8290f01 | ||
|
|
6467f6734f | ||
|
|
5a30d11416 | ||
|
|
1f4105e8f9 | ||
|
|
caf9ff34e6 | ||
|
|
e8fc8ee623 | ||
|
|
1a16e203b8 | ||
|
|
5dae303ce0 | ||
|
|
6cbfbdd013 | ||
|
|
0c6fa60436 | ||
|
|
b04e916c23 | ||
|
|
1a32ba7d9a | ||
|
|
deccc26f1f | ||
|
|
9e38bd5b78 | ||
|
|
a329831b0b | ||
|
|
98dd1a9480 | ||
|
|
9c7c598c7d | ||
|
|
728c40def5 | ||
|
|
cd64562e1b | ||
|
|
8fddc9d71f | ||
|
|
3d1cd03fc8 | ||
|
|
e7ebe42306 | ||
|
|
e0fab7e34e | ||
|
|
29ee85c86f | ||
|
|
85b6520710 | ||
|
|
bfa942e032 | ||
|
|
11256076d8 | ||
|
|
3ca2387631 | ||
|
|
ed07f02738 | ||
|
|
b121030c94 | ||
|
|
c22c18374d | ||
|
|
e40233a3ac | ||
|
|
3ae5eabf9d | ||
|
|
a077ba9f03 | ||
|
|
5401d54eaa | ||
|
|
5ac89d7c0b | ||
|
|
4f908d5cb3 | ||
|
|
c1aa684743 | ||
|
|
7e5b84cc5c | ||
|
|
09cb313211 | ||
|
|
c026485023 | ||
|
|
1eabc60484 | ||
|
|
f4bf492f24 | ||
|
|
81e48c00a4 | ||
|
|
7dc53071e8 | ||
|
|
4878665c66 | ||
|
|
f7350c797a | ||
|
|
2abbb7fbc8 | ||
|
|
05b60db554 | ||
|
|
cc4839bedb | ||
|
|
dbbff04616 | ||
|
|
e6438b9a76 | ||
|
|
e10ff8d37f | ||
|
|
9538992eaf | ||
|
|
27b72062f2 | ||
|
|
9a79a8d257 | ||
|
|
a9bf08748b |
2
.github/workflows/classic-frontend-ci.yml
vendored
2
.github/workflows/classic-frontend-ci.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||||
if: github.event_name == 'push'
|
if: github.event_name == 'push'
|
||||||
uses: peter-evans/create-pull-request@v7
|
uses: peter-evans/create-pull-request@v8
|
||||||
with:
|
with:
|
||||||
add-paths: classic/frontend/build/web
|
add-paths: classic/frontend/build/web
|
||||||
base: ${{ github.ref_name }}
|
base: ${{ github.ref_name }}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.event.workflow_run.head_branch }}
|
ref: ${{ github.event.workflow_run.head_branch }}
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
@@ -42,7 +42,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Get CI failure details
|
- name: Get CI failure details
|
||||||
id: failure_details
|
id: failure_details
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const run = await github.rest.actions.getWorkflowRun({
|
const run = await github.rest.actions.getWorkflowRun({
|
||||||
|
|||||||
11
.github/workflows/claude-dependabot.yml
vendored
11
.github/workflows/claude-dependabot.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
|||||||
actions: read # Required for CI access
|
actions: read # Required for CI access
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -41,7 +41,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -78,7 +78,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -124,7 +124,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
@@ -309,6 +309,7 @@ jobs:
|
|||||||
uses: anthropics/claude-code-action@v1
|
uses: anthropics/claude-code-action@v1
|
||||||
with:
|
with:
|
||||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||||
|
allowed_bots: "dependabot[bot]"
|
||||||
claude_args: |
|
claude_args: |
|
||||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||||
prompt: |
|
prompt: |
|
||||||
|
|||||||
10
.github/workflows/claude.yml
vendored
10
.github/workflows/claude.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
|||||||
actions: read # Required for CI access
|
actions: read # Required for CI access
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -94,7 +94,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -140,7 +140,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
|
|||||||
2
.github/workflows/codeql.yml
vendored
2
.github/workflows/codeql.yml
vendored
@@ -58,7 +58,7 @@ jobs:
|
|||||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
# Initializes the CodeQL tools for scanning.
|
# Initializes the CodeQL tools for scanning.
|
||||||
- name: Initialize CodeQL
|
- name: Initialize CodeQL
|
||||||
|
|||||||
10
.github/workflows/copilot-setup-steps.yml
vendored
10
.github/workflows/copilot-setup-steps.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
# If you do not check out your code, Copilot will do this for you.
|
# If you do not check out your code, Copilot will do this for you.
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
submodules: true
|
submodules: true
|
||||||
@@ -39,7 +39,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -76,7 +76,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -132,7 +132,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
|
|||||||
4
.github/workflows/docs-block-sync.yml
vendored
4
.github/workflows/docs-block-sync.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
4
.github/workflows/docs-claude-review.yml
vendored
4
.github/workflows/docs-claude-review.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
4
.github/workflows/docs-enhance.yml
vendored
4
.github/workflows/docs-enhance.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger deploy workflow
|
- name: Trigger deploy workflow
|
||||||
uses: peter-evans/repository-dispatch@v3
|
uses: peter-evans/repository-dispatch@v4
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.ref_name || 'master' }}
|
ref: ${{ github.ref_name || 'master' }}
|
||||||
|
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger deploy workflow
|
- name: Trigger deploy workflow
|
||||||
uses: peter-evans/repository-dispatch@v3
|
uses: peter-evans/repository-dispatch@v4
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
|
|||||||
4
.github/workflows/platform-backend-ci.yml
vendored
4
.github/workflows/platform-backend-ci.yml
vendored
@@ -68,7 +68,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
submodules: true
|
submodules: true
|
||||||
@@ -88,7 +88,7 @@ jobs:
|
|||||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ jobs:
|
|||||||
- name: Check comment permissions and deployment status
|
- name: Check comment permissions and deployment status
|
||||||
id: check_status
|
id: check_status
|
||||||
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const commentBody = context.payload.comment.body.trim();
|
const commentBody = context.payload.comment.body.trim();
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post permission denied comment
|
- name: Post permission denied comment
|
||||||
if: steps.check_status.outputs.permission_denied == 'true'
|
if: steps.check_status.outputs.permission_denied == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -68,7 +68,7 @@ jobs:
|
|||||||
- name: Get PR details for deployment
|
- name: Get PR details for deployment
|
||||||
id: pr_details
|
id: pr_details
|
||||||
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const pr = await github.rest.pulls.get({
|
const pr = await github.rest.pulls.get({
|
||||||
@@ -82,7 +82,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Dispatch Deploy Event
|
- name: Dispatch Deploy Event
|
||||||
if: steps.check_status.outputs.should_deploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true'
|
||||||
uses: peter-evans/repository-dispatch@v3
|
uses: peter-evans/repository-dispatch@v4
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
@@ -98,7 +98,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post deploy success comment
|
- name: Post deploy success comment
|
||||||
if: steps.check_status.outputs.should_deploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -110,7 +110,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Dispatch Undeploy Event (from comment)
|
- name: Dispatch Undeploy Event (from comment)
|
||||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: peter-evans/repository-dispatch@v3
|
uses: peter-evans/repository-dispatch@v4
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
@@ -126,7 +126,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post undeploy success comment
|
- name: Post undeploy success comment
|
||||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -139,7 +139,7 @@ jobs:
|
|||||||
- name: Check deployment status on PR close
|
- name: Check deployment status on PR close
|
||||||
id: check_pr_close
|
id: check_pr_close
|
||||||
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const comments = await github.rest.issues.listComments({
|
const comments = await github.rest.issues.listComments({
|
||||||
@@ -168,7 +168,7 @@ jobs:
|
|||||||
github.event_name == 'pull_request' &&
|
github.event_name == 'pull_request' &&
|
||||||
github.event.action == 'closed' &&
|
github.event.action == 'closed' &&
|
||||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||||
uses: peter-evans/repository-dispatch@v3
|
uses: peter-evans/repository-dispatch@v4
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
@@ -187,7 +187,7 @@ jobs:
|
|||||||
github.event_name == 'pull_request' &&
|
github.event_name == 'pull_request' &&
|
||||||
github.event.action == 'closed' &&
|
github.event.action == 'closed' &&
|
||||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
|
|||||||
48
.github/workflows/platform-frontend-ci.yml
vendored
48
.github/workflows/platform-frontend-ci.yml
vendored
@@ -27,13 +27,22 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||||
|
components-changed: ${{ steps.filter.outputs.components }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
|
- name: Check for component changes
|
||||||
|
uses: dorny/paths-filter@v3
|
||||||
|
id: filter
|
||||||
|
with:
|
||||||
|
filters: |
|
||||||
|
components:
|
||||||
|
- 'autogpt_platform/frontend/src/components/**'
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -45,7 +54,7 @@ jobs:
|
|||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
key: ${{ steps.cache-key.outputs.key }}
|
||||||
@@ -62,10 +71,10 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -73,7 +82,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -90,17 +99,20 @@ jobs:
|
|||||||
chromatic:
|
chromatic:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: setup
|
needs: setup
|
||||||
# Only run on dev branch pushes or PRs targeting dev
|
# Disabled: to re-enable, remove 'false &&' from the condition below
|
||||||
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
|
if: >-
|
||||||
|
false
|
||||||
|
&& (github.ref == 'refs/heads/dev' || github.base_ref == 'dev')
|
||||||
|
&& needs.setup.outputs.components-changed == 'true'
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -108,7 +120,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -136,12 +148,12 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -164,7 +176,7 @@ jobs:
|
|||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
- name: Cache Docker layers
|
- name: Cache Docker layers
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: /tmp/.buildx-cache
|
path: /tmp/.buildx-cache
|
||||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -219,7 +231,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -265,12 +277,12 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -278,7 +290,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
|||||||
16
.github/workflows/platform-fullstack-ci.yml
vendored
16
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -29,10 +29,10 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ jobs:
|
|||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
key: ${{ steps.cache-key.outputs.key }}
|
||||||
@@ -56,19 +56,19 @@ jobs:
|
|||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
types:
|
types:
|
||||||
runs-on: ubuntu-latest
|
runs-on: big-boi
|
||||||
needs: setup
|
needs: setup
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -85,10 +85,10 @@ jobs:
|
|||||||
|
|
||||||
- name: Run docker compose
|
- name: Run docker compose
|
||||||
run: |
|
run: |
|
||||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
|||||||
2
.github/workflows/repo-workflow-checker.yml
vendored
2
.github/workflows/repo-workflow-checker.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
# - name: Wait some time for all actions to start
|
# - name: Wait some time for all actions to start
|
||||||
# run: sleep 30
|
# run: sleep 30
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
# with:
|
# with:
|
||||||
# fetch-depth: 0
|
# fetch-depth: 0
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -180,3 +180,4 @@ autogpt_platform/backend/settings.py
|
|||||||
.claude/settings.local.json
|
.claude/settings.local.json
|
||||||
CLAUDE.local.md
|
CLAUDE.local.md
|
||||||
/autogpt_platform/backend/logs
|
/autogpt_platform/backend/logs
|
||||||
|
.next
|
||||||
1862
autogpt_platform/autogpt_libs/poetry.lock
generated
1862
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.10,<4.0"
|
python = ">=3.10,<4.0"
|
||||||
colorama = "^0.4.6"
|
colorama = "^0.4.6"
|
||||||
cryptography = "^45.0"
|
cryptography = "^46.0"
|
||||||
expiringdict = "^1.2.2"
|
expiringdict = "^1.2.2"
|
||||||
fastapi = "^0.116.1"
|
fastapi = "^0.128.0"
|
||||||
google-cloud-logging = "^3.12.1"
|
google-cloud-logging = "^3.13.0"
|
||||||
launchdarkly-server-sdk = "^9.12.0"
|
launchdarkly-server-sdk = "^9.14.1"
|
||||||
pydantic = "^2.11.7"
|
pydantic = "^2.12.5"
|
||||||
pydantic-settings = "^2.10.1"
|
pydantic-settings = "^2.12.0"
|
||||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
supabase = "^2.16.0"
|
supabase = "^2.27.2"
|
||||||
uvicorn = "^0.35.0"
|
uvicorn = "^0.40.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pyright = "^1.1.404"
|
pyright = "^1.1.408"
|
||||||
pytest = "^8.4.1"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.1.0"
|
pytest-asyncio = "^1.3.0"
|
||||||
pytest-mock = "^3.14.1"
|
pytest-mock = "^3.15.1"
|
||||||
pytest-cov = "^6.2.1"
|
pytest-cov = "^7.0.0"
|
||||||
ruff = "^0.12.11"
|
ruff = "^0.15.0"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|||||||
@@ -152,6 +152,7 @@ REPLICATE_API_KEY=
|
|||||||
REVID_API_KEY=
|
REVID_API_KEY=
|
||||||
SCREENSHOTONE_API_KEY=
|
SCREENSHOTONE_API_KEY=
|
||||||
UNREAL_SPEECH_API_KEY=
|
UNREAL_SPEECH_API_KEY=
|
||||||
|
ELEVENLABS_API_KEY=
|
||||||
|
|
||||||
# Data & Search Services
|
# Data & Search Services
|
||||||
E2B_API_KEY=
|
E2B_API_KEY=
|
||||||
|
|||||||
3
autogpt_platform/backend/.gitignore
vendored
3
autogpt_platform/backend/.gitignore
vendored
@@ -19,3 +19,6 @@ load-tests/*.json
|
|||||||
load-tests/*.log
|
load-tests/*.log
|
||||||
load-tests/node_modules/*
|
load-tests/node_modules/*
|
||||||
migrations/*/rollback*.sql
|
migrations/*/rollback*.sql
|
||||||
|
|
||||||
|
# Workspace files
|
||||||
|
workspaces/
|
||||||
|
|||||||
@@ -62,10 +62,12 @@ ENV POETRY_HOME=/opt/poetry \
|
|||||||
DEBIAN_FRONTEND=noninteractive
|
DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
# Install Python without upgrading system-managed packages
|
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
python3.13 \
|
python3.13 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
|
ffmpeg \
|
||||||
|
imagemagick \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy only necessary files from builder
|
# Copy only necessary files from builder
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from typing_extensions import TypedDict
|
|||||||
|
|
||||||
import backend.api.features.store.cache as store_cache
|
import backend.api.features.store.cache as store_cache
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
import backend.data.block
|
import backend.blocks
|
||||||
from backend.api.external.middleware import require_permission
|
from backend.api.external.middleware import require_permission
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
@@ -67,7 +67,7 @@ async def get_user_info(
|
|||||||
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
||||||
)
|
)
|
||||||
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
blocks = [block() for block in backend.blocks.get_blocks().values()]
|
||||||
return [b.to_dict() for b in blocks if not b.disabled]
|
return [b.to_dict() for b in blocks if not b.disabled]
|
||||||
|
|
||||||
|
|
||||||
@@ -83,7 +83,7 @@ async def execute_graph_block(
|
|||||||
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
||||||
),
|
),
|
||||||
) -> CompletedBlockOutput:
|
) -> CompletedBlockOutput:
|
||||||
obj = backend.data.block.get_block(block_id)
|
obj = backend.blocks.get_block(block_id)
|
||||||
if not obj:
|
if not obj:
|
||||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||||
if obj.disabled:
|
if obj.disabled:
|
||||||
|
|||||||
@@ -10,10 +10,15 @@ import backend.api.features.library.db as library_db
|
|||||||
import backend.api.features.library.model as library_model
|
import backend.api.features.library.model as library_model
|
||||||
import backend.api.features.store.db as store_db
|
import backend.api.features.store.db as store_db
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
import backend.data.block
|
|
||||||
from backend.blocks import load_all_blocks
|
from backend.blocks import load_all_blocks
|
||||||
|
from backend.blocks._base import (
|
||||||
|
AnyBlockSchema,
|
||||||
|
BlockCategory,
|
||||||
|
BlockInfo,
|
||||||
|
BlockSchema,
|
||||||
|
BlockType,
|
||||||
|
)
|
||||||
from backend.blocks.llm import LlmModel
|
from backend.blocks.llm import LlmModel
|
||||||
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
|
||||||
from backend.data.db import query_raw_with_schema
|
from backend.data.db import query_raw_with_schema
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
@@ -22,7 +27,7 @@ from backend.util.models import Pagination
|
|||||||
from .model import (
|
from .model import (
|
||||||
BlockCategoryResponse,
|
BlockCategoryResponse,
|
||||||
BlockResponse,
|
BlockResponse,
|
||||||
BlockType,
|
BlockTypeFilter,
|
||||||
CountResponse,
|
CountResponse,
|
||||||
FilterType,
|
FilterType,
|
||||||
Provider,
|
Provider,
|
||||||
@@ -88,7 +93,7 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
|
|||||||
def get_blocks(
|
def get_blocks(
|
||||||
*,
|
*,
|
||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
type: BlockType | None = None,
|
type: BlockTypeFilter | None = None,
|
||||||
provider: ProviderName | None = None,
|
provider: ProviderName | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 50,
|
page_size: int = 50,
|
||||||
@@ -669,9 +674,9 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
|||||||
for block_type in load_all_blocks().values():
|
for block_type in load_all_blocks().values():
|
||||||
block: AnyBlockSchema = block_type()
|
block: AnyBlockSchema = block_type()
|
||||||
if block.disabled or block.block_type in (
|
if block.disabled or block.block_type in (
|
||||||
backend.data.block.BlockType.INPUT,
|
BlockType.INPUT,
|
||||||
backend.data.block.BlockType.OUTPUT,
|
BlockType.OUTPUT,
|
||||||
backend.data.block.BlockType.AGENT,
|
BlockType.AGENT,
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
# Find the execution count for this block
|
# Find the execution count for this block
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
import backend.api.features.library.model as library_model
|
import backend.api.features.library.model as library_model
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
from backend.data.block import BlockInfo
|
from backend.blocks._base import BlockInfo
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ FilterType = Literal[
|
|||||||
"my_agents",
|
"my_agents",
|
||||||
]
|
]
|
||||||
|
|
||||||
BlockType = Literal["all", "input", "action", "output"]
|
BlockTypeFilter = Literal["all", "input", "action", "output"]
|
||||||
|
|
||||||
|
|
||||||
class SearchEntry(BaseModel):
|
class SearchEntry(BaseModel):
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ async def get_block_categories(
|
|||||||
)
|
)
|
||||||
async def get_blocks(
|
async def get_blocks(
|
||||||
category: Annotated[str | None, fastapi.Query()] = None,
|
category: Annotated[str | None, fastapi.Query()] = None,
|
||||||
type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None,
|
type: Annotated[builder_model.BlockTypeFilter | None, fastapi.Query()] = None,
|
||||||
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
|
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
|
||||||
page: Annotated[int, fastapi.Query()] = 1,
|
page: Annotated[int, fastapi.Query()] = 1,
|
||||||
page_size: Annotated[int, fastapi.Query()] = 50,
|
page_size: Annotated[int, fastapi.Query()] = 50,
|
||||||
|
|||||||
@@ -0,0 +1,368 @@
|
|||||||
|
"""Redis Streams consumer for operation completion messages.
|
||||||
|
|
||||||
|
This module provides a consumer (ChatCompletionConsumer) that listens for
|
||||||
|
completion notifications (OperationCompleteMessage) from external services
|
||||||
|
(like Agent Generator) and triggers the appropriate stream registry and
|
||||||
|
chat service updates via process_operation_success/process_operation_failure.
|
||||||
|
|
||||||
|
Why Redis Streams instead of RabbitMQ?
|
||||||
|
--------------------------------------
|
||||||
|
While the project typically uses RabbitMQ for async task queues (e.g., execution
|
||||||
|
queue), Redis Streams was chosen for chat completion notifications because:
|
||||||
|
|
||||||
|
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
|
||||||
|
Streams (via stream_registry) for message persistence and replay. Using Redis
|
||||||
|
Streams for completion notifications keeps all chat streaming infrastructure
|
||||||
|
in one system, simplifying operations and reducing cross-system coordination.
|
||||||
|
|
||||||
|
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
|
||||||
|
allowing consumers to replay missed messages after reconnection. This aligns
|
||||||
|
with the SSE reconnection pattern where clients can resume from last_message_id.
|
||||||
|
|
||||||
|
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
|
||||||
|
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
|
||||||
|
recovering from dead consumers - ideal for the completion callback pattern.
|
||||||
|
|
||||||
|
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
|
||||||
|
stream_registry) provides lower latency than an additional RabbitMQ hop.
|
||||||
|
|
||||||
|
5. **Atomicity with Task State**: Completion processing often needs to update
|
||||||
|
task metadata stored in Redis. Keeping both in Redis enables simpler
|
||||||
|
transactional semantics without distributed coordination.
|
||||||
|
|
||||||
|
The consumer uses Redis Streams with consumer groups for reliable message
|
||||||
|
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
|
||||||
|
stale pending messages from dead consumers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
from prisma import Prisma
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from redis.exceptions import ResponseError
|
||||||
|
|
||||||
|
from backend.data.redis_client import get_redis_async
|
||||||
|
|
||||||
|
from . import stream_registry
|
||||||
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
|
from .config import ChatConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
config = ChatConfig()
|
||||||
|
|
||||||
|
|
||||||
|
class OperationCompleteMessage(BaseModel):
|
||||||
|
"""Message format for operation completion notifications."""
|
||||||
|
|
||||||
|
operation_id: str
|
||||||
|
task_id: str
|
||||||
|
success: bool
|
||||||
|
result: dict | str | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionConsumer:
|
||||||
|
"""Consumer for chat operation completion messages from Redis Streams.
|
||||||
|
|
||||||
|
This consumer initializes its own Prisma client in start() to ensure
|
||||||
|
database operations work correctly within this async context.
|
||||||
|
|
||||||
|
Uses Redis consumer groups to allow multiple platform pods to consume
|
||||||
|
messages reliably with automatic redelivery on failure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._consumer_task: asyncio.Task | None = None
|
||||||
|
self._running = False
|
||||||
|
self._prisma: Prisma | None = None
|
||||||
|
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the completion consumer."""
|
||||||
|
if self._running:
|
||||||
|
logger.warning("Completion consumer already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create consumer group if it doesn't exist
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
await redis.xgroup_create(
|
||||||
|
config.stream_completion_name,
|
||||||
|
config.stream_consumer_group,
|
||||||
|
id="0",
|
||||||
|
mkstream=True,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Created consumer group '{config.stream_consumer_group}' "
|
||||||
|
f"on stream '{config.stream_completion_name}'"
|
||||||
|
)
|
||||||
|
except ResponseError as e:
|
||||||
|
if "BUSYGROUP" in str(e):
|
||||||
|
logger.debug(
|
||||||
|
f"Consumer group '{config.stream_consumer_group}' already exists"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
self._consumer_task = asyncio.create_task(self._consume_messages())
|
||||||
|
logger.info(
|
||||||
|
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _ensure_prisma(self) -> Prisma:
|
||||||
|
"""Lazily initialize Prisma client on first use."""
|
||||||
|
if self._prisma is None:
|
||||||
|
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||||
|
self._prisma = Prisma(datasource={"url": database_url})
|
||||||
|
await self._prisma.connect()
|
||||||
|
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
||||||
|
return self._prisma
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the completion consumer."""
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
if self._consumer_task:
|
||||||
|
self._consumer_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._consumer_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._consumer_task = None
|
||||||
|
|
||||||
|
if self._prisma:
|
||||||
|
await self._prisma.disconnect()
|
||||||
|
self._prisma = None
|
||||||
|
logger.info("[COMPLETION] Consumer Prisma client disconnected")
|
||||||
|
|
||||||
|
logger.info("Chat completion consumer stopped")
|
||||||
|
|
||||||
|
async def _consume_messages(self) -> None:
|
||||||
|
"""Main message consumption loop with retry logic."""
|
||||||
|
max_retries = 10
|
||||||
|
retry_delay = 5 # seconds
|
||||||
|
retry_count = 0
|
||||||
|
block_timeout = 5000 # milliseconds
|
||||||
|
|
||||||
|
while self._running and retry_count < max_retries:
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
|
||||||
|
# Reset retry count on successful connection
|
||||||
|
retry_count = 0
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
# First, claim any stale pending messages from dead consumers
|
||||||
|
# Redis does NOT auto-redeliver pending messages; we must explicitly
|
||||||
|
# claim them using XAUTOCLAIM
|
||||||
|
try:
|
||||||
|
claimed_result = await redis.xautoclaim(
|
||||||
|
name=config.stream_completion_name,
|
||||||
|
groupname=config.stream_consumer_group,
|
||||||
|
consumername=self._consumer_name,
|
||||||
|
min_idle_time=config.stream_claim_min_idle_ms,
|
||||||
|
start_id="0-0",
|
||||||
|
count=10,
|
||||||
|
)
|
||||||
|
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
|
||||||
|
if claimed_result and len(claimed_result) >= 2:
|
||||||
|
claimed_entries = claimed_result[1]
|
||||||
|
if claimed_entries:
|
||||||
|
logger.info(
|
||||||
|
f"Claimed {len(claimed_entries)} stale pending messages"
|
||||||
|
)
|
||||||
|
for entry_id, data in claimed_entries:
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
await self._process_entry(redis, entry_id, data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
|
||||||
|
|
||||||
|
# Read new messages from the stream
|
||||||
|
messages = await redis.xreadgroup(
|
||||||
|
groupname=config.stream_consumer_group,
|
||||||
|
consumername=self._consumer_name,
|
||||||
|
streams={config.stream_completion_name: ">"},
|
||||||
|
block=block_timeout,
|
||||||
|
count=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for stream_name, entries in messages:
|
||||||
|
for entry_id, data in entries:
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
await self._process_entry(redis, entry_id, data)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Consumer cancelled")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
retry_count += 1
|
||||||
|
logger.error(
|
||||||
|
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
if self._running and retry_count < max_retries:
|
||||||
|
await asyncio.sleep(retry_delay)
|
||||||
|
else:
|
||||||
|
logger.error("Max retries reached, stopping consumer")
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _process_entry(
|
||||||
|
self, redis: Any, entry_id: str, data: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Process a single stream entry and acknowledge it on success.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis: Redis client connection
|
||||||
|
entry_id: The stream entry ID
|
||||||
|
data: The entry data dict
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Handle the message
|
||||||
|
message_data = data.get("data")
|
||||||
|
if message_data:
|
||||||
|
await self._handle_message(
|
||||||
|
message_data.encode()
|
||||||
|
if isinstance(message_data, str)
|
||||||
|
else message_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Acknowledge the message after successful processing
|
||||||
|
await redis.xack(
|
||||||
|
config.stream_completion_name,
|
||||||
|
config.stream_consumer_group,
|
||||||
|
entry_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error processing completion message {entry_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# Message remains in pending state and will be claimed by
|
||||||
|
# XAUTOCLAIM after min_idle_time expires
|
||||||
|
|
||||||
|
async def _handle_message(self, body: bytes) -> None:
|
||||||
|
"""Handle a completion message using our own Prisma client."""
|
||||||
|
try:
|
||||||
|
data = orjson.loads(body)
|
||||||
|
message = OperationCompleteMessage(**data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse completion message: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
||||||
|
f"(task_id={message.task_id}, success={message.success})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find task in registry
|
||||||
|
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
||||||
|
if task is None:
|
||||||
|
task = await stream_registry.get_task(message.task_id)
|
||||||
|
|
||||||
|
if task is None:
|
||||||
|
logger.warning(
|
||||||
|
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
||||||
|
f"(task_id={message.task_id})"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
||||||
|
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Guard against empty task fields
|
||||||
|
if not task.task_id or not task.session_id or not task.tool_call_id:
|
||||||
|
logger.error(
|
||||||
|
f"[COMPLETION] Task has empty critical fields! "
|
||||||
|
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
||||||
|
f"tool_call_id={task.tool_call_id!r}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if message.success:
|
||||||
|
await self._handle_success(task, message)
|
||||||
|
else:
|
||||||
|
await self._handle_failure(task, message)
|
||||||
|
|
||||||
|
async def _handle_success(
|
||||||
|
self,
|
||||||
|
task: stream_registry.ActiveTask,
|
||||||
|
message: OperationCompleteMessage,
|
||||||
|
) -> None:
|
||||||
|
"""Handle successful operation completion."""
|
||||||
|
prisma = await self._ensure_prisma()
|
||||||
|
await process_operation_success(task, message.result, prisma)
|
||||||
|
|
||||||
|
async def _handle_failure(
|
||||||
|
self,
|
||||||
|
task: stream_registry.ActiveTask,
|
||||||
|
message: OperationCompleteMessage,
|
||||||
|
) -> None:
|
||||||
|
"""Handle failed operation completion."""
|
||||||
|
prisma = await self._ensure_prisma()
|
||||||
|
await process_operation_failure(task, message.error, prisma)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level consumer instance
|
||||||
|
_consumer: ChatCompletionConsumer | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def start_completion_consumer() -> None:
|
||||||
|
"""Start the global completion consumer."""
|
||||||
|
global _consumer
|
||||||
|
if _consumer is None:
|
||||||
|
_consumer = ChatCompletionConsumer()
|
||||||
|
await _consumer.start()
|
||||||
|
|
||||||
|
|
||||||
|
async def stop_completion_consumer() -> None:
|
||||||
|
"""Stop the global completion consumer."""
|
||||||
|
global _consumer
|
||||||
|
if _consumer:
|
||||||
|
await _consumer.stop()
|
||||||
|
_consumer = None
|
||||||
|
|
||||||
|
|
||||||
|
async def publish_operation_complete(
|
||||||
|
operation_id: str,
|
||||||
|
task_id: str,
|
||||||
|
success: bool,
|
||||||
|
result: dict | str | None = None,
|
||||||
|
error: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Publish an operation completion message to Redis Streams.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation_id: The operation ID that completed.
|
||||||
|
task_id: The task ID associated with the operation.
|
||||||
|
success: Whether the operation succeeded.
|
||||||
|
result: The result data (for success).
|
||||||
|
error: The error message (for failure).
|
||||||
|
"""
|
||||||
|
message = OperationCompleteMessage(
|
||||||
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
success=success,
|
||||||
|
result=result,
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
|
||||||
|
redis = await get_redis_async()
|
||||||
|
await redis.xadd(
|
||||||
|
config.stream_completion_name,
|
||||||
|
{"data": message.model_dump_json()},
|
||||||
|
maxlen=config.stream_max_length,
|
||||||
|
)
|
||||||
|
logger.info(f"Published completion for operation {operation_id}")
|
||||||
@@ -0,0 +1,344 @@
|
|||||||
|
"""Shared completion handling for operation success and failure.
|
||||||
|
|
||||||
|
This module provides common logic for handling operation completion from both:
|
||||||
|
- The Redis Streams consumer (completion_consumer.py)
|
||||||
|
- The HTTP webhook endpoint (routes.py)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
from prisma import Prisma
|
||||||
|
|
||||||
|
from . import service as chat_service
|
||||||
|
from . import stream_registry
|
||||||
|
from .response_model import StreamError, StreamToolOutputAvailable
|
||||||
|
from .tools.models import ErrorResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Tools that produce agent_json that needs to be saved to library
|
||||||
|
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
|
||||||
|
|
||||||
|
# Keys that should be stripped from agent_json when returning in error responses
|
||||||
|
SENSITIVE_KEYS = frozenset(
|
||||||
|
{
|
||||||
|
"api_key",
|
||||||
|
"apikey",
|
||||||
|
"api_secret",
|
||||||
|
"password",
|
||||||
|
"secret",
|
||||||
|
"credentials",
|
||||||
|
"credential",
|
||||||
|
"token",
|
||||||
|
"access_token",
|
||||||
|
"refresh_token",
|
||||||
|
"private_key",
|
||||||
|
"privatekey",
|
||||||
|
"auth",
|
||||||
|
"authorization",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_agent_json(obj: Any) -> Any:
|
||||||
|
"""Recursively sanitize agent_json by removing sensitive keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object to sanitize (dict, list, or primitive)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized copy with sensitive keys removed/redacted
|
||||||
|
"""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {
|
||||||
|
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
|
||||||
|
for k, v in obj.items()
|
||||||
|
}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [_sanitize_agent_json(item) for item in obj]
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class ToolMessageUpdateError(Exception):
|
||||||
|
"""Raised when updating a tool message in the database fails."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_tool_message(
|
||||||
|
session_id: str,
|
||||||
|
tool_call_id: str,
|
||||||
|
content: str,
|
||||||
|
prisma_client: Prisma | None,
|
||||||
|
) -> None:
|
||||||
|
"""Update tool message in database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session ID
|
||||||
|
tool_call_id: The tool call ID to update
|
||||||
|
content: The new content for the message
|
||||||
|
prisma_client: Optional Prisma client. If None, uses chat_service.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ToolMessageUpdateError: If the database update fails. The caller should
|
||||||
|
handle this to avoid marking the task as completed with inconsistent state.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if prisma_client:
|
||||||
|
# Use provided Prisma client (for consumer with its own connection)
|
||||||
|
updated_count = await prisma_client.chatmessage.update_many(
|
||||||
|
where={
|
||||||
|
"sessionId": session_id,
|
||||||
|
"toolCallId": tool_call_id,
|
||||||
|
},
|
||||||
|
data={"content": content},
|
||||||
|
)
|
||||||
|
# Check if any rows were updated - 0 means message not found
|
||||||
|
if updated_count == 0:
|
||||||
|
raise ToolMessageUpdateError(
|
||||||
|
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use service function (for webhook endpoint)
|
||||||
|
await chat_service._update_pending_operation(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
result=content,
|
||||||
|
)
|
||||||
|
except ToolMessageUpdateError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
||||||
|
raise ToolMessageUpdateError(
|
||||||
|
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
|
||||||
|
"""Serialize result to JSON string with sensible defaults.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: The result to serialize. Can be a dict, list, string,
|
||||||
|
number, boolean, or None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON string representation of the result. Returns '{"status": "completed"}'
|
||||||
|
only when result is explicitly None.
|
||||||
|
"""
|
||||||
|
if isinstance(result, str):
|
||||||
|
return result
|
||||||
|
if result is None:
|
||||||
|
return '{"status": "completed"}'
|
||||||
|
return orjson.dumps(result).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
async def _save_agent_from_result(
|
||||||
|
result: dict[str, Any],
|
||||||
|
user_id: str | None,
|
||||||
|
tool_name: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Save agent to library if result contains agent_json.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: The result dict that may contain agent_json
|
||||||
|
user_id: The user ID to save the agent for
|
||||||
|
tool_name: The tool name (create_agent or edit_agent)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated result dict with saved agent details, or original result if no agent_json
|
||||||
|
"""
|
||||||
|
if not user_id:
|
||||||
|
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
|
||||||
|
return result
|
||||||
|
|
||||||
|
agent_json = result.get("agent_json")
|
||||||
|
if not agent_json:
|
||||||
|
logger.warning(
|
||||||
|
f"[COMPLETION] {tool_name} completed but no agent_json in result"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .tools.agent_generator import save_agent_to_library
|
||||||
|
|
||||||
|
is_update = tool_name == "edit_agent"
|
||||||
|
created_graph, library_agent = await save_agent_to_library(
|
||||||
|
agent_json, user_id, is_update=is_update
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
|
||||||
|
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return a response similar to AgentSavedResponse
|
||||||
|
return {
|
||||||
|
"type": "agent_saved",
|
||||||
|
"message": f"Agent '{created_graph.name}' has been saved to your library!",
|
||||||
|
"agent_id": created_graph.id,
|
||||||
|
"agent_name": created_graph.name,
|
||||||
|
"library_agent_id": library_agent.id,
|
||||||
|
"library_agent_link": f"/library/agents/{library_agent.id}",
|
||||||
|
"agent_page_link": f"/build?flowID={created_graph.id}",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[COMPLETION] Failed to save agent to library: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# Return error but don't fail the whole operation
|
||||||
|
# Sanitize agent_json to remove sensitive keys before returning
|
||||||
|
return {
|
||||||
|
"type": "error",
|
||||||
|
"message": f"Agent was generated but failed to save: {str(e)}",
|
||||||
|
"error": str(e),
|
||||||
|
"agent_json": _sanitize_agent_json(agent_json),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def process_operation_success(
|
||||||
|
task: stream_registry.ActiveTask,
|
||||||
|
result: dict | str | None,
|
||||||
|
prisma_client: Prisma | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Handle successful operation completion.
|
||||||
|
|
||||||
|
Publishes the result to the stream registry, updates the database,
|
||||||
|
generates LLM continuation, and marks the task as completed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The active task that completed
|
||||||
|
result: The result data from the operation
|
||||||
|
prisma_client: Optional Prisma client for database operations.
|
||||||
|
If None, uses chat_service._update_pending_operation instead.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ToolMessageUpdateError: If the database update fails. The task will be
|
||||||
|
marked as failed instead of completed to avoid inconsistent state.
|
||||||
|
"""
|
||||||
|
# For agent generation tools, save the agent to library
|
||||||
|
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
||||||
|
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
|
||||||
|
|
||||||
|
# Serialize result for output (only substitute default when result is exactly None)
|
||||||
|
result_output = result if result is not None else {"status": "completed"}
|
||||||
|
output_str = (
|
||||||
|
result_output
|
||||||
|
if isinstance(result_output, str)
|
||||||
|
else orjson.dumps(result_output).decode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Publish result to stream registry
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task.task_id,
|
||||||
|
StreamToolOutputAvailable(
|
||||||
|
toolCallId=task.tool_call_id,
|
||||||
|
toolName=task.tool_name,
|
||||||
|
output=output_str,
|
||||||
|
success=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update pending operation in database
|
||||||
|
# If this fails, we must not continue to mark the task as completed
|
||||||
|
result_str = serialize_result(result)
|
||||||
|
try:
|
||||||
|
await _update_tool_message(
|
||||||
|
session_id=task.session_id,
|
||||||
|
tool_call_id=task.tool_call_id,
|
||||||
|
content=result_str,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
)
|
||||||
|
except ToolMessageUpdateError:
|
||||||
|
# DB update failed - mark task as failed to avoid inconsistent state
|
||||||
|
logger.error(
|
||||||
|
f"[COMPLETION] DB update failed for task {task.task_id}, "
|
||||||
|
"marking as failed instead of completed"
|
||||||
|
)
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task.task_id,
|
||||||
|
StreamError(errorText="Failed to save operation result to database"),
|
||||||
|
)
|
||||||
|
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Generate LLM continuation with streaming
|
||||||
|
try:
|
||||||
|
await chat_service._generate_llm_continuation_with_streaming(
|
||||||
|
session_id=task.session_id,
|
||||||
|
user_id=task.user_id,
|
||||||
|
task_id=task.task_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark task as completed and release Redis lock
|
||||||
|
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||||
|
try:
|
||||||
|
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def process_operation_failure(
|
||||||
|
task: stream_registry.ActiveTask,
|
||||||
|
error: str | None,
|
||||||
|
prisma_client: Prisma | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Handle failed operation completion.
|
||||||
|
|
||||||
|
Publishes the error to the stream registry, updates the database with
|
||||||
|
the error response, and marks the task as failed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The active task that failed
|
||||||
|
error: The error message from the operation
|
||||||
|
prisma_client: Optional Prisma client for database operations.
|
||||||
|
If None, uses chat_service._update_pending_operation instead.
|
||||||
|
"""
|
||||||
|
error_msg = error or "Operation failed"
|
||||||
|
|
||||||
|
# Publish error to stream registry
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task.task_id,
|
||||||
|
StreamError(errorText=error_msg),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update pending operation with error
|
||||||
|
# If this fails, we still continue to mark the task as failed
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
message=error_msg,
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await _update_tool_message(
|
||||||
|
session_id=task.session_id,
|
||||||
|
tool_call_id=task.tool_call_id,
|
||||||
|
content=error_response.model_dump_json(),
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
)
|
||||||
|
except ToolMessageUpdateError:
|
||||||
|
# DB update failed - log but continue with cleanup
|
||||||
|
logger.error(
|
||||||
|
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
|
||||||
|
"continuing with cleanup"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark task as failed and release Redis lock
|
||||||
|
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||||
|
try:
|
||||||
|
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||||
|
|
||||||
|
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
|
||||||
@@ -11,7 +11,7 @@ class ChatConfig(BaseSettings):
|
|||||||
|
|
||||||
# OpenAI API Configuration
|
# OpenAI API Configuration
|
||||||
model: str = Field(
|
model: str = Field(
|
||||||
default="anthropic/claude-opus-4.5", description="Default model to use"
|
default="anthropic/claude-opus-4.6", description="Default model to use"
|
||||||
)
|
)
|
||||||
title_model: str = Field(
|
title_model: str = Field(
|
||||||
default="openai/gpt-4o-mini",
|
default="openai/gpt-4o-mini",
|
||||||
@@ -44,6 +44,48 @@ class ChatConfig(BaseSettings):
|
|||||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Stream registry configuration for SSE reconnection
|
||||||
|
stream_ttl: int = Field(
|
||||||
|
default=3600,
|
||||||
|
description="TTL in seconds for stream data in Redis (1 hour)",
|
||||||
|
)
|
||||||
|
stream_max_length: int = Field(
|
||||||
|
default=10000,
|
||||||
|
description="Maximum number of messages to store per stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redis Streams configuration for completion consumer
|
||||||
|
stream_completion_name: str = Field(
|
||||||
|
default="chat:completions",
|
||||||
|
description="Redis Stream name for operation completions",
|
||||||
|
)
|
||||||
|
stream_consumer_group: str = Field(
|
||||||
|
default="chat_consumers",
|
||||||
|
description="Consumer group name for completion stream",
|
||||||
|
)
|
||||||
|
stream_claim_min_idle_ms: int = Field(
|
||||||
|
default=60000,
|
||||||
|
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redis key prefixes for stream registry
|
||||||
|
task_meta_prefix: str = Field(
|
||||||
|
default="chat:task:meta:",
|
||||||
|
description="Prefix for task metadata hash keys",
|
||||||
|
)
|
||||||
|
task_stream_prefix: str = Field(
|
||||||
|
default="chat:stream:",
|
||||||
|
description="Prefix for task message stream keys",
|
||||||
|
)
|
||||||
|
task_op_prefix: str = Field(
|
||||||
|
default="chat:task:op:",
|
||||||
|
description="Prefix for operation ID to task ID mapping keys",
|
||||||
|
)
|
||||||
|
internal_api_key: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
# Langfuse Prompt Management Configuration
|
# Langfuse Prompt Management Configuration
|
||||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||||
langfuse_prompt_name: str = Field(
|
langfuse_prompt_name: str = Field(
|
||||||
@@ -51,6 +93,12 @@ class ChatConfig(BaseSettings):
|
|||||||
description="Name of the prompt in Langfuse to fetch",
|
description="Name of the prompt in Langfuse to fetch",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extended thinking configuration for Claude models
|
||||||
|
thinking_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable adaptive thinking for Claude models via OpenRouter",
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("api_key", mode="before")
|
@field_validator("api_key", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_api_key(cls, v):
|
def get_api_key(cls, v):
|
||||||
@@ -82,6 +130,14 @@ class ChatConfig(BaseSettings):
|
|||||||
v = "https://openrouter.ai/api/v1"
|
v = "https://openrouter.ai/api/v1"
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@field_validator("internal_api_key", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def get_internal_api_key(cls, v):
|
||||||
|
"""Get internal API key from environment if not provided."""
|
||||||
|
if v is None:
|
||||||
|
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||||
|
return v
|
||||||
|
|
||||||
# Prompt paths for different contexts
|
# Prompt paths for different contexts
|
||||||
PROMPT_PATHS: dict[str, str] = {
|
PROMPT_PATHS: dict[str, str] = {
|
||||||
"default": "prompts/chat_system.md",
|
"default": "prompts/chat_system.md",
|
||||||
|
|||||||
@@ -45,10 +45,7 @@ async def create_chat_session(
|
|||||||
successfulAgentRuns=SafeJson({}),
|
successfulAgentRuns=SafeJson({}),
|
||||||
successfulAgentSchedules=SafeJson({}),
|
successfulAgentSchedules=SafeJson({}),
|
||||||
)
|
)
|
||||||
return await PrismaChatSession.prisma().create(
|
return await PrismaChatSession.prisma().create(data=data)
|
||||||
data=data,
|
|
||||||
include={"Messages": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def update_chat_session(
|
async def update_chat_session(
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
@@ -104,6 +104,26 @@ class ChatSession(BaseModel):
|
|||||||
successful_agent_runs: dict[str, int] = {}
|
successful_agent_runs: dict[str, int] = {}
|
||||||
successful_agent_schedules: dict[str, int] = {}
|
successful_agent_schedules: dict[str, int] = {}
|
||||||
|
|
||||||
|
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
||||||
|
"""Attach a tool_call to the current turn's assistant message.
|
||||||
|
|
||||||
|
Searches backwards for the most recent assistant message (stopping at
|
||||||
|
any user message boundary). If found, appends the tool_call to it.
|
||||||
|
Otherwise creates a new assistant message with the tool_call.
|
||||||
|
"""
|
||||||
|
for msg in reversed(self.messages):
|
||||||
|
if msg.role == "user":
|
||||||
|
break
|
||||||
|
if msg.role == "assistant":
|
||||||
|
if not msg.tool_calls:
|
||||||
|
msg.tool_calls = []
|
||||||
|
msg.tool_calls.append(tool_call)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.messages.append(
|
||||||
|
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def new(user_id: str) -> "ChatSession":
|
def new(user_id: str) -> "ChatSession":
|
||||||
return ChatSession(
|
return ChatSession(
|
||||||
@@ -172,6 +192,47 @@ class ChatSession(BaseModel):
|
|||||||
successful_agent_schedules=successful_agent_schedules,
|
successful_agent_schedules=successful_agent_schedules,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_consecutive_assistant_messages(
|
||||||
|
messages: list[ChatCompletionMessageParam],
|
||||||
|
) -> list[ChatCompletionMessageParam]:
|
||||||
|
"""Merge consecutive assistant messages into single messages.
|
||||||
|
|
||||||
|
Long-running tool flows can create split assistant messages: one with
|
||||||
|
text content and another with tool_calls. Anthropic's API requires
|
||||||
|
tool_result blocks to reference a tool_use in the immediately preceding
|
||||||
|
assistant message, so these splits cause 400 errors via OpenRouter.
|
||||||
|
"""
|
||||||
|
if len(messages) < 2:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
result: list[ChatCompletionMessageParam] = [messages[0]]
|
||||||
|
for msg in messages[1:]:
|
||||||
|
prev = result[-1]
|
||||||
|
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
|
||||||
|
result.append(msg)
|
||||||
|
continue
|
||||||
|
|
||||||
|
prev = cast(ChatCompletionAssistantMessageParam, prev)
|
||||||
|
curr = cast(ChatCompletionAssistantMessageParam, msg)
|
||||||
|
|
||||||
|
curr_content = curr.get("content") or ""
|
||||||
|
if curr_content:
|
||||||
|
prev_content = prev.get("content") or ""
|
||||||
|
prev["content"] = (
|
||||||
|
f"{prev_content}\n{curr_content}" if prev_content else curr_content
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_tool_calls = curr.get("tool_calls")
|
||||||
|
if curr_tool_calls:
|
||||||
|
prev_tool_calls = prev.get("tool_calls")
|
||||||
|
prev["tool_calls"] = (
|
||||||
|
list(prev_tool_calls) + list(curr_tool_calls)
|
||||||
|
if prev_tool_calls
|
||||||
|
else list(curr_tool_calls)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||||
messages = []
|
messages = []
|
||||||
for message in self.messages:
|
for message in self.messages:
|
||||||
@@ -258,7 +319,7 @@ class ChatSession(BaseModel):
|
|||||||
name=message.name or "",
|
name=message.name or "",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return messages
|
return self._merge_consecutive_assistant_messages(messages)
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||||
|
|||||||
@@ -1,4 +1,16 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionAssistantMessageParam,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionToolMessageParam,
|
||||||
|
ChatCompletionUserMessageParam,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||||
|
ChatCompletionMessageToolCallParam,
|
||||||
|
Function,
|
||||||
|
)
|
||||||
|
|
||||||
from .model import (
|
from .model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
@@ -117,3 +129,205 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
|||||||
loaded.tool_calls is not None
|
loaded.tool_calls is not None
|
||||||
), f"Tool calls missing for {orig.role} message"
|
), f"Tool calls missing for {orig.role} message"
|
||||||
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# _merge_consecutive_assistant_messages #
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
_tc = ChatCompletionMessageToolCallParam(
|
||||||
|
id="tc1", type="function", function=Function(name="do_stuff", arguments="{}")
|
||||||
|
)
|
||||||
|
_tc2 = ChatCompletionMessageToolCallParam(
|
||||||
|
id="tc2", type="function", function=Function(name="other", arguments="{}")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_noop_when_no_consecutive_assistants():
|
||||||
|
"""Messages without consecutive assistants are returned unchanged."""
|
||||||
|
msgs = [
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="hi"),
|
||||||
|
ChatCompletionAssistantMessageParam(role="assistant", content="hello"),
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="bye"),
|
||||||
|
]
|
||||||
|
merged = ChatSession._merge_consecutive_assistant_messages(msgs)
|
||||||
|
assert len(merged) == 3
|
||||||
|
assert [m["role"] for m in merged] == ["user", "assistant", "user"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_splits_text_and_tool_calls():
|
||||||
|
"""The exact bug scenario: text-only assistant followed by tool_calls-only assistant."""
|
||||||
|
msgs = [
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="build agent"),
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="Let me build that"
|
||||||
|
),
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="", tool_calls=[_tc]
|
||||||
|
),
|
||||||
|
ChatCompletionToolMessageParam(role="tool", content="ok", tool_call_id="tc1"),
|
||||||
|
]
|
||||||
|
merged = ChatSession._merge_consecutive_assistant_messages(msgs)
|
||||||
|
|
||||||
|
assert len(merged) == 3
|
||||||
|
assert merged[0]["role"] == "user"
|
||||||
|
assert merged[2]["role"] == "tool"
|
||||||
|
a = cast(ChatCompletionAssistantMessageParam, merged[1])
|
||||||
|
assert a["role"] == "assistant"
|
||||||
|
assert a.get("content") == "Let me build that"
|
||||||
|
assert a.get("tool_calls") == [_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_combines_tool_calls_from_both():
|
||||||
|
"""Both consecutive assistants have tool_calls — they get merged."""
|
||||||
|
msgs: list[ChatCompletionAssistantMessageParam] = [
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="text", tool_calls=[_tc]
|
||||||
|
),
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="", tool_calls=[_tc2]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
merged = ChatSession._merge_consecutive_assistant_messages(msgs) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert len(merged) == 1
|
||||||
|
a = cast(ChatCompletionAssistantMessageParam, merged[0])
|
||||||
|
assert a.get("tool_calls") == [_tc, _tc2]
|
||||||
|
assert a.get("content") == "text"
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_three_consecutive_assistants():
|
||||||
|
"""Three consecutive assistants collapse into one."""
|
||||||
|
msgs: list[ChatCompletionAssistantMessageParam] = [
|
||||||
|
ChatCompletionAssistantMessageParam(role="assistant", content="a"),
|
||||||
|
ChatCompletionAssistantMessageParam(role="assistant", content="b"),
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="", tool_calls=[_tc]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
merged = ChatSession._merge_consecutive_assistant_messages(msgs) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert len(merged) == 1
|
||||||
|
a = cast(ChatCompletionAssistantMessageParam, merged[0])
|
||||||
|
assert a.get("content") == "a\nb"
|
||||||
|
assert a.get("tool_calls") == [_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_empty_and_single_message():
|
||||||
|
"""Edge cases: empty list and single message."""
|
||||||
|
assert ChatSession._merge_consecutive_assistant_messages([]) == []
|
||||||
|
|
||||||
|
single: list[ChatCompletionMessageParam] = [
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="hi")
|
||||||
|
]
|
||||||
|
assert ChatSession._merge_consecutive_assistant_messages(single) == single
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# add_tool_call_to_current_turn #
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
_raw_tc = {
|
||||||
|
"id": "tc1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "f", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
_raw_tc2 = {
|
||||||
|
"id": "tc2",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "g", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_call_appends_to_existing_assistant():
|
||||||
|
"""When the last assistant is from the current turn, tool_call is added to it."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="user", content="hi"),
|
||||||
|
ChatMessage(role="assistant", content="working on it"),
|
||||||
|
]
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc)
|
||||||
|
|
||||||
|
assert len(session.messages) == 2 # no new message created
|
||||||
|
assert session.messages[1].tool_calls == [_raw_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_call_creates_assistant_when_none_exists():
|
||||||
|
"""When there's no current-turn assistant, a new one is created."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="user", content="hi"),
|
||||||
|
]
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc)
|
||||||
|
|
||||||
|
assert len(session.messages) == 2
|
||||||
|
assert session.messages[1].role == "assistant"
|
||||||
|
assert session.messages[1].tool_calls == [_raw_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_call_does_not_cross_user_boundary():
|
||||||
|
"""A user message acts as a boundary — previous assistant is not modified."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="assistant", content="old turn"),
|
||||||
|
ChatMessage(role="user", content="new message"),
|
||||||
|
]
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc)
|
||||||
|
|
||||||
|
assert len(session.messages) == 3 # new assistant was created
|
||||||
|
assert session.messages[0].tool_calls is None # old assistant untouched
|
||||||
|
assert session.messages[2].role == "assistant"
|
||||||
|
assert session.messages[2].tool_calls == [_raw_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_call_multiple_times():
|
||||||
|
"""Multiple long-running tool calls accumulate on the same assistant."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="user", content="hi"),
|
||||||
|
ChatMessage(role="assistant", content="doing stuff"),
|
||||||
|
]
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc)
|
||||||
|
# Simulate a pending tool result in between (like _yield_tool_call does)
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(role="tool", content="pending", tool_call_id="tc1")
|
||||||
|
)
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc2)
|
||||||
|
|
||||||
|
assert len(session.messages) == 3 # user, assistant, tool — no extra assistant
|
||||||
|
assert session.messages[1].tool_calls == [_raw_tc, _raw_tc2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_openai_messages_merges_split_assistants():
|
||||||
|
"""End-to-end: session with split assistants produces valid OpenAI messages."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="user", content="build agent"),
|
||||||
|
ChatMessage(role="assistant", content="Let me build that"),
|
||||||
|
ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "tc1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "create_agent", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ChatMessage(role="tool", content="done", tool_call_id="tc1"),
|
||||||
|
ChatMessage(role="assistant", content="Saved!"),
|
||||||
|
ChatMessage(role="user", content="show me an example run"),
|
||||||
|
]
|
||||||
|
openai_msgs = session.to_openai_messages()
|
||||||
|
|
||||||
|
# The two consecutive assistants at index 1,2 should be merged
|
||||||
|
roles = [m["role"] for m in openai_msgs]
|
||||||
|
assert roles == ["user", "assistant", "tool", "assistant", "user"]
|
||||||
|
|
||||||
|
# The merged assistant should have both content and tool_calls
|
||||||
|
merged = cast(ChatCompletionAssistantMessageParam, openai_msgs[1])
|
||||||
|
assert merged.get("content") == "Let me build that"
|
||||||
|
tc_list = merged.get("tool_calls")
|
||||||
|
assert tc_list is not None and len(list(tc_list)) == 1
|
||||||
|
assert list(tc_list)[0]["id"] == "tc1"
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.util.json import dumps as json_dumps
|
||||||
|
|
||||||
|
|
||||||
class ResponseType(str, Enum):
|
class ResponseType(str, Enum):
|
||||||
"""Types of streaming responses following AI SDK protocol."""
|
"""Types of streaming responses following AI SDK protocol."""
|
||||||
@@ -18,6 +20,10 @@ class ResponseType(str, Enum):
|
|||||||
START = "start"
|
START = "start"
|
||||||
FINISH = "finish"
|
FINISH = "finish"
|
||||||
|
|
||||||
|
# Step lifecycle (one LLM API call within a message)
|
||||||
|
START_STEP = "start-step"
|
||||||
|
FINISH_STEP = "finish-step"
|
||||||
|
|
||||||
# Text streaming
|
# Text streaming
|
||||||
TEXT_START = "text-start"
|
TEXT_START = "text-start"
|
||||||
TEXT_DELTA = "text-delta"
|
TEXT_DELTA = "text-delta"
|
||||||
@@ -52,6 +58,20 @@ class StreamStart(StreamBaseResponse):
|
|||||||
|
|
||||||
type: ResponseType = ResponseType.START
|
type: ResponseType = ResponseType.START
|
||||||
messageId: str = Field(..., description="Unique message ID")
|
messageId: str = Field(..., description="Unique message ID")
|
||||||
|
taskId: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_sse(self) -> str:
|
||||||
|
"""Convert to SSE format, excluding non-protocol fields like taskId."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"type": self.type.value,
|
||||||
|
"messageId": self.messageId,
|
||||||
|
}
|
||||||
|
return f"data: {json.dumps(data)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
class StreamFinish(StreamBaseResponse):
|
class StreamFinish(StreamBaseResponse):
|
||||||
@@ -60,6 +80,26 @@ class StreamFinish(StreamBaseResponse):
|
|||||||
type: ResponseType = ResponseType.FINISH
|
type: ResponseType = ResponseType.FINISH
|
||||||
|
|
||||||
|
|
||||||
|
class StreamStartStep(StreamBaseResponse):
|
||||||
|
"""Start of a step (one LLM API call within a message).
|
||||||
|
|
||||||
|
The AI SDK uses this to add a step-start boundary to message.parts,
|
||||||
|
enabling visual separation between multiple LLM calls in a single message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.START_STEP
|
||||||
|
|
||||||
|
|
||||||
|
class StreamFinishStep(StreamBaseResponse):
|
||||||
|
"""End of a step (one LLM API call within a message).
|
||||||
|
|
||||||
|
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
|
||||||
|
so the next LLM call in a tool-call continuation starts with clean state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.FINISH_STEP
|
||||||
|
|
||||||
|
|
||||||
# ========== Text Streaming ==========
|
# ========== Text Streaming ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -113,7 +153,7 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
|||||||
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
||||||
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
||||||
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
||||||
# Additional fields for internal use (not part of AI SDK spec but useful)
|
# Keep these for internal backend use
|
||||||
toolName: str | None = Field(
|
toolName: str | None = Field(
|
||||||
default=None, description="Name of the tool that was executed"
|
default=None, description="Name of the tool that was executed"
|
||||||
)
|
)
|
||||||
@@ -121,6 +161,17 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
|||||||
default=True, description="Whether the tool execution succeeded"
|
default=True, description="Whether the tool execution succeeded"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_sse(self) -> str:
|
||||||
|
"""Convert to SSE format, excluding non-spec fields."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"type": self.type.value,
|
||||||
|
"toolCallId": self.toolCallId,
|
||||||
|
"output": self.output,
|
||||||
|
}
|
||||||
|
return f"data: {json.dumps(data)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
# ========== Other ==========
|
# ========== Other ==========
|
||||||
|
|
||||||
@@ -144,6 +195,18 @@ class StreamError(StreamBaseResponse):
|
|||||||
default=None, description="Additional error details"
|
default=None, description="Additional error details"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_sse(self) -> str:
|
||||||
|
"""Convert to SSE format, only emitting fields required by AI SDK protocol.
|
||||||
|
|
||||||
|
The AI SDK uses z.strictObject({type, errorText}) which rejects
|
||||||
|
any extra fields like `code` or `details`.
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"type": self.type.value,
|
||||||
|
"errorText": self.errorText,
|
||||||
|
}
|
||||||
|
return f"data: {json_dumps(data)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
class StreamHeartbeat(StreamBaseResponse):
|
class StreamHeartbeat(StreamBaseResponse):
|
||||||
"""Heartbeat to keep SSE connection alive during long-running operations.
|
"""Heartbeat to keep SSE connection alive during long-running operations.
|
||||||
|
|||||||
@@ -1,19 +1,45 @@
|
|||||||
"""Chat API routes for chat session management and streaming via SSE."""
|
"""Chat API routes for chat session management and streaming via SSE."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import uuid as uuid_module
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from autogpt_libs import auth
|
from autogpt_libs import auth
|
||||||
from fastapi import APIRouter, Depends, Query, Security
|
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
from . import service as chat_service
|
from . import service as chat_service
|
||||||
|
from . import stream_registry
|
||||||
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||||
|
from .response_model import StreamFinish, StreamHeartbeat
|
||||||
|
from .tools.models import (
|
||||||
|
AgentDetailsResponse,
|
||||||
|
AgentOutputResponse,
|
||||||
|
AgentPreviewResponse,
|
||||||
|
AgentSavedResponse,
|
||||||
|
AgentsFoundResponse,
|
||||||
|
BlockListResponse,
|
||||||
|
BlockOutputResponse,
|
||||||
|
ClarificationNeededResponse,
|
||||||
|
DocPageResponse,
|
||||||
|
DocSearchResultsResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
ExecutionStartedResponse,
|
||||||
|
InputValidationErrorResponse,
|
||||||
|
NeedLoginResponse,
|
||||||
|
NoResultsResponse,
|
||||||
|
OperationInProgressResponse,
|
||||||
|
OperationPendingResponse,
|
||||||
|
OperationStartedResponse,
|
||||||
|
SetupRequirementsResponse,
|
||||||
|
UnderstandingUpdatedResponse,
|
||||||
|
)
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -55,6 +81,15 @@ class CreateSessionResponse(BaseModel):
|
|||||||
user_id: str | None
|
user_id: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveStreamInfo(BaseModel):
|
||||||
|
"""Information about an active stream for reconnection."""
|
||||||
|
|
||||||
|
task_id: str
|
||||||
|
last_message_id: str # Redis Stream message ID for resumption
|
||||||
|
operation_id: str # Operation ID for completion tracking
|
||||||
|
tool_name: str # Name of the tool being executed
|
||||||
|
|
||||||
|
|
||||||
class SessionDetailResponse(BaseModel):
|
class SessionDetailResponse(BaseModel):
|
||||||
"""Response model providing complete details for a chat session, including messages."""
|
"""Response model providing complete details for a chat session, including messages."""
|
||||||
|
|
||||||
@@ -63,6 +98,7 @@ class SessionDetailResponse(BaseModel):
|
|||||||
updated_at: str
|
updated_at: str
|
||||||
user_id: str | None
|
user_id: str | None
|
||||||
messages: list[dict]
|
messages: list[dict]
|
||||||
|
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||||
|
|
||||||
|
|
||||||
class SessionSummaryResponse(BaseModel):
|
class SessionSummaryResponse(BaseModel):
|
||||||
@@ -81,6 +117,14 @@ class ListSessionsResponse(BaseModel):
|
|||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class OperationCompleteRequest(BaseModel):
|
||||||
|
"""Request model for external completion webhook."""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
result: dict | str | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
# ========== Routes ==========
|
# ========== Routes ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -166,13 +210,14 @@ async def get_session(
|
|||||||
Retrieve the details of a specific chat session.
|
Retrieve the details of a specific chat session.
|
||||||
|
|
||||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||||
|
If there's an active stream for this session, returns the task_id for reconnection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The unique identifier for the desired chat session.
|
session_id: The unique identifier for the desired chat session.
|
||||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SessionDetailResponse: Details for the requested session, or None if not found.
|
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
@@ -180,10 +225,27 @@ async def get_session(
|
|||||||
raise NotFoundError(f"Session {session_id} not found.")
|
raise NotFoundError(f"Session {session_id} not found.")
|
||||||
|
|
||||||
messages = [message.model_dump() for message in session.messages]
|
messages = [message.model_dump() for message in session.messages]
|
||||||
logger.info(
|
|
||||||
f"Returning session {session_id}: "
|
# Check if there's an active stream for this session
|
||||||
f"message_count={len(messages)}, "
|
active_stream_info = None
|
||||||
f"roles={[m.get('role') for m in messages]}"
|
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||||
|
session_id, user_id
|
||||||
|
)
|
||||||
|
if active_task:
|
||||||
|
# Filter out the in-progress assistant message from the session response.
|
||||||
|
# The client will receive the complete assistant response through the SSE
|
||||||
|
# stream replay instead, preventing duplicate content.
|
||||||
|
if messages and messages[-1].get("role") == "assistant":
|
||||||
|
messages = messages[:-1]
|
||||||
|
|
||||||
|
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
||||||
|
# Since we filtered out the cached assistant message, the client needs
|
||||||
|
# the full stream to reconstruct the response.
|
||||||
|
active_stream_info = ActiveStreamInfo(
|
||||||
|
task_id=active_task.task_id,
|
||||||
|
last_message_id="0-0",
|
||||||
|
operation_id=active_task.operation_id,
|
||||||
|
tool_name=active_task.tool_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
return SessionDetailResponse(
|
return SessionDetailResponse(
|
||||||
@@ -192,6 +254,7 @@ async def get_session(
|
|||||||
updated_at=session.updated_at.isoformat(),
|
updated_at=session.updated_at.isoformat(),
|
||||||
user_id=session.user_id or None,
|
user_id=session.user_id or None,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
active_stream=active_stream_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -211,19 +274,80 @@ async def stream_chat_post(
|
|||||||
- Tool call UI elements (if invoked)
|
- Tool call UI elements (if invoked)
|
||||||
- Tool execution results
|
- Tool execution results
|
||||||
|
|
||||||
|
The AI generation runs in a background task that continues even if the client disconnects.
|
||||||
|
All chunks are written to Redis for reconnection support. If the client disconnects,
|
||||||
|
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier to associate with the streamed messages.
|
session_id: The chat session identifier to associate with the streamed messages.
|
||||||
request: Request body containing message, is_user_message, and optional context.
|
request: Request body containing message, is_user_message, and optional context.
|
||||||
user_id: Optional authenticated user ID.
|
user_id: Optional authenticated user ID.
|
||||||
Returns:
|
Returns:
|
||||||
StreamingResponse: SSE-formatted response chunks.
|
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
||||||
|
containing the task_id for reconnection.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
stream_start_time = time.perf_counter()
|
||||||
|
log_meta = {"component": "ChatStream", "session_id": session_id}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
|
||||||
|
f"user={user_id}, message_len={len(request.message)}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
|
||||||
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a task in the stream registry for reconnection support
|
||||||
|
task_id = str(uuid_module.uuid4())
|
||||||
|
operation_id = str(uuid_module.uuid4())
|
||||||
|
log_meta["task_id"] = task_id
|
||||||
|
|
||||||
|
task_create_start = time.perf_counter()
|
||||||
|
await stream_registry.create_task(
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
||||||
|
tool_name="chat",
|
||||||
|
operation_id=operation_id,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Background task that runs the AI generation independently of SSE connection
|
||||||
|
async def run_ai_generation():
|
||||||
|
import time as time_module
|
||||||
|
|
||||||
|
gen_start_time = time_module.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
first_chunk_time, ttfc = None, None
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
first_chunk_type: str | None = None
|
try:
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
session_id,
|
session_id,
|
||||||
request.message,
|
request.message,
|
||||||
@@ -231,28 +355,182 @@ async def stream_chat_post(
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
context=request.context,
|
context=request.context,
|
||||||
|
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
|
||||||
):
|
):
|
||||||
if chunk_count < 3:
|
|
||||||
logger.info(
|
|
||||||
"Chat stream chunk",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_type": str(chunk.type),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if not first_chunk_type:
|
|
||||||
first_chunk_type = str(chunk.type)
|
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
yield chunk.to_sse()
|
if first_chunk_time is None:
|
||||||
|
first_chunk_time = time_module.perf_counter()
|
||||||
|
ttfc = first_chunk_time - gen_start_time
|
||||||
logger.info(
|
logger.info(
|
||||||
"Chat stream completed",
|
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
||||||
extra={
|
extra={
|
||||||
"session_id": session_id,
|
"json_fields": {
|
||||||
"chunk_count": chunk_count,
|
**log_meta,
|
||||||
"first_chunk_type": first_chunk_type,
|
"chunk_type": type(chunk).__name__,
|
||||||
|
"time_to_first_chunk_ms": ttfc * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Write to Redis (subscribers will receive via XREAD)
|
||||||
|
await stream_registry.publish_chunk(task_id, chunk)
|
||||||
|
|
||||||
|
gen_end_time = time_module.perf_counter()
|
||||||
|
total_time = (gen_end_time - gen_start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
||||||
|
f"task={task_id}, session={session_id}, "
|
||||||
|
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"time_to_first_chunk_ms": (
|
||||||
|
ttfc * 1000 if ttfc is not None else None
|
||||||
|
),
|
||||||
|
"n_chunks": chunk_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await stream_registry.mark_task_completed(task_id, "completed")
|
||||||
|
except Exception as e:
|
||||||
|
elapsed = time_module.perf_counter() - gen_start_time
|
||||||
|
logger.error(
|
||||||
|
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed * 1000,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await stream_registry.mark_task_completed(task_id, "failed")
|
||||||
|
|
||||||
|
# Start the AI generation in a background task
|
||||||
|
bg_task = asyncio.create_task(run_ai_generation())
|
||||||
|
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||||
|
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||||
|
)
|
||||||
|
|
||||||
|
# SSE endpoint that subscribes to the task's stream
|
||||||
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
import time as time_module
|
||||||
|
|
||||||
|
event_gen_start = time_module.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
||||||
|
f"user={user_id}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
subscriber_queue = None
|
||||||
|
first_chunk_yielded = False
|
||||||
|
chunks_yielded = 0
|
||||||
|
try:
|
||||||
|
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||||
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
|
task_id=task_id,
|
||||||
|
user_id=user_id,
|
||||||
|
last_message_id="0-0", # Get all messages from the beginning
|
||||||
|
)
|
||||||
|
|
||||||
|
if subscriber_queue is None:
|
||||||
|
yield StreamFinish().to_sse()
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
# Read from the subscriber queue and yield to SSE
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] Starting to read from subscriber_queue",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||||
|
chunks_yielded += 1
|
||||||
|
|
||||||
|
if not first_chunk_yielded:
|
||||||
|
first_chunk_yielded = True
|
||||||
|
elapsed = time_module.perf_counter() - event_gen_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
|
||||||
|
f"type={type(chunk).__name__}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunk_type": type(chunk).__name__,
|
||||||
|
"elapsed_ms": elapsed * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
yield chunk.to_sse()
|
||||||
|
|
||||||
|
# Check for finish signal
|
||||||
|
if isinstance(chunk, StreamFinish):
|
||||||
|
total_time = time_module.perf_counter() - event_gen_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
|
||||||
|
f"n_chunks={chunks_yielded}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunks_yielded": chunks_yielded,
|
||||||
|
"total_time_ms": total_time * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
yield StreamHeartbeat().to_sse()
|
||||||
|
|
||||||
|
except GeneratorExit:
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunks_yielded": chunks_yielded,
|
||||||
|
"reason": "client_disconnect",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
pass # Client disconnected - background task continues
|
||||||
|
except Exception as e:
|
||||||
|
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||||
|
logger.error(
|
||||||
|
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||||
|
if subscriber_queue is not None:
|
||||||
|
try:
|
||||||
|
await stream_registry.unsubscribe_from_task(
|
||||||
|
task_id, subscriber_queue
|
||||||
|
)
|
||||||
|
except Exception as unsub_err:
|
||||||
|
logger.error(
|
||||||
|
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||||
|
total_time = time_module.perf_counter() - event_gen_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
||||||
|
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time * 1000,
|
||||||
|
"chunks_yielded": chunks_yielded,
|
||||||
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# AI SDK protocol termination
|
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
@@ -270,44 +548,53 @@ async def stream_chat_post(
|
|||||||
@router.get(
|
@router.get(
|
||||||
"/sessions/{session_id}/stream",
|
"/sessions/{session_id}/stream",
|
||||||
)
|
)
|
||||||
async def stream_chat_get(
|
async def resume_session_stream(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
|
||||||
user_id: str | None = Depends(auth.get_user_id),
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
is_user_message: bool = Query(default=True),
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Stream chat responses for a session (GET - legacy endpoint).
|
Resume an active stream for a session.
|
||||||
|
|
||||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
Called by the AI SDK's ``useChat(resume: true)`` on page load.
|
||||||
- Text fragments as they are generated
|
Checks for an active (in-progress) task on the session and either replays
|
||||||
- Tool call UI elements (if invoked)
|
the full SSE stream or returns 204 No Content if nothing is running.
|
||||||
- Tool execution results
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier to associate with the streamed messages.
|
session_id: The chat session identifier.
|
||||||
message: The user's new message to process.
|
|
||||||
user_id: Optional authenticated user ID.
|
user_id: Optional authenticated user ID.
|
||||||
is_user_message: Whether the message is a user message.
|
|
||||||
Returns:
|
|
||||||
StreamingResponse: SSE-formatted response chunks.
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamingResponse (SSE) when an active stream exists,
|
||||||
|
or 204 No Content when there is nothing to resume.
|
||||||
"""
|
"""
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
import asyncio
|
||||||
|
|
||||||
|
active_task, _last_id = await stream_registry.get_active_task_for_session(
|
||||||
|
session_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not active_task:
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
|
task_id=active_task.task_id,
|
||||||
|
user_id=user_id,
|
||||||
|
last_message_id="0-0", # Full replay so useChat rebuilds the message
|
||||||
|
)
|
||||||
|
|
||||||
|
if subscriber_queue is None:
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
first_chunk_type: str | None = None
|
first_chunk_type: str | None = None
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
try:
|
||||||
session_id,
|
while True:
|
||||||
message,
|
try:
|
||||||
is_user_message=is_user_message,
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||||
user_id=user_id,
|
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
|
||||||
):
|
|
||||||
if chunk_count < 3:
|
if chunk_count < 3:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Chat stream chunk",
|
"Resume stream chunk",
|
||||||
extra={
|
extra={
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"chunk_type": str(chunk.type),
|
"chunk_type": str(chunk.type),
|
||||||
@@ -317,15 +604,33 @@ async def stream_chat_get(
|
|||||||
first_chunk_type = str(chunk.type)
|
first_chunk_type = str(chunk.type)
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
yield chunk.to_sse()
|
yield chunk.to_sse()
|
||||||
|
|
||||||
|
if isinstance(chunk, StreamFinish):
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
yield StreamHeartbeat().to_sse()
|
||||||
|
except GeneratorExit:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await stream_registry.unsubscribe_from_task(
|
||||||
|
active_task.task_id, subscriber_queue
|
||||||
|
)
|
||||||
|
except Exception as unsub_err:
|
||||||
|
logger.error(
|
||||||
|
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Chat stream completed",
|
"Resume stream completed",
|
||||||
extra={
|
extra={
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"chunk_count": chunk_count,
|
"n_chunks": chunk_count,
|
||||||
"first_chunk_type": first_chunk_type,
|
"first_chunk_type": first_chunk_type,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# AI SDK protocol termination
|
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
@@ -334,8 +639,8 @@ async def stream_chat_get(
|
|||||||
headers={
|
headers={
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
"X-Accel-Buffering": "no",
|
||||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
"x-vercel-ai-ui-message-stream": "v1",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -366,6 +671,251 @@ async def session_assign_user(
|
|||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Task Streaming (SSE Reconnection) ==========
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/tasks/{task_id}/stream",
|
||||||
|
)
|
||||||
|
async def stream_task(
|
||||||
|
task_id: str,
|
||||||
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
|
last_message_id: str = Query(
|
||||||
|
default="0-0",
|
||||||
|
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Reconnect to a long-running task's SSE stream.
|
||||||
|
|
||||||
|
When a long-running operation (like agent generation) starts, the client
|
||||||
|
receives a task_id. If the connection drops, the client can reconnect
|
||||||
|
using this endpoint to resume receiving updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The task ID from the operation_started response.
|
||||||
|
user_id: Authenticated user ID for ownership validation.
|
||||||
|
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
|
||||||
|
"""
|
||||||
|
# Check task existence and expiry before subscribing
|
||||||
|
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
|
||||||
|
|
||||||
|
if error_code == "TASK_EXPIRED":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=410,
|
||||||
|
detail={
|
||||||
|
"code": "TASK_EXPIRED",
|
||||||
|
"message": "This operation has expired. Please try again.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if error_code == "TASK_NOT_FOUND":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={
|
||||||
|
"code": "TASK_NOT_FOUND",
|
||||||
|
"message": f"Task {task_id} not found.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate ownership if task has an owner
|
||||||
|
if task and task.user_id and user_id != task.user_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={
|
||||||
|
"code": "ACCESS_DENIED",
|
||||||
|
"message": "You do not have access to this task.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get subscriber queue from stream registry
|
||||||
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
|
task_id=task_id,
|
||||||
|
user_id=user_id,
|
||||||
|
last_message_id=last_message_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if subscriber_queue is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={
|
||||||
|
"code": "TASK_NOT_FOUND",
|
||||||
|
"message": f"Task {task_id} not found or access denied.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Wait for next chunk with timeout for heartbeats
|
||||||
|
chunk = await asyncio.wait_for(
|
||||||
|
subscriber_queue.get(), timeout=heartbeat_interval
|
||||||
|
)
|
||||||
|
yield chunk.to_sse()
|
||||||
|
|
||||||
|
# Check for finish signal
|
||||||
|
if isinstance(chunk, StreamFinish):
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Send heartbeat to keep connection alive
|
||||||
|
yield StreamHeartbeat().to_sse()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
# Unsubscribe when client disconnects or stream ends
|
||||||
|
try:
|
||||||
|
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
||||||
|
except Exception as unsub_err:
|
||||||
|
logger.error(
|
||||||
|
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"x-vercel-ai-ui-message-stream": "v1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/tasks/{task_id}",
|
||||||
|
)
|
||||||
|
async def get_task_status(
|
||||||
|
task_id: str,
|
||||||
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Get the status of a long-running task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The task ID to check.
|
||||||
|
user_id: Authenticated user ID for ownership validation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Task status including task_id, status, tool_name, and operation_id.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If task_id is not found or user doesn't have access.
|
||||||
|
"""
|
||||||
|
task = await stream_registry.get_task(task_id)
|
||||||
|
|
||||||
|
if task is None:
|
||||||
|
raise NotFoundError(f"Task {task_id} not found.")
|
||||||
|
|
||||||
|
# Validate ownership - if task has an owner, requester must match
|
||||||
|
if task.user_id and user_id != task.user_id:
|
||||||
|
raise NotFoundError(f"Task {task_id} not found.")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task_id": task.task_id,
|
||||||
|
"session_id": task.session_id,
|
||||||
|
"status": task.status,
|
||||||
|
"tool_name": task.tool_name,
|
||||||
|
"operation_id": task.operation_id,
|
||||||
|
"created_at": task.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== External Completion Webhook ==========
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/operations/{operation_id}/complete",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
async def complete_operation(
|
||||||
|
operation_id: str,
|
||||||
|
request: OperationCompleteRequest,
|
||||||
|
x_api_key: str | None = Header(default=None),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
External completion webhook for long-running operations.
|
||||||
|
|
||||||
|
Called by Agent Generator (or other services) when an operation completes.
|
||||||
|
This triggers the stream registry to publish completion and continue LLM generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation_id: The operation ID to complete.
|
||||||
|
request: Completion payload with success status and result/error.
|
||||||
|
x_api_key: Internal API key for authentication.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Status of the completion.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If API key is invalid or operation not found.
|
||||||
|
"""
|
||||||
|
# Validate internal API key - reject if not configured or invalid
|
||||||
|
if not config.internal_api_key:
|
||||||
|
logger.error(
|
||||||
|
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail="Webhook not available: internal API key not configured",
|
||||||
|
)
|
||||||
|
if x_api_key != config.internal_api_key:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||||
|
|
||||||
|
# Find task by operation_id
|
||||||
|
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||||
|
if task is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Operation {operation_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Received completion webhook for operation {operation_id} "
|
||||||
|
f"(task_id={task.task_id}, success={request.success})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.success:
|
||||||
|
await process_operation_success(task, request.result)
|
||||||
|
else:
|
||||||
|
await process_operation_failure(task, request.error)
|
||||||
|
|
||||||
|
return {"status": "ok", "task_id": task.task_id}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Configuration ==========
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/ttl", status_code=200)
|
||||||
|
async def get_ttl_config() -> dict:
|
||||||
|
"""
|
||||||
|
Get the stream TTL configuration.
|
||||||
|
|
||||||
|
Returns the Time-To-Live settings for chat streams, which determines
|
||||||
|
how long clients can reconnect to an active stream.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: TTL configuration with seconds and milliseconds values.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"stream_ttl_seconds": config.stream_ttl,
|
||||||
|
"stream_ttl_ms": config.stream_ttl * 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ========== Health Check ==========
|
# ========== Health Check ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -402,3 +952,42 @@ async def health_check() -> dict:
|
|||||||
"service": "chat",
|
"service": "chat",
|
||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
|
||||||
|
|
||||||
|
ToolResponseUnion = (
|
||||||
|
AgentsFoundResponse
|
||||||
|
| NoResultsResponse
|
||||||
|
| AgentDetailsResponse
|
||||||
|
| SetupRequirementsResponse
|
||||||
|
| ExecutionStartedResponse
|
||||||
|
| NeedLoginResponse
|
||||||
|
| ErrorResponse
|
||||||
|
| InputValidationErrorResponse
|
||||||
|
| AgentOutputResponse
|
||||||
|
| UnderstandingUpdatedResponse
|
||||||
|
| AgentPreviewResponse
|
||||||
|
| AgentSavedResponse
|
||||||
|
| ClarificationNeededResponse
|
||||||
|
| BlockListResponse
|
||||||
|
| BlockOutputResponse
|
||||||
|
| DocSearchResultsResponse
|
||||||
|
| DocPageResponse
|
||||||
|
| OperationStartedResponse
|
||||||
|
| OperationPendingResponse
|
||||||
|
| OperationInProgressResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/schema/tool-responses",
|
||||||
|
response_model=ToolResponseUnion,
|
||||||
|
include_in_schema=True,
|
||||||
|
summary="[Dummy] Tool response type export for codegen",
|
||||||
|
description="This endpoint is not meant to be called. It exists solely to "
|
||||||
|
"expose tool response models in the OpenAPI schema for frontend codegen.",
|
||||||
|
)
|
||||||
|
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
|
||||||
|
"""Never called at runtime. Exists only so Orval generates TS types."""
|
||||||
|
raise HTTPException(status_code=501, detail="Schema-only endpoint")
|
||||||
|
|||||||
@@ -33,9 +33,10 @@ from backend.data.understanding import (
|
|||||||
get_business_understanding,
|
get_business_understanding,
|
||||||
)
|
)
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import AppEnvironment, Settings
|
||||||
|
|
||||||
from . import db as chat_db
|
from . import db as chat_db
|
||||||
|
from . import stream_registry
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import (
|
from .model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
@@ -51,8 +52,10 @@ from .response_model import (
|
|||||||
StreamBaseResponse,
|
StreamBaseResponse,
|
||||||
StreamError,
|
StreamError,
|
||||||
StreamFinish,
|
StreamFinish,
|
||||||
|
StreamFinishStep,
|
||||||
StreamHeartbeat,
|
StreamHeartbeat,
|
||||||
StreamStart,
|
StreamStart,
|
||||||
|
StreamStartStep,
|
||||||
StreamTextDelta,
|
StreamTextDelta,
|
||||||
StreamTextEnd,
|
StreamTextEnd,
|
||||||
StreamTextStart,
|
StreamTextStart,
|
||||||
@@ -221,8 +224,18 @@ async def _get_system_prompt_template(context: str) -> str:
|
|||||||
try:
|
try:
|
||||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||||
# Use asyncio.to_thread to avoid blocking the event loop
|
# Use asyncio.to_thread to avoid blocking the event loop
|
||||||
|
# In non-production environments, fetch the latest prompt version
|
||||||
|
# instead of the production-labeled version for easier testing
|
||||||
|
label = (
|
||||||
|
None
|
||||||
|
if settings.config.app_env == AppEnvironment.PRODUCTION
|
||||||
|
else "latest"
|
||||||
|
)
|
||||||
prompt = await asyncio.to_thread(
|
prompt = await asyncio.to_thread(
|
||||||
langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0
|
langfuse.get_prompt,
|
||||||
|
config.langfuse_prompt_name,
|
||||||
|
label=label,
|
||||||
|
cache_ttl_seconds=0,
|
||||||
)
|
)
|
||||||
return prompt.compile(users_information=context)
|
return prompt.compile(users_information=context)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -340,6 +353,10 @@ async def stream_chat_completion(
|
|||||||
retry_count: int = 0,
|
retry_count: int = 0,
|
||||||
session: ChatSession | None = None,
|
session: ChatSession | None = None,
|
||||||
context: dict[str, str] | None = None, # {url: str, content: str}
|
context: dict[str, str] | None = None, # {url: str, content: str}
|
||||||
|
_continuation_message_id: (
|
||||||
|
str | None
|
||||||
|
) = None, # Internal: reuse message ID for tool call continuations
|
||||||
|
_task_id: str | None = None, # Internal: task ID for SSE reconnection support
|
||||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
"""Main entry point for streaming chat completions with database handling.
|
"""Main entry point for streaming chat completions with database handling.
|
||||||
|
|
||||||
@@ -360,21 +377,45 @@ async def stream_chat_completion(
|
|||||||
ValueError: If max_context_messages is exceeded
|
ValueError: If max_context_messages is exceeded
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
completion_start = time.monotonic()
|
||||||
|
|
||||||
|
# Build log metadata for structured logging
|
||||||
|
log_meta = {"component": "ChatService", "session_id": session_id}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
|
||||||
|
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"message_len": len(message) if message else 0,
|
||||||
|
"is_user_message": is_user_message,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only fetch from Redis if session not provided (initial call)
|
# Only fetch from Redis if session not provided (initial call)
|
||||||
if session is None:
|
if session is None:
|
||||||
|
fetch_start = time.monotonic()
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
fetch_time = (time.monotonic() - fetch_start) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Fetched session from Redis: {session.session_id if session else 'None'}, "
|
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
|
||||||
f"message_count={len(session.messages) if session else 0}"
|
f"n_messages={len(session.messages) if session else 0}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": fetch_time,
|
||||||
|
"n_messages": len(session.messages) if session else 0,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using provided session object: {session.session_id}, "
|
f"[TIMING] Using provided session, messages={len(session.messages)}",
|
||||||
f"message_count={len(session.messages)}"
|
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||||
)
|
)
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
@@ -395,17 +436,25 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
# Track user message in PostHog
|
# Track user message in PostHog
|
||||||
if is_user_message:
|
if is_user_message:
|
||||||
|
posthog_start = time.monotonic()
|
||||||
track_user_message(
|
track_user_message(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
message_length=len(message),
|
message_length=len(message),
|
||||||
)
|
)
|
||||||
|
posthog_time = (time.monotonic() - posthog_start) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Upserting session: {session.session_id} with user id {session.user_id}, "
|
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
|
||||||
f"message_count={len(session.messages)}"
|
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
upsert_start = time.monotonic()
|
||||||
session = await upsert_chat_session(session)
|
session = await upsert_chat_session(session)
|
||||||
|
upsert_time = (time.monotonic() - upsert_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
|
||||||
|
)
|
||||||
assert session, "Session not found"
|
assert session, "Session not found"
|
||||||
|
|
||||||
# Generate title for new sessions on first user message (non-blocking)
|
# Generate title for new sessions on first user message (non-blocking)
|
||||||
@@ -443,7 +492,13 @@ async def stream_chat_completion(
|
|||||||
asyncio.create_task(_update_title())
|
asyncio.create_task(_update_title())
|
||||||
|
|
||||||
# Build system prompt with business understanding
|
# Build system prompt with business understanding
|
||||||
|
prompt_start = time.monotonic()
|
||||||
system_prompt, understanding = await _build_system_prompt(user_id)
|
system_prompt, understanding = await _build_system_prompt(user_id)
|
||||||
|
prompt_time = (time.monotonic() - prompt_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize variables for streaming
|
# Initialize variables for streaming
|
||||||
assistant_response = ChatMessage(
|
assistant_response = ChatMessage(
|
||||||
@@ -468,13 +523,27 @@ async def stream_chat_completion(
|
|||||||
# Generate unique IDs for AI SDK protocol
|
# Generate unique IDs for AI SDK protocol
|
||||||
import uuid as uuid_module
|
import uuid as uuid_module
|
||||||
|
|
||||||
message_id = str(uuid_module.uuid4())
|
is_continuation = _continuation_message_id is not None
|
||||||
|
message_id = _continuation_message_id or str(uuid_module.uuid4())
|
||||||
text_block_id = str(uuid_module.uuid4())
|
text_block_id = str(uuid_module.uuid4())
|
||||||
|
|
||||||
# Yield message start
|
# Only yield message start for the initial call, not for continuations.
|
||||||
yield StreamStart(messageId=message_id)
|
setup_time = (time.monotonic() - completion_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||||
|
)
|
||||||
|
if not is_continuation:
|
||||||
|
yield StreamStart(messageId=message_id, taskId=_task_id)
|
||||||
|
|
||||||
|
# Emit start-step before each LLM call (AI SDK uses this to add step boundaries)
|
||||||
|
yield StreamStartStep()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] Calling _stream_chat_chunks",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
async for chunk in _stream_chat_chunks(
|
async for chunk in _stream_chat_chunks(
|
||||||
session=session,
|
session=session,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@@ -574,6 +643,10 @@ async def stream_chat_completion(
|
|||||||
)
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
elif isinstance(chunk, StreamFinish):
|
elif isinstance(chunk, StreamFinish):
|
||||||
|
if has_done_tool_call:
|
||||||
|
# Tool calls happened — close the step but don't send message-level finish.
|
||||||
|
# The continuation will open a new step, and finish will come at the end.
|
||||||
|
yield StreamFinishStep()
|
||||||
if not has_done_tool_call:
|
if not has_done_tool_call:
|
||||||
# Emit text-end before finish if we received text but haven't closed it
|
# Emit text-end before finish if we received text but haven't closed it
|
||||||
if has_received_text and not text_streaming_ended:
|
if has_received_text and not text_streaming_ended:
|
||||||
@@ -605,6 +678,8 @@ async def stream_chat_completion(
|
|||||||
has_saved_assistant_message = True
|
has_saved_assistant_message = True
|
||||||
|
|
||||||
has_yielded_end = True
|
has_yielded_end = True
|
||||||
|
# Emit finish-step before finish (resets AI SDK text/reasoning state)
|
||||||
|
yield StreamFinishStep()
|
||||||
yield chunk
|
yield chunk
|
||||||
elif isinstance(chunk, StreamError):
|
elif isinstance(chunk, StreamError):
|
||||||
has_yielded_error = True
|
has_yielded_error = True
|
||||||
@@ -617,6 +692,9 @@ async def stream_chat_completion(
|
|||||||
total_tokens=chunk.totalTokens,
|
total_tokens=chunk.totalTokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
elif isinstance(chunk, StreamHeartbeat):
|
||||||
|
# Pass through heartbeat to keep SSE connection alive
|
||||||
|
yield chunk
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
||||||
|
|
||||||
@@ -651,6 +729,10 @@ async def stream_chat_completion(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
||||||
)
|
)
|
||||||
|
# Close the current step before retrying so the recursive call's
|
||||||
|
# StreamStartStep doesn't produce unbalanced step events.
|
||||||
|
if not has_yielded_end:
|
||||||
|
yield StreamFinishStep()
|
||||||
should_retry = True
|
should_retry = True
|
||||||
else:
|
else:
|
||||||
# Non-retryable error or max retries exceeded
|
# Non-retryable error or max retries exceeded
|
||||||
@@ -686,6 +768,7 @@ async def stream_chat_completion(
|
|||||||
error_response = StreamError(errorText=error_message)
|
error_response = StreamError(errorText=error_message)
|
||||||
yield error_response
|
yield error_response
|
||||||
if not has_yielded_end:
|
if not has_yielded_end:
|
||||||
|
yield StreamFinishStep()
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -700,6 +783,8 @@ async def stream_chat_completion(
|
|||||||
retry_count=retry_count + 1,
|
retry_count=retry_count + 1,
|
||||||
session=session,
|
session=session,
|
||||||
context=context,
|
context=context,
|
||||||
|
_continuation_message_id=message_id, # Reuse message ID since start was already sent
|
||||||
|
_task_id=_task_id,
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
return # Exit after retry to avoid double-saving in finally block
|
return # Exit after retry to avoid double-saving in finally block
|
||||||
@@ -715,9 +800,13 @@ async def stream_chat_completion(
|
|||||||
# Build the messages list in the correct order
|
# Build the messages list in the correct order
|
||||||
messages_to_save: list[ChatMessage] = []
|
messages_to_save: list[ChatMessage] = []
|
||||||
|
|
||||||
# Add assistant message with tool_calls if any
|
# Add assistant message with tool_calls if any.
|
||||||
|
# Use extend (not assign) to preserve tool_calls already added by
|
||||||
|
# _yield_tool_call for long-running tools.
|
||||||
if accumulated_tool_calls:
|
if accumulated_tool_calls:
|
||||||
assistant_response.tool_calls = accumulated_tool_calls
|
if not assistant_response.tool_calls:
|
||||||
|
assistant_response.tool_calls = []
|
||||||
|
assistant_response.tool_calls.extend(accumulated_tool_calls)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||||
)
|
)
|
||||||
@@ -769,6 +858,8 @@ async def stream_chat_completion(
|
|||||||
session=session, # Pass session object to avoid Redis refetch
|
session=session, # Pass session object to avoid Redis refetch
|
||||||
context=context,
|
context=context,
|
||||||
tool_call_response=str(tool_response_messages),
|
tool_call_response=str(tool_response_messages),
|
||||||
|
_continuation_message_id=message_id, # Reuse message ID to avoid duplicates
|
||||||
|
_task_id=_task_id,
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
@@ -879,9 +970,21 @@ async def _stream_chat_chunks(
|
|||||||
SSE formatted JSON response objects
|
SSE formatted JSON response objects
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import time as time_module
|
||||||
|
|
||||||
|
stream_chunks_start = time_module.perf_counter()
|
||||||
model = config.model
|
model = config.model
|
||||||
|
|
||||||
logger.info("Starting pure chat stream")
|
# Build log metadata for structured logging
|
||||||
|
log_meta = {"component": "ChatService", "session_id": session.session_id}
|
||||||
|
if session.user_id:
|
||||||
|
log_meta["user_id"] = session.user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
|
||||||
|
f"user={session.user_id}, n_messages={len(session.messages)}",
|
||||||
|
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||||
|
)
|
||||||
|
|
||||||
messages = session.to_openai_messages()
|
messages = session.to_openai_messages()
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
@@ -892,12 +995,18 @@ async def _stream_chat_chunks(
|
|||||||
messages = [system_message] + messages
|
messages = [system_message] + messages
|
||||||
|
|
||||||
# Apply context window management
|
# Apply context window management
|
||||||
|
context_start = time_module.perf_counter()
|
||||||
context_result = await _manage_context_window(
|
context_result = await _manage_context_window(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=model,
|
model=model,
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
base_url=config.base_url,
|
base_url=config.base_url,
|
||||||
)
|
)
|
||||||
|
context_time = (time_module.perf_counter() - context_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
|
||||||
|
)
|
||||||
|
|
||||||
if context_result.error:
|
if context_result.error:
|
||||||
if "System prompt dropped" in context_result.error:
|
if "System prompt dropped" in context_result.error:
|
||||||
@@ -932,9 +1041,19 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
while retry_count <= MAX_RETRIES:
|
while retry_count <= MAX_RETRIES:
|
||||||
try:
|
try:
|
||||||
|
elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||||
|
retry_info = (
|
||||||
|
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Creating OpenAI chat completion stream..."
|
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
|
||||||
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"retry_count": retry_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build extra_body for OpenRouter tracing and PostHog analytics
|
# Build extra_body for OpenRouter tracing and PostHog analytics
|
||||||
@@ -951,6 +1070,11 @@ async def _stream_chat_chunks(
|
|||||||
:128
|
:128
|
||||||
] # OpenRouter limit
|
] # OpenRouter limit
|
||||||
|
|
||||||
|
# Enable adaptive thinking for Anthropic models via OpenRouter
|
||||||
|
if config.thinking_enabled and "anthropic" in model.lower():
|
||||||
|
extra_body["reasoning"] = {"enabled": True}
|
||||||
|
|
||||||
|
api_call_start = time_module.perf_counter()
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
@@ -960,6 +1084,11 @@ async def _stream_chat_chunks(
|
|||||||
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
|
||||||
|
)
|
||||||
|
|
||||||
# Variables to accumulate tool calls
|
# Variables to accumulate tool calls
|
||||||
tool_calls: list[dict[str, Any]] = []
|
tool_calls: list[dict[str, Any]] = []
|
||||||
@@ -970,10 +1099,13 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
# Track if we've started the text block
|
# Track if we've started the text block
|
||||||
text_started = False
|
text_started = False
|
||||||
|
first_content_chunk = True
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
# Process the stream
|
# Process the stream
|
||||||
chunk: ChatCompletionChunk
|
chunk: ChatCompletionChunk
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
|
chunk_count += 1
|
||||||
if chunk.usage:
|
if chunk.usage:
|
||||||
yield StreamUsage(
|
yield StreamUsage(
|
||||||
promptTokens=chunk.usage.prompt_tokens,
|
promptTokens=chunk.usage.prompt_tokens,
|
||||||
@@ -996,6 +1128,23 @@ async def _stream_chat_chunks(
|
|||||||
if not text_started and text_block_id:
|
if not text_started and text_block_id:
|
||||||
yield StreamTextStart(id=text_block_id)
|
yield StreamTextStart(id=text_block_id)
|
||||||
text_started = True
|
text_started = True
|
||||||
|
# Log timing for first content chunk
|
||||||
|
if first_content_chunk:
|
||||||
|
first_content_chunk = False
|
||||||
|
ttfc = (
|
||||||
|
time_module.perf_counter() - api_call_start
|
||||||
|
) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
|
||||||
|
f"(since API call), n_chunks={chunk_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"time_to_first_chunk_ms": ttfc,
|
||||||
|
"n_chunks": chunk_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
# Stream the text delta
|
# Stream the text delta
|
||||||
text_response = StreamTextDelta(
|
text_response = StreamTextDelta(
|
||||||
id=text_block_id or "",
|
id=text_block_id or "",
|
||||||
@@ -1052,7 +1201,21 @@ async def _stream_chat_chunks(
|
|||||||
toolName=tool_calls[idx]["function"]["name"],
|
toolName=tool_calls[idx]["function"]["name"],
|
||||||
)
|
)
|
||||||
emitted_start_for_idx.add(idx)
|
emitted_start_for_idx.add(idx)
|
||||||
logger.info(f"Stream complete. Finish reason: {finish_reason}")
|
stream_duration = time_module.perf_counter() - api_call_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
|
||||||
|
f"duration={stream_duration:.2f}s, "
|
||||||
|
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"stream_duration_ms": stream_duration * 1000,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
"n_chunks": chunk_count,
|
||||||
|
"n_tool_calls": len(tool_calls),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Yield all accumulated tool calls after the stream is complete
|
# Yield all accumulated tool calls after the stream is complete
|
||||||
# This ensures all tool call arguments have been fully received
|
# This ensures all tool call arguments have been fully received
|
||||||
@@ -1072,6 +1235,12 @@ async def _stream_chat_chunks(
|
|||||||
# Re-raise to trigger retry logic in the parent function
|
# Re-raise to trigger retry logic in the parent function
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
|
||||||
|
f"session={session.session_id}, user={session.user_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||||
|
)
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1184,8 +1353,9 @@ async def _yield_tool_call(
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Generate operation ID
|
# Generate operation ID and task ID
|
||||||
operation_id = str(uuid_module.uuid4())
|
operation_id = str(uuid_module.uuid4())
|
||||||
|
task_id = str(uuid_module.uuid4())
|
||||||
|
|
||||||
# Build a user-friendly message based on tool and arguments
|
# Build a user-friendly message based on tool and arguments
|
||||||
if tool_name == "create_agent":
|
if tool_name == "create_agent":
|
||||||
@@ -1228,13 +1398,19 @@ async def _yield_tool_call(
|
|||||||
|
|
||||||
# Wrap session save and task creation in try-except to release lock on failure
|
# Wrap session save and task creation in try-except to release lock on failure
|
||||||
try:
|
try:
|
||||||
# Save assistant message with tool_call FIRST (required by LLM)
|
# Create task in stream registry for SSE reconnection support
|
||||||
assistant_message = ChatMessage(
|
await stream_registry.create_task(
|
||||||
role="assistant",
|
task_id=task_id,
|
||||||
content="",
|
session_id=session.session_id,
|
||||||
tool_calls=[tool_calls[yield_idx]],
|
user_id=session.user_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
operation_id=operation_id,
|
||||||
)
|
)
|
||||||
session.messages.append(assistant_message)
|
|
||||||
|
# Attach the tool_call to the current turn's assistant message
|
||||||
|
# (or create one if this is a tool-only response with no text).
|
||||||
|
session.add_tool_call_to_current_turn(tool_calls[yield_idx])
|
||||||
|
|
||||||
# Then save pending tool result
|
# Then save pending tool result
|
||||||
pending_message = ChatMessage(
|
pending_message = ChatMessage(
|
||||||
@@ -1249,23 +1425,27 @@ async def _yield_tool_call(
|
|||||||
session.messages.append(pending_message)
|
session.messages.append(pending_message)
|
||||||
await upsert_chat_session(session)
|
await upsert_chat_session(session)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Saved pending operation {operation_id} for tool {tool_name} "
|
f"Saved pending operation {operation_id} (task_id={task_id}) "
|
||||||
f"in session {session.session_id}"
|
f"for tool {tool_name} in session {session.session_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store task reference in module-level set to prevent GC before completion
|
# Store task reference in module-level set to prevent GC before completion
|
||||||
task = asyncio.create_task(
|
bg_task = asyncio.create_task(
|
||||||
_execute_long_running_tool(
|
_execute_long_running_tool_with_streaming(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
parameters=arguments,
|
parameters=arguments,
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
user_id=session.user_id,
|
user_id=session.user_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
_background_tasks.add(task)
|
_background_tasks.add(bg_task)
|
||||||
task.add_done_callback(_background_tasks.discard)
|
bg_task.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
|
# Associate the asyncio task with the stream registry task
|
||||||
|
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Roll back appended messages to prevent data corruption on subsequent saves
|
# Roll back appended messages to prevent data corruption on subsequent saves
|
||||||
if (
|
if (
|
||||||
@@ -1283,6 +1463,11 @@ async def _yield_tool_call(
|
|||||||
|
|
||||||
# Release the Redis lock since the background task won't be spawned
|
# Release the Redis lock since the background task won't be spawned
|
||||||
await _mark_operation_completed(tool_call_id)
|
await _mark_operation_completed(tool_call_id)
|
||||||
|
# Mark stream registry task as failed if it was created
|
||||||
|
try:
|
||||||
|
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
||||||
)
|
)
|
||||||
@@ -1296,6 +1481,7 @@ async def _yield_tool_call(
|
|||||||
message=started_msg,
|
message=started_msg,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
|
task_id=task_id, # Include task_id for SSE reconnection
|
||||||
).model_dump_json(),
|
).model_dump_json(),
|
||||||
success=True,
|
success=True,
|
||||||
)
|
)
|
||||||
@@ -1365,6 +1551,9 @@ async def _execute_long_running_tool(
|
|||||||
|
|
||||||
This function runs independently of the SSE connection, so the operation
|
This function runs independently of the SSE connection, so the operation
|
||||||
survives if the user closes their browser tab.
|
survives if the user closes their browser tab.
|
||||||
|
|
||||||
|
NOTE: This is the legacy function without stream registry support.
|
||||||
|
Use _execute_long_running_tool_with_streaming for new implementations.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Load fresh session (not stale reference)
|
# Load fresh session (not stale reference)
|
||||||
@@ -1417,6 +1606,134 @@ async def _execute_long_running_tool(
|
|||||||
await _mark_operation_completed(tool_call_id)
|
await _mark_operation_completed(tool_call_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_long_running_tool_with_streaming(
|
||||||
|
tool_name: str,
|
||||||
|
parameters: dict[str, Any],
|
||||||
|
tool_call_id: str,
|
||||||
|
operation_id: str,
|
||||||
|
task_id: str,
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a long-running tool with stream registry support for SSE reconnection.
|
||||||
|
|
||||||
|
This function runs independently of the SSE connection, publishes progress
|
||||||
|
to the stream registry, and survives if the user closes their browser tab.
|
||||||
|
Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming.
|
||||||
|
|
||||||
|
If the external service returns a 202 Accepted (async), this function exits
|
||||||
|
early and lets the Redis Streams completion consumer handle the rest.
|
||||||
|
"""
|
||||||
|
# Track whether we delegated to async processing - if so, the Redis Streams
|
||||||
|
# completion consumer (stream_registry / completion_consumer) will handle cleanup, not us
|
||||||
|
delegated_to_async = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load fresh session (not stale reference)
|
||||||
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
if not session:
|
||||||
|
logger.error(f"Session {session_id} not found for background tool")
|
||||||
|
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Pass operation_id and task_id to the tool for async processing
|
||||||
|
enriched_parameters = {
|
||||||
|
**parameters,
|
||||||
|
"_operation_id": operation_id,
|
||||||
|
"_task_id": task_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute the actual tool
|
||||||
|
result = await execute_tool(
|
||||||
|
tool_name=tool_name,
|
||||||
|
parameters=enriched_parameters,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the tool result indicates async processing
|
||||||
|
# (e.g., Agent Generator returned 202 Accepted)
|
||||||
|
try:
|
||||||
|
if isinstance(result.output, dict):
|
||||||
|
result_data = result.output
|
||||||
|
elif result.output:
|
||||||
|
result_data = orjson.loads(result.output)
|
||||||
|
else:
|
||||||
|
result_data = {}
|
||||||
|
if result_data.get("status") == "accepted":
|
||||||
|
logger.info(
|
||||||
|
f"Tool {tool_name} delegated to async processing "
|
||||||
|
f"(operation_id={operation_id}, task_id={task_id}). "
|
||||||
|
f"Redis Streams completion consumer will handle the rest."
|
||||||
|
)
|
||||||
|
# Don't publish result, don't continue with LLM, and don't cleanup
|
||||||
|
# The Redis Streams consumer (completion_consumer) will handle
|
||||||
|
# everything when the external service completes via webhook
|
||||||
|
delegated_to_async = True
|
||||||
|
return
|
||||||
|
except (orjson.JSONDecodeError, TypeError):
|
||||||
|
pass # Not JSON or not async - continue normally
|
||||||
|
|
||||||
|
# Publish tool result to stream registry
|
||||||
|
await stream_registry.publish_chunk(task_id, result)
|
||||||
|
|
||||||
|
# Update the pending message with result
|
||||||
|
result_str = (
|
||||||
|
result.output
|
||||||
|
if isinstance(result.output, str)
|
||||||
|
else orjson.dumps(result.output).decode("utf-8")
|
||||||
|
)
|
||||||
|
await _update_pending_operation(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
result=result_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Background tool {tool_name} completed for session {session_id} "
|
||||||
|
f"(task_id={task_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate LLM continuation and stream chunks to registry
|
||||||
|
await _generate_llm_continuation_with_streaming(
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark task as completed in stream registry
|
||||||
|
await stream_registry.mark_task_completed(task_id, status="completed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True)
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
message=f"Tool {tool_name} failed: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Publish error to stream registry followed by finish event
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task_id,
|
||||||
|
StreamError(errorText=str(e)),
|
||||||
|
)
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||||
|
|
||||||
|
await _update_pending_operation(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
result=error_response.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark task as failed in stream registry
|
||||||
|
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||||
|
finally:
|
||||||
|
# Only cleanup if we didn't delegate to async processing
|
||||||
|
# For async path, the Redis Streams completion consumer handles cleanup
|
||||||
|
if not delegated_to_async:
|
||||||
|
await _mark_operation_completed(tool_call_id)
|
||||||
|
|
||||||
|
|
||||||
async def _update_pending_operation(
|
async def _update_pending_operation(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
tool_call_id: str,
|
tool_call_id: str,
|
||||||
@@ -1516,6 +1833,10 @@ async def _generate_llm_continuation(
|
|||||||
if session_id:
|
if session_id:
|
||||||
extra_body["session_id"] = session_id[:128]
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
|
# Enable adaptive thinking for Anthropic models via OpenRouter
|
||||||
|
if config.thinking_enabled and "anthropic" in config.model.lower():
|
||||||
|
extra_body["reasoning"] = {"enabled": True}
|
||||||
|
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
last_error: Exception | None = None
|
last_error: Exception | None = None
|
||||||
response = None
|
response = None
|
||||||
@@ -1597,3 +1918,135 @@ async def _generate_llm_continuation(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_llm_continuation_with_streaming(
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
task_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Generate an LLM response with streaming to the stream registry.
|
||||||
|
|
||||||
|
This is called by background tasks to continue the conversation
|
||||||
|
after a tool result is saved. Chunks are published to the stream registry
|
||||||
|
so reconnecting clients can receive them.
|
||||||
|
"""
|
||||||
|
import uuid as uuid_module
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load fresh session from DB (bypass cache to get the updated tool result)
|
||||||
|
await invalidate_session_cache(session_id)
|
||||||
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
if not session:
|
||||||
|
logger.error(f"Session {session_id} not found for LLM continuation")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build system prompt
|
||||||
|
system_prompt, _ = await _build_system_prompt(user_id)
|
||||||
|
|
||||||
|
# Build messages in OpenAI format
|
||||||
|
messages = session.to_openai_messages()
|
||||||
|
if system_prompt:
|
||||||
|
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||||
|
|
||||||
|
system_message = ChatCompletionSystemMessageParam(
|
||||||
|
role="system",
|
||||||
|
content=system_prompt,
|
||||||
|
)
|
||||||
|
messages = [system_message] + messages
|
||||||
|
|
||||||
|
# Build extra_body for tracing
|
||||||
|
extra_body: dict[str, Any] = {
|
||||||
|
"posthogProperties": {
|
||||||
|
"environment": settings.config.app_env.value,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if user_id:
|
||||||
|
extra_body["user"] = user_id[:128]
|
||||||
|
extra_body["posthogDistinctId"] = user_id
|
||||||
|
if session_id:
|
||||||
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
|
# Enable adaptive thinking for Anthropic models via OpenRouter
|
||||||
|
if config.thinking_enabled and "anthropic" in config.model.lower():
|
||||||
|
extra_body["reasoning"] = {"enabled": True}
|
||||||
|
|
||||||
|
# Make streaming LLM call (no tools - just text response)
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
# Generate unique IDs for AI SDK protocol
|
||||||
|
message_id = str(uuid_module.uuid4())
|
||||||
|
text_block_id = str(uuid_module.uuid4())
|
||||||
|
|
||||||
|
# Publish start event
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamStartStep())
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
||||||
|
|
||||||
|
# Stream the response
|
||||||
|
stream = await client.chat.completions.create(
|
||||||
|
model=config.model,
|
||||||
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
|
extra_body=extra_body,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assistant_content = ""
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk.choices and chunk.choices[0].delta.content:
|
||||||
|
delta = chunk.choices[0].delta.content
|
||||||
|
assistant_content += delta
|
||||||
|
# Publish delta to stream registry
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task_id,
|
||||||
|
StreamTextDelta(id=text_block_id, delta=delta),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Publish end events
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||||
|
|
||||||
|
if assistant_content:
|
||||||
|
# Reload session from DB to avoid race condition with user messages
|
||||||
|
fresh_session = await get_chat_session(session_id, user_id)
|
||||||
|
if not fresh_session:
|
||||||
|
logger.error(
|
||||||
|
f"Session {session_id} disappeared during LLM continuation"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save assistant message to database
|
||||||
|
assistant_message = ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=assistant_content,
|
||||||
|
)
|
||||||
|
fresh_session.messages.append(assistant_message)
|
||||||
|
|
||||||
|
# Save to database (not cache) to persist the response
|
||||||
|
await upsert_chat_session(fresh_session)
|
||||||
|
|
||||||
|
# Invalidate cache so next poll/refresh gets fresh data
|
||||||
|
await invalidate_session_cache(session_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generated streaming LLM continuation for session {session_id} "
|
||||||
|
f"(task_id={task_id}), response length: {len(assistant_content)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Streaming LLM continuation returned empty response for {session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to generate streaming LLM continuation: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
# Publish error to stream registry followed by finish event
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task_id,
|
||||||
|
StreamError(errorText=f"Failed to generate response: {e}"),
|
||||||
|
)
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||||
|
|||||||
@@ -0,0 +1,967 @@
|
|||||||
|
"""Stream registry for managing reconnectable SSE streams.
|
||||||
|
|
||||||
|
This module provides a registry for tracking active streaming tasks and their
|
||||||
|
messages. It uses Redis for all state management (no in-memory state), making
|
||||||
|
pods stateless and horizontally scalable.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
- Redis Stream: Persists all messages for replay and real-time delivery
|
||||||
|
- Redis Hash: Task metadata (status, session_id, etc.)
|
||||||
|
|
||||||
|
Subscribers:
|
||||||
|
1. Replay missed messages from Redis Stream (XREAD)
|
||||||
|
2. Listen for live updates via blocking XREAD
|
||||||
|
3. No in-memory state required on the subscribing pod
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
from backend.data.redis_client import get_redis_async
|
||||||
|
|
||||||
|
from .config import ChatConfig
|
||||||
|
from .response_model import StreamBaseResponse, StreamError, StreamFinish
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
config = ChatConfig()
|
||||||
|
|
||||||
|
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
||||||
|
_local_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
|
# Track listener tasks per subscriber queue for cleanup
|
||||||
|
# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe
|
||||||
|
_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {}
|
||||||
|
|
||||||
|
# Timeout for putting chunks into subscriber queues (seconds)
|
||||||
|
# If the queue is full and doesn't drain within this time, send an overflow error
|
||||||
|
QUEUE_PUT_TIMEOUT = 5.0
|
||||||
|
|
||||||
|
# Lua script for atomic compare-and-swap status update (idempotent completion)
|
||||||
|
# Returns 1 if status was updated, 0 if already completed/failed
|
||||||
|
COMPLETE_TASK_SCRIPT = """
|
||||||
|
local current = redis.call("HGET", KEYS[1], "status")
|
||||||
|
if current == "running" then
|
||||||
|
redis.call("HSET", KEYS[1], "status", ARGV[1])
|
||||||
|
return 1
|
||||||
|
end
|
||||||
|
return 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ActiveTask:
|
||||||
|
"""Represents an active streaming task (metadata only, no in-memory queues)."""
|
||||||
|
|
||||||
|
task_id: str
|
||||||
|
session_id: str
|
||||||
|
user_id: str | None
|
||||||
|
tool_call_id: str
|
||||||
|
tool_name: str
|
||||||
|
operation_id: str
|
||||||
|
status: Literal["running", "completed", "failed"] = "running"
|
||||||
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
asyncio_task: asyncio.Task | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_task_meta_key(task_id: str) -> str:
|
||||||
|
"""Get Redis key for task metadata."""
|
||||||
|
return f"{config.task_meta_prefix}{task_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_task_stream_key(task_id: str) -> str:
|
||||||
|
"""Get Redis key for task message stream."""
|
||||||
|
return f"{config.task_stream_prefix}{task_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_operation_mapping_key(operation_id: str) -> str:
|
||||||
|
"""Get Redis key for operation_id to task_id mapping."""
|
||||||
|
return f"{config.task_op_prefix}{operation_id}"
|
||||||
|
|
||||||
|
|
||||||
|
async def create_task(
|
||||||
|
task_id: str,
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
tool_call_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
operation_id: str,
|
||||||
|
) -> ActiveTask:
|
||||||
|
"""Create a new streaming task in Redis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Unique identifier for the task
|
||||||
|
session_id: Chat session ID
|
||||||
|
user_id: User ID (may be None for anonymous)
|
||||||
|
tool_call_id: Tool call ID from the LLM
|
||||||
|
tool_name: Name of the tool being executed
|
||||||
|
operation_id: Operation ID for webhook callbacks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created ActiveTask instance (metadata only)
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Build log metadata for structured logging
|
||||||
|
log_meta = {
|
||||||
|
"component": "StreamRegistry",
|
||||||
|
"task_id": task_id,
|
||||||
|
"session_id": session_id,
|
||||||
|
}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
|
||||||
|
task = ActiveTask(
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
operation_id=operation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store metadata in Redis
|
||||||
|
redis_start = time.perf_counter()
|
||||||
|
redis = await get_redis_async()
|
||||||
|
redis_time = (time.perf_counter() - redis_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
op_key = _get_operation_mapping_key(operation_id)
|
||||||
|
|
||||||
|
hset_start = time.perf_counter()
|
||||||
|
await redis.hset( # type: ignore[misc]
|
||||||
|
meta_key,
|
||||||
|
mapping={
|
||||||
|
"task_id": task_id,
|
||||||
|
"session_id": session_id,
|
||||||
|
"user_id": user_id or "",
|
||||||
|
"tool_call_id": tool_call_id,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"operation_id": operation_id,
|
||||||
|
"status": task.status,
|
||||||
|
"created_at": task.created_at.isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
hset_time = (time.perf_counter() - hset_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] redis.hset took {hset_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
|
||||||
|
)
|
||||||
|
|
||||||
|
await redis.expire(meta_key, config.stream_ttl)
|
||||||
|
|
||||||
|
# Create operation_id -> task_id mapping for webhook lookups
|
||||||
|
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||||
|
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||||
|
)
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
async def publish_chunk(
|
||||||
|
task_id: str,
|
||||||
|
chunk: StreamBaseResponse,
|
||||||
|
) -> str:
|
||||||
|
"""Publish a chunk to Redis Stream.
|
||||||
|
|
||||||
|
All delivery is via Redis Streams - no in-memory state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to publish to
|
||||||
|
chunk: The stream response chunk to publish
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The Redis Stream message ID
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
chunk_type = type(chunk).__name__
|
||||||
|
chunk_json = chunk.model_dump_json()
|
||||||
|
message_id = "0-0"
|
||||||
|
|
||||||
|
# Build log metadata
|
||||||
|
log_meta = {
|
||||||
|
"component": "StreamRegistry",
|
||||||
|
"task_id": task_id,
|
||||||
|
"chunk_type": chunk_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
|
# Write to Redis Stream for persistence and real-time delivery
|
||||||
|
xadd_start = time.perf_counter()
|
||||||
|
raw_id = await redis.xadd(
|
||||||
|
stream_key,
|
||||||
|
{"data": chunk_json},
|
||||||
|
maxlen=config.stream_max_length,
|
||||||
|
)
|
||||||
|
xadd_time = (time.perf_counter() - xadd_start) * 1000
|
||||||
|
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||||
|
|
||||||
|
# Set TTL on stream to match task metadata TTL
|
||||||
|
await redis.expire(stream_key, config.stream_ttl)
|
||||||
|
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
# Only log timing for significant chunks or slow operations
|
||||||
|
if (
|
||||||
|
chunk_type
|
||||||
|
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
|
||||||
|
or total_time > 50
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"xadd_time_ms": xadd_time,
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.error(
|
||||||
|
f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}",
|
||||||
|
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return message_id
|
||||||
|
|
||||||
|
|
||||||
|
async def subscribe_to_task(
|
||||||
|
task_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
last_message_id: str = "0-0",
|
||||||
|
) -> asyncio.Queue[StreamBaseResponse] | None:
|
||||||
|
"""Subscribe to a task's stream with replay of missed messages.
|
||||||
|
|
||||||
|
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to subscribe to
|
||||||
|
user_id: User ID for ownership validation
|
||||||
|
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||||
|
or user doesn't have access
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Build log metadata
|
||||||
|
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_start = time.perf_counter()
|
||||||
|
redis = await get_redis_async()
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
|
hgetall_time = (time.perf_counter() - redis_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not meta:
|
||||||
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"reason": "task_not_found",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||||
|
task_status = meta.get("status", "")
|
||||||
|
task_user_id = meta.get("user_id", "") or None
|
||||||
|
log_meta["session_id"] = meta.get("session_id", "")
|
||||||
|
|
||||||
|
# Validate ownership - if task has an owner, requester must match
|
||||||
|
if task_user_id:
|
||||||
|
if user_id != task_user_id:
|
||||||
|
logger.warning(
|
||||||
|
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"task_owner": task_user_id,
|
||||||
|
"reason": "access_denied",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
|
||||||
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
|
# Step 1: Replay messages from Redis Stream
|
||||||
|
xread_start = time.perf_counter()
|
||||||
|
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||||
|
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": xread_time,
|
||||||
|
"task_status": task_status,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
replayed_count = 0
|
||||||
|
replay_last_id = last_message_id
|
||||||
|
if messages:
|
||||||
|
for _stream_name, stream_messages in messages:
|
||||||
|
for msg_id, msg_data in stream_messages:
|
||||||
|
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||||
|
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||||
|
if "data" in msg_data:
|
||||||
|
try:
|
||||||
|
chunk_data = orjson.loads(msg_data["data"])
|
||||||
|
chunk = _reconstruct_chunk(chunk_data)
|
||||||
|
if chunk:
|
||||||
|
await subscriber_queue.put(chunk)
|
||||||
|
replayed_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to replay message: {e}")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"n_messages_replayed": replayed_count,
|
||||||
|
"replay_last_id": replay_last_id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: If task is still running, start stream listener for live updates
|
||||||
|
if task_status == "running":
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] Task still running, starting _stream_listener",
|
||||||
|
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||||
|
)
|
||||||
|
listener_task = asyncio.create_task(
|
||||||
|
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
|
||||||
|
)
|
||||||
|
# Track listener task for cleanup on unsubscribe
|
||||||
|
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||||
|
else:
|
||||||
|
# Task is completed/failed - add finish marker
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Task already {task_status}, adding StreamFinish",
|
||||||
|
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||||
|
)
|
||||||
|
await subscriber_queue.put(StreamFinish())
|
||||||
|
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
|
||||||
|
f"n_messages_replayed={replayed_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"n_messages_replayed": replayed_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return subscriber_queue
|
||||||
|
|
||||||
|
|
||||||
|
async def _stream_listener(
|
||||||
|
task_id: str,
|
||||||
|
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||||
|
last_replayed_id: str,
|
||||||
|
log_meta: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||||
|
|
||||||
|
This approach avoids the duplicate message issue that can occur with pub/sub
|
||||||
|
when messages are published during the gap between replay and subscription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to listen for
|
||||||
|
subscriber_queue: Queue to deliver messages to
|
||||||
|
last_replayed_id: Last message ID from replay (continue from here)
|
||||||
|
log_meta: Structured logging metadata
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Use provided log_meta or build minimal one
|
||||||
|
if log_meta is None:
|
||||||
|
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
|
||||||
|
)
|
||||||
|
|
||||||
|
queue_id = id(subscriber_queue)
|
||||||
|
# Track the last successfully delivered message ID for recovery hints
|
||||||
|
last_delivered_id = last_replayed_id
|
||||||
|
messages_delivered = 0
|
||||||
|
first_message_time = None
|
||||||
|
xread_count = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
current_id = last_replayed_id
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Block for up to 30 seconds waiting for new messages
|
||||||
|
# This allows periodic checking if task is still running
|
||||||
|
xread_start = time.perf_counter()
|
||||||
|
xread_count += 1
|
||||||
|
messages = await redis.xread(
|
||||||
|
{stream_key: current_id}, block=30000, count=100
|
||||||
|
)
|
||||||
|
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||||
|
|
||||||
|
if messages:
|
||||||
|
msg_count = sum(len(msgs) for _, msgs in messages)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"xread_count": xread_count,
|
||||||
|
"n_messages": msg_count,
|
||||||
|
"duration_ms": xread_time,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif xread_time > 1000:
|
||||||
|
# Only log timeouts (30s blocking)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"xread_count": xread_count,
|
||||||
|
"duration_ms": xread_time,
|
||||||
|
"reason": "timeout",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
# Timeout - check if task is still running
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
status = await redis.hget(meta_key, "status") # type: ignore[misc]
|
||||||
|
if status and status != "running":
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
subscriber_queue.put(StreamFinish()),
|
||||||
|
timeout=QUEUE_PUT_TIMEOUT,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
f"Timeout delivering finish event for task {task_id}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
for _stream_name, stream_messages in messages:
|
||||||
|
for msg_id, msg_data in stream_messages:
|
||||||
|
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||||
|
|
||||||
|
if "data" not in msg_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk_data = orjson.loads(msg_data["data"])
|
||||||
|
chunk = _reconstruct_chunk(chunk_data)
|
||||||
|
if chunk:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
subscriber_queue.put(chunk),
|
||||||
|
timeout=QUEUE_PUT_TIMEOUT,
|
||||||
|
)
|
||||||
|
# Update last delivered ID on successful delivery
|
||||||
|
last_delivered_id = current_id
|
||||||
|
messages_delivered += 1
|
||||||
|
if first_message_time is None:
|
||||||
|
first_message_time = time.perf_counter()
|
||||||
|
elapsed = (first_message_time - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"chunk_type": type(chunk).__name__,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"timeout_s": QUEUE_PUT_TIMEOUT,
|
||||||
|
"reason": "queue_full",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Send overflow error with recovery info
|
||||||
|
try:
|
||||||
|
overflow_error = StreamError(
|
||||||
|
errorText="Message delivery timeout - some messages may have been missed",
|
||||||
|
code="QUEUE_OVERFLOW",
|
||||||
|
details={
|
||||||
|
"last_delivered_id": last_delivered_id,
|
||||||
|
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
subscriber_queue.put_nowait(overflow_error)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
# Queue is completely stuck, nothing more we can do
|
||||||
|
logger.error(
|
||||||
|
f"Cannot deliver overflow error for task {task_id}, "
|
||||||
|
"queue completely blocked"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop listening on finish
|
||||||
|
if isinstance(chunk, StreamFinish):
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"messages_delivered": messages_delivered,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error processing stream message: {e}",
|
||||||
|
extra={"json_fields": {**log_meta, "error": str(e)}},
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"messages_delivered": messages_delivered,
|
||||||
|
"reason": "cancelled",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise # Re-raise to propagate cancellation
|
||||||
|
except Exception as e:
|
||||||
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.error(
|
||||||
|
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
|
||||||
|
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
||||||
|
)
|
||||||
|
# On error, send finish to unblock subscriber
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
subscriber_queue.put(StreamFinish()),
|
||||||
|
timeout=QUEUE_PUT_TIMEOUT,
|
||||||
|
)
|
||||||
|
except (asyncio.TimeoutError, asyncio.QueueFull):
|
||||||
|
logger.warning(
|
||||||
|
"Could not deliver finish event after error",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Clean up listener task mapping on exit
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
|
||||||
|
f"delivered={messages_delivered}, xread_count={xread_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"messages_delivered": messages_delivered,
|
||||||
|
"xread_count": xread_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
_listener_tasks.pop(queue_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def mark_task_completed(
|
||||||
|
task_id: str,
|
||||||
|
status: Literal["completed", "failed"] = "completed",
|
||||||
|
) -> bool:
|
||||||
|
"""Mark a task as completed and publish finish event.
|
||||||
|
|
||||||
|
This is idempotent - calling multiple times with the same task_id is safe.
|
||||||
|
Uses atomic compare-and-swap via Lua script to prevent race conditions.
|
||||||
|
Status is updated first (source of truth), then finish event is published (best-effort).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to mark as completed
|
||||||
|
status: Final status ("completed" or "failed")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if task was newly marked completed, False if already completed/failed
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
|
||||||
|
# Atomic compare-and-swap: only update if status is "running"
|
||||||
|
# This prevents race conditions when multiple callers try to complete simultaneously
|
||||||
|
result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc]
|
||||||
|
|
||||||
|
if result == 0:
|
||||||
|
logger.debug(f"Task {task_id} already completed/failed, skipping")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# THEN publish finish event (best-effort - listeners can detect via status polling)
|
||||||
|
try:
|
||||||
|
await publish_chunk(task_id, StreamFinish())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to publish finish event for task {task_id}: {e}. "
|
||||||
|
"Listeners will detect completion via status polling."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up local task reference if exists
|
||||||
|
_local_tasks.pop(task_id, None)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
|
||||||
|
"""Find a task by its operation ID.
|
||||||
|
|
||||||
|
Used by webhook callbacks to locate the task to update.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation_id: Operation ID to search for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ActiveTask if found, None otherwise
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
op_key = _get_operation_mapping_key(operation_id)
|
||||||
|
task_id = await redis.get(op_key)
|
||||||
|
|
||||||
|
if not task_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
|
||||||
|
return await get_task(task_id_str)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_task(task_id: str) -> ActiveTask | None:
|
||||||
|
"""Get a task by its ID from Redis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to look up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ActiveTask if found, None otherwise
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
|
|
||||||
|
if not meta:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||||
|
return ActiveTask(
|
||||||
|
task_id=meta.get("task_id", ""),
|
||||||
|
session_id=meta.get("session_id", ""),
|
||||||
|
user_id=meta.get("user_id", "") or None,
|
||||||
|
tool_call_id=meta.get("tool_call_id", ""),
|
||||||
|
tool_name=meta.get("tool_name", ""),
|
||||||
|
operation_id=meta.get("operation_id", ""),
|
||||||
|
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_task_with_expiry_info(
|
||||||
|
task_id: str,
|
||||||
|
) -> tuple[ActiveTask | None, str | None]:
|
||||||
|
"""Get a task by its ID with expiration detection.
|
||||||
|
|
||||||
|
Returns (task, error_code) where error_code is:
|
||||||
|
- None if task found
|
||||||
|
- "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired)
|
||||||
|
- "TASK_NOT_FOUND" if neither exists
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to look up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (ActiveTask or None, error_code or None)
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
|
|
||||||
|
if not meta:
|
||||||
|
# Check if stream still has data (metadata expired but stream hasn't)
|
||||||
|
stream_len = await redis.xlen(stream_key)
|
||||||
|
if stream_len > 0:
|
||||||
|
return None, "TASK_EXPIRED"
|
||||||
|
return None, "TASK_NOT_FOUND"
|
||||||
|
|
||||||
|
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||||
|
return (
|
||||||
|
ActiveTask(
|
||||||
|
task_id=meta.get("task_id", ""),
|
||||||
|
session_id=meta.get("session_id", ""),
|
||||||
|
user_id=meta.get("user_id", "") or None,
|
||||||
|
tool_call_id=meta.get("tool_call_id", ""),
|
||||||
|
tool_name=meta.get("tool_name", ""),
|
||||||
|
operation_id=meta.get("operation_id", ""),
|
||||||
|
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_active_task_for_session(
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> tuple[ActiveTask | None, str]:
|
||||||
|
"""Get the active (running) task for a session, if any.
|
||||||
|
|
||||||
|
Scans Redis for tasks matching the session_id with status="running".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Session ID to look up
|
||||||
|
user_id: User ID for ownership validation (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
|
||||||
|
"""
|
||||||
|
|
||||||
|
redis = await get_redis_async()
|
||||||
|
|
||||||
|
# Scan Redis for task metadata keys
|
||||||
|
cursor = 0
|
||||||
|
tasks_checked = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
cursor, keys = await redis.scan(
|
||||||
|
cursor, match=f"{config.task_meta_prefix}*", count=100
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
tasks_checked += 1
|
||||||
|
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
|
||||||
|
if not meta:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||||
|
task_session_id = meta.get("session_id", "")
|
||||||
|
task_status = meta.get("status", "")
|
||||||
|
task_user_id = meta.get("user_id", "") or None
|
||||||
|
task_id = meta.get("task_id", "")
|
||||||
|
|
||||||
|
if task_session_id == session_id and task_status == "running":
|
||||||
|
# Validate ownership - if task has an owner, requester must match
|
||||||
|
if task_user_id and user_id != task_user_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the last message ID from Redis Stream
|
||||||
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
last_id = "0-0"
|
||||||
|
try:
|
||||||
|
messages = await redis.xrevrange(stream_key, count=1)
|
||||||
|
if messages:
|
||||||
|
msg_id = messages[0][0]
|
||||||
|
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get last message ID: {e}")
|
||||||
|
|
||||||
|
return (
|
||||||
|
ActiveTask(
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=task_session_id,
|
||||||
|
user_id=task_user_id,
|
||||||
|
tool_call_id=meta.get("tool_call_id", ""),
|
||||||
|
tool_name=meta.get("tool_name", ""),
|
||||||
|
operation_id=meta.get("operation_id", ""),
|
||||||
|
status="running",
|
||||||
|
),
|
||||||
|
last_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
return None, "0-0"
|
||||||
|
|
||||||
|
|
||||||
|
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||||
|
"""Reconstruct a StreamBaseResponse from JSON data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_data: Parsed JSON data from Redis
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reconstructed response object, or None if unknown type
|
||||||
|
"""
|
||||||
|
from .response_model import (
|
||||||
|
ResponseType,
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamFinishStep,
|
||||||
|
StreamHeartbeat,
|
||||||
|
StreamStart,
|
||||||
|
StreamStartStep,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamTextEnd,
|
||||||
|
StreamTextStart,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolInputStart,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
StreamUsage,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map response types to their corresponding classes
|
||||||
|
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
||||||
|
ResponseType.START.value: StreamStart,
|
||||||
|
ResponseType.FINISH.value: StreamFinish,
|
||||||
|
ResponseType.START_STEP.value: StreamStartStep,
|
||||||
|
ResponseType.FINISH_STEP.value: StreamFinishStep,
|
||||||
|
ResponseType.TEXT_START.value: StreamTextStart,
|
||||||
|
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
||||||
|
ResponseType.TEXT_END.value: StreamTextEnd,
|
||||||
|
ResponseType.TOOL_INPUT_START.value: StreamToolInputStart,
|
||||||
|
ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable,
|
||||||
|
ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable,
|
||||||
|
ResponseType.ERROR.value: StreamError,
|
||||||
|
ResponseType.USAGE.value: StreamUsage,
|
||||||
|
ResponseType.HEARTBEAT.value: StreamHeartbeat,
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk_type = chunk_data.get("type")
|
||||||
|
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
if chunk_class is None:
|
||||||
|
logger.warning(f"Unknown chunk type: {chunk_type}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return chunk_class(**chunk_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
|
||||||
|
"""Track the asyncio.Task for a task (local reference only).
|
||||||
|
|
||||||
|
This is just for cleanup purposes - the task state is in Redis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID
|
||||||
|
asyncio_task: The asyncio Task to track
|
||||||
|
"""
|
||||||
|
_local_tasks[task_id] = asyncio_task
|
||||||
|
|
||||||
|
|
||||||
|
async def unsubscribe_from_task(
|
||||||
|
task_id: str,
|
||||||
|
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||||
|
) -> None:
|
||||||
|
"""Clean up when a subscriber disconnects.
|
||||||
|
|
||||||
|
Cancels the XREAD-based listener task associated with this subscriber queue
|
||||||
|
to prevent resource leaks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID
|
||||||
|
subscriber_queue: The subscriber's queue used to look up the listener task
|
||||||
|
"""
|
||||||
|
queue_id = id(subscriber_queue)
|
||||||
|
listener_entry = _listener_tasks.pop(queue_id, None)
|
||||||
|
|
||||||
|
if listener_entry is None:
|
||||||
|
logger.debug(
|
||||||
|
f"No listener task found for task {task_id} queue {queue_id} "
|
||||||
|
"(may have already completed)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
stored_task_id, listener_task = listener_entry
|
||||||
|
|
||||||
|
if stored_task_id != task_id:
|
||||||
|
logger.warning(
|
||||||
|
f"Task ID mismatch in unsubscribe: expected {task_id}, "
|
||||||
|
f"found {stored_task_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if listener_task.done():
|
||||||
|
logger.debug(f"Listener task for task {task_id} already completed")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Cancel the listener task
|
||||||
|
listener_task.cancel()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Wait for the task to be cancelled with a timeout
|
||||||
|
await asyncio.wait_for(listener_task, timeout=5.0)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Expected - the task was successfully cancelled
|
||||||
|
pass
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
f"Timeout waiting for listener task cancellation for task {task_id}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during listener task cancellation for task {task_id}: {e}")
|
||||||
|
|
||||||
|
logger.debug(f"Successfully unsubscribed from task {task_id}")
|
||||||
@@ -10,6 +10,7 @@ from .add_understanding import AddUnderstandingTool
|
|||||||
from .agent_output import AgentOutputTool
|
from .agent_output import AgentOutputTool
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .create_agent import CreateAgentTool
|
from .create_agent import CreateAgentTool
|
||||||
|
from .customize_agent import CustomizeAgentTool
|
||||||
from .edit_agent import EditAgentTool
|
from .edit_agent import EditAgentTool
|
||||||
from .find_agent import FindAgentTool
|
from .find_agent import FindAgentTool
|
||||||
from .find_block import FindBlockTool
|
from .find_block import FindBlockTool
|
||||||
@@ -34,6 +35,7 @@ logger = logging.getLogger(__name__)
|
|||||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||||
"add_understanding": AddUnderstandingTool(),
|
"add_understanding": AddUnderstandingTool(),
|
||||||
"create_agent": CreateAgentTool(),
|
"create_agent": CreateAgentTool(),
|
||||||
|
"customize_agent": CustomizeAgentTool(),
|
||||||
"edit_agent": EditAgentTool(),
|
"edit_agent": EditAgentTool(),
|
||||||
"find_agent": FindAgentTool(),
|
"find_agent": FindAgentTool(),
|
||||||
"find_block": FindBlockTool(),
|
"find_block": FindBlockTool(),
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from .core import (
|
|||||||
DecompositionStep,
|
DecompositionStep,
|
||||||
LibraryAgentSummary,
|
LibraryAgentSummary,
|
||||||
MarketplaceAgentSummary,
|
MarketplaceAgentSummary,
|
||||||
|
customize_template,
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
enrich_library_agents_from_steps,
|
enrich_library_agents_from_steps,
|
||||||
extract_search_terms_from_steps,
|
extract_search_terms_from_steps,
|
||||||
@@ -19,6 +20,7 @@ from .core import (
|
|||||||
get_library_agent_by_graph_id,
|
get_library_agent_by_graph_id,
|
||||||
get_library_agent_by_id,
|
get_library_agent_by_id,
|
||||||
get_library_agents_for_generation,
|
get_library_agents_for_generation,
|
||||||
|
graph_to_json,
|
||||||
json_to_graph,
|
json_to_graph,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
search_marketplace_agents_for_generation,
|
search_marketplace_agents_for_generation,
|
||||||
@@ -36,6 +38,7 @@ __all__ = [
|
|||||||
"LibraryAgentSummary",
|
"LibraryAgentSummary",
|
||||||
"MarketplaceAgentSummary",
|
"MarketplaceAgentSummary",
|
||||||
"check_external_service_health",
|
"check_external_service_health",
|
||||||
|
"customize_template",
|
||||||
"decompose_goal",
|
"decompose_goal",
|
||||||
"enrich_library_agents_from_steps",
|
"enrich_library_agents_from_steps",
|
||||||
"extract_search_terms_from_steps",
|
"extract_search_terms_from_steps",
|
||||||
@@ -48,6 +51,7 @@ __all__ = [
|
|||||||
"get_library_agent_by_id",
|
"get_library_agent_by_id",
|
||||||
"get_library_agents_for_generation",
|
"get_library_agents_for_generation",
|
||||||
"get_user_message_for_error",
|
"get_user_message_for_error",
|
||||||
|
"graph_to_json",
|
||||||
"is_external_service_configured",
|
"is_external_service_configured",
|
||||||
"json_to_graph",
|
"json_to_graph",
|
||||||
"save_agent_to_library",
|
"save_agent_to_library",
|
||||||
|
|||||||
@@ -7,18 +7,11 @@ from typing import Any, NotRequired, TypedDict
|
|||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data.graph import (
|
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
||||||
Graph,
|
|
||||||
Link,
|
|
||||||
Node,
|
|
||||||
create_graph,
|
|
||||||
get_graph,
|
|
||||||
get_graph_all_versions,
|
|
||||||
get_store_listed_graphs,
|
|
||||||
)
|
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .service import (
|
||||||
|
customize_template_external,
|
||||||
decompose_goal_external,
|
decompose_goal_external,
|
||||||
generate_agent_external,
|
generate_agent_external,
|
||||||
generate_agent_patch_external,
|
generate_agent_patch_external,
|
||||||
@@ -27,8 +20,6 @@ from .service import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionSummary(TypedDict):
|
class ExecutionSummary(TypedDict):
|
||||||
"""Summary of a single execution for quality assessment."""
|
"""Summary of a single execution for quality assessment."""
|
||||||
@@ -549,15 +540,21 @@ async def decompose_goal(
|
|||||||
async def generate_agent(
|
async def generate_agent(
|
||||||
instructions: DecompositionResult | dict[str, Any],
|
instructions: DecompositionResult | dict[str, Any],
|
||||||
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||||
|
operation_id: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Generate agent JSON from instructions.
|
"""Generate agent JSON from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
operation_id: Operation ID for async processing (enables Redis Streams
|
||||||
|
completion notification)
|
||||||
|
task_id: Task ID for async processing (enables Redis Streams persistence
|
||||||
|
and SSE delivery)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
@@ -565,8 +562,13 @@ async def generate_agent(
|
|||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent")
|
logger.info("Calling external Agent Generator service for generate_agent")
|
||||||
result = await generate_agent_external(
|
result = await generate_agent_external(
|
||||||
dict(instructions), _to_dict_list(library_agents)
|
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Don't modify async response
|
||||||
|
if result and result.get("status") == "accepted":
|
||||||
|
return result
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
return result
|
return result
|
||||||
@@ -657,45 +659,6 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _reassign_node_ids(graph: Graph) -> None:
|
|
||||||
"""Reassign all node and link IDs to new UUIDs.
|
|
||||||
|
|
||||||
This is needed when creating a new version to avoid unique constraint violations.
|
|
||||||
"""
|
|
||||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
|
||||||
|
|
||||||
for node in graph.nodes:
|
|
||||||
node.id = id_map[node.id]
|
|
||||||
|
|
||||||
for link in graph.links:
|
|
||||||
link.id = str(uuid.uuid4())
|
|
||||||
if link.source_id in id_map:
|
|
||||||
link.source_id = id_map[link.source_id]
|
|
||||||
if link.sink_id in id_map:
|
|
||||||
link.sink_id = id_map[link.sink_id]
|
|
||||||
|
|
||||||
|
|
||||||
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
|
|
||||||
"""Populate user_id in AgentExecutorBlock nodes.
|
|
||||||
|
|
||||||
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
|
|
||||||
This function fills in the actual user_id so sub-agents run with correct permissions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_json: Agent JSON dict (modified in place)
|
|
||||||
user_id: User ID to set
|
|
||||||
"""
|
|
||||||
for node in agent_json.get("nodes", []):
|
|
||||||
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
|
|
||||||
input_default = node.get("input_default") or {}
|
|
||||||
if not input_default.get("user_id"):
|
|
||||||
input_default["user_id"] = user_id
|
|
||||||
node["input_default"] = input_default
|
|
||||||
logger.debug(
|
|
||||||
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def save_agent_to_library(
|
async def save_agent_to_library(
|
||||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||||
) -> tuple[Graph, Any]:
|
) -> tuple[Graph, Any]:
|
||||||
@@ -709,63 +672,21 @@ async def save_agent_to_library(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (created Graph, LibraryAgent)
|
Tuple of (created Graph, LibraryAgent)
|
||||||
"""
|
"""
|
||||||
# Populate user_id in AgentExecutorBlock nodes before conversion
|
|
||||||
_populate_agent_executor_user_ids(agent_json, user_id)
|
|
||||||
|
|
||||||
graph = json_to_graph(agent_json)
|
graph = json_to_graph(agent_json)
|
||||||
|
|
||||||
if is_update:
|
if is_update:
|
||||||
if graph.id:
|
return await library_db.update_graph_in_library(graph, user_id)
|
||||||
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
return await library_db.create_graph_in_library(graph, user_id)
|
||||||
if existing_versions:
|
|
||||||
latest_version = max(v.version for v in existing_versions)
|
|
||||||
graph.version = latest_version + 1
|
|
||||||
_reassign_node_ids(graph)
|
|
||||||
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
|
||||||
else:
|
|
||||||
graph.id = str(uuid.uuid4())
|
|
||||||
graph.version = 1
|
|
||||||
_reassign_node_ids(graph)
|
|
||||||
logger.info(f"Creating new agent with ID {graph.id}")
|
|
||||||
|
|
||||||
created_graph = await create_graph(graph, user_id)
|
|
||||||
|
|
||||||
library_agents = await library_db.create_library_agent(
|
|
||||||
graph=created_graph,
|
|
||||||
user_id=user_id,
|
|
||||||
sensitive_action_safe_mode=True,
|
|
||||||
create_library_agents_for_sub_graphs=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return created_graph, library_agents[0]
|
|
||||||
|
|
||||||
|
|
||||||
async def get_agent_as_json(
|
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||||
agent_id: str, user_id: str | None
|
"""Convert a Graph object to JSON format for the agent generator.
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Fetch an agent and convert to JSON format for editing.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent_id: Graph ID or library agent ID
|
graph: Graph object to convert
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent as JSON dict or None if not found
|
Agent as JSON dict
|
||||||
"""
|
"""
|
||||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
|
||||||
|
|
||||||
if not graph and user_id:
|
|
||||||
try:
|
|
||||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
|
||||||
graph = await get_graph(
|
|
||||||
library_agent.graph_id, version=None, user_id=user_id
|
|
||||||
)
|
|
||||||
except NotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not graph:
|
|
||||||
return None
|
|
||||||
|
|
||||||
nodes = []
|
nodes = []
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
nodes.append(
|
nodes.append(
|
||||||
@@ -802,10 +723,41 @@ async def get_agent_as_json(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def get_agent_as_json(
|
||||||
|
agent_id: str, user_id: str | None
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Fetch an agent and convert to JSON format for editing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Graph ID or library agent ID
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent as JSON dict or None if not found
|
||||||
|
"""
|
||||||
|
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||||
|
|
||||||
|
if not graph and user_id:
|
||||||
|
try:
|
||||||
|
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
graph = await get_graph(
|
||||||
|
library_agent.graph_id, version=None, user_id=user_id
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not graph:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return graph_to_json(graph)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_patch(
|
async def generate_agent_patch(
|
||||||
update_request: str,
|
update_request: str,
|
||||||
current_agent: dict[str, Any],
|
current_agent: dict[str, Any],
|
||||||
library_agents: list[AgentSummary] | None = None,
|
library_agents: list[AgentSummary] | None = None,
|
||||||
|
operation_id: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Update an existing agent using natural language.
|
"""Update an existing agent using natural language.
|
||||||
|
|
||||||
@@ -818,10 +770,12 @@ async def generate_agent_patch(
|
|||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||||
|
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||||
error dict {"type": "error", ...}, or None on unexpected error
|
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
@@ -829,5 +783,43 @@ async def generate_agent_patch(
|
|||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||||
return await generate_agent_patch_external(
|
return await generate_agent_patch_external(
|
||||||
update_request, current_agent, _to_dict_list(library_agents)
|
update_request,
|
||||||
|
current_agent,
|
||||||
|
_to_dict_list(library_agents),
|
||||||
|
operation_id,
|
||||||
|
task_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def customize_template(
|
||||||
|
template_agent: dict[str, Any],
|
||||||
|
modification_request: str,
|
||||||
|
context: str = "",
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Customize a template/marketplace agent using natural language.
|
||||||
|
|
||||||
|
This is used when users want to modify a template or marketplace agent
|
||||||
|
to fit their specific needs before adding it to their library.
|
||||||
|
|
||||||
|
The external Agent Generator service handles:
|
||||||
|
- Understanding the modification request
|
||||||
|
- Applying changes to the template
|
||||||
|
- Fixing and validating the result
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_agent: The template agent JSON to customize
|
||||||
|
modification_request: Natural language description of customizations
|
||||||
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||||
|
error dict {"type": "error", ...}, or None on unexpected error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
|
"""
|
||||||
|
_check_service_configured()
|
||||||
|
logger.info("Calling external Agent Generator service for customize_template")
|
||||||
|
return await customize_template_external(
|
||||||
|
template_agent, modification_request, context
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -212,24 +212,45 @@ async def decompose_goal_external(
|
|||||||
async def generate_agent_external(
|
async def generate_agent_external(
|
||||||
instructions: dict[str, Any],
|
instructions: dict[str, Any],
|
||||||
library_agents: list[dict[str, Any]] | None = None,
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
|
operation_id: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate an agent from instructions.
|
"""Call the external service to generate an agent from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||||
|
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
|
# Build request payload
|
||||||
payload: dict[str, Any] = {"instructions": instructions}
|
payload: dict[str, Any] = {"instructions": instructions}
|
||||||
if library_agents:
|
if library_agents:
|
||||||
payload["library_agents"] = library_agents
|
payload["library_agents"] = library_agents
|
||||||
|
if operation_id and task_id:
|
||||||
|
payload["operation_id"] = operation_id
|
||||||
|
payload["task_id"] = task_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post("/api/generate-agent", json=payload)
|
response = await client.post("/api/generate-agent", json=payload)
|
||||||
|
|
||||||
|
# Handle 202 Accepted for async processing
|
||||||
|
if response.status_code == 202:
|
||||||
|
logger.info(
|
||||||
|
f"Agent Generator accepted async request "
|
||||||
|
f"(operation_id={operation_id}, task_id={task_id})"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "accepted",
|
||||||
|
"operation_id": operation_id,
|
||||||
|
"task_id": task_id,
|
||||||
|
}
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
@@ -261,6 +282,8 @@ async def generate_agent_patch_external(
|
|||||||
update_request: str,
|
update_request: str,
|
||||||
current_agent: dict[str, Any],
|
current_agent: dict[str, Any],
|
||||||
library_agents: list[dict[str, Any]] | None = None,
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
|
operation_id: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate a patch for an existing agent.
|
"""Call the external service to generate a patch for an existing agent.
|
||||||
|
|
||||||
@@ -268,21 +291,40 @@ async def generate_agent_patch_external(
|
|||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||||
|
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict, or error dict on error
|
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
|
# Build request payload
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"update_request": update_request,
|
"update_request": update_request,
|
||||||
"current_agent_json": current_agent,
|
"current_agent_json": current_agent,
|
||||||
}
|
}
|
||||||
if library_agents:
|
if library_agents:
|
||||||
payload["library_agents"] = library_agents
|
payload["library_agents"] = library_agents
|
||||||
|
if operation_id and task_id:
|
||||||
|
payload["operation_id"] = operation_id
|
||||||
|
payload["task_id"] = task_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post("/api/update-agent", json=payload)
|
response = await client.post("/api/update-agent", json=payload)
|
||||||
|
|
||||||
|
# Handle 202 Accepted for async processing
|
||||||
|
if response.status_code == 202:
|
||||||
|
logger.info(
|
||||||
|
f"Agent Generator accepted async update request "
|
||||||
|
f"(operation_id={operation_id}, task_id={task_id})"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "accepted",
|
||||||
|
"operation_id": operation_id,
|
||||||
|
"task_id": task_id,
|
||||||
|
}
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
@@ -326,6 +368,77 @@ async def generate_agent_patch_external(
|
|||||||
return _create_error_response(error_msg, "unexpected_error")
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
|
async def customize_template_external(
|
||||||
|
template_agent: dict[str, Any],
|
||||||
|
modification_request: str,
|
||||||
|
context: str = "",
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Call the external service to customize a template/marketplace agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_agent: The template agent JSON to customize
|
||||||
|
modification_request: Natural language description of customizations
|
||||||
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Customized agent JSON, clarifying questions dict, or error dict on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
request = modification_request
|
||||||
|
if context:
|
||||||
|
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"template_agent_json": template_agent,
|
||||||
|
"modification_request": request,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post("/api/template-modification", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||||
|
error_type = data.get("error_type", "unknown")
|
||||||
|
logger.error(
|
||||||
|
f"Agent Generator template customization failed: {error_msg} "
|
||||||
|
f"(type: {error_type})"
|
||||||
|
)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
|
||||||
|
# Check if it's clarifying questions
|
||||||
|
if data.get("type") == "clarifying_questions":
|
||||||
|
return {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": data.get("questions", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if it's an error passed through
|
||||||
|
if data.get("type") == "error":
|
||||||
|
return _create_error_response(
|
||||||
|
data.get("error", "Unknown error"),
|
||||||
|
data.get("error_type", "unknown"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Otherwise return the customized agent JSON
|
||||||
|
return data.get("agent_json")
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
error_type, error_msg = _classify_http_error(e)
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
error_type, error_msg = _classify_request_error(e)
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||||
"""Get available blocks from the external service.
|
"""Get available blocks from the external service.
|
||||||
|
|
||||||
|
|||||||
@@ -206,9 +206,9 @@ async def search_agents(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
no_results_msg = (
|
no_results_msg = (
|
||||||
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
|
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else f"No agents matching '{query}' found in your library."
|
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
|
||||||
)
|
)
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||||
@@ -224,10 +224,10 @@ async def search_agents(
|
|||||||
message = (
|
message = (
|
||||||
"Now you have found some options for the user to choose from. "
|
"Now you have found some options for the user to choose from. "
|
||||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||||
"Please ask the user if they would like to use any of these agents."
|
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
|
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgentsFoundResponse(
|
return AgentsFoundResponse(
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from .base import BaseTool
|
|||||||
from .models import (
|
from .models import (
|
||||||
AgentPreviewResponse,
|
AgentPreviewResponse,
|
||||||
AgentSavedResponse,
|
AgentSavedResponse,
|
||||||
|
AsyncProcessingResponse,
|
||||||
ClarificationNeededResponse,
|
ClarificationNeededResponse,
|
||||||
ClarifyingQuestion,
|
ClarifyingQuestion,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -98,6 +99,10 @@ class CreateAgentTool(BaseTool):
|
|||||||
save = kwargs.get("save", True)
|
save = kwargs.get("save", True)
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
|
# Extract async processing params (passed by long-running tool handler)
|
||||||
|
operation_id = kwargs.get("_operation_id")
|
||||||
|
task_id = kwargs.get("_task_id")
|
||||||
|
|
||||||
if not description:
|
if not description:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide a description of what the agent should do.",
|
message="Please provide a description of what the agent should do.",
|
||||||
@@ -219,7 +224,12 @@ class CreateAgentTool(BaseTool):
|
|||||||
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_json = await generate_agent(decomposition_result, library_agents)
|
agent_json = await generate_agent(
|
||||||
|
decomposition_result,
|
||||||
|
library_agents,
|
||||||
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -263,6 +273,19 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if Agent Generator accepted for async processing
|
||||||
|
if agent_json.get("status") == "accepted":
|
||||||
|
logger.info(
|
||||||
|
f"Agent generation delegated to async processing "
|
||||||
|
f"(operation_id={operation_id}, task_id={task_id})"
|
||||||
|
)
|
||||||
|
return AsyncProcessingResponse(
|
||||||
|
message="Agent generation started. You'll be notified when it's complete.",
|
||||||
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
agent_name = agent_json.get("name", "Generated Agent")
|
agent_name = agent_json.get("name", "Generated Agent")
|
||||||
agent_description = agent_json.get("description", "")
|
agent_description = agent_json.get("description", "")
|
||||||
node_count = len(agent_json.get("nodes", []))
|
node_count = len(agent_json.get("nodes", []))
|
||||||
|
|||||||
@@ -0,0 +1,337 @@
|
|||||||
|
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.store import db as store_db
|
||||||
|
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||||
|
|
||||||
|
from .agent_generator import (
|
||||||
|
AgentGeneratorNotConfiguredError,
|
||||||
|
customize_template,
|
||||||
|
get_user_message_for_error,
|
||||||
|
graph_to_json,
|
||||||
|
save_agent_to_library,
|
||||||
|
)
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import (
|
||||||
|
AgentPreviewResponse,
|
||||||
|
AgentSavedResponse,
|
||||||
|
ClarificationNeededResponse,
|
||||||
|
ClarifyingQuestion,
|
||||||
|
ErrorResponse,
|
||||||
|
ToolResponseBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomizeAgentTool(BaseTool):
|
||||||
|
"""Tool for customizing marketplace/template agents using natural language."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "customize_agent"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Customize a marketplace or template agent using natural language. "
|
||||||
|
"Takes an existing agent from the marketplace and modifies it based on "
|
||||||
|
"the user's requirements before adding to their library."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"agent_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The marketplace agent ID in format 'creator/slug' "
|
||||||
|
"(e.g., 'autogpt/newsletter-writer'). "
|
||||||
|
"Get this from find_agent results."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"modifications": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Natural language description of how to customize the agent. "
|
||||||
|
"Be specific about what changes you want to make."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"context": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Additional context or answers to previous clarifying questions."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"save": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": (
|
||||||
|
"Whether to save the customized agent to the user's library. "
|
||||||
|
"Default is true. Set to false for preview only."
|
||||||
|
),
|
||||||
|
"default": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["agent_id", "modifications"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""Execute the customize_agent tool.
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. Parse the agent ID to get creator/slug
|
||||||
|
2. Fetch the template agent from the marketplace
|
||||||
|
3. Call customize_template with the modification request
|
||||||
|
4. Preview or save based on the save parameter
|
||||||
|
"""
|
||||||
|
agent_id = kwargs.get("agent_id", "").strip()
|
||||||
|
modifications = kwargs.get("modifications", "").strip()
|
||||||
|
context = kwargs.get("context", "")
|
||||||
|
save = kwargs.get("save", True)
|
||||||
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
|
if not agent_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
||||||
|
error="missing_agent_id",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not modifications:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please describe how you want to customize this agent.",
|
||||||
|
error="missing_modifications",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse agent_id in format "creator/slug"
|
||||||
|
parts = [p.strip() for p in agent_id.split("/")]
|
||||||
|
if len(parts) != 2 or not parts[0] or not parts[1]:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Invalid agent ID format: '{agent_id}'. "
|
||||||
|
"Expected format is 'creator/agent-name' "
|
||||||
|
"(e.g., 'autogpt/newsletter-writer')."
|
||||||
|
),
|
||||||
|
error="invalid_agent_id_format",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
creator_username, agent_slug = parts
|
||||||
|
|
||||||
|
# Fetch the marketplace agent details
|
||||||
|
try:
|
||||||
|
agent_details = await store_db.get_store_agent_details(
|
||||||
|
username=creator_username, agent_name=agent_slug
|
||||||
|
)
|
||||||
|
except AgentNotFoundError:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Could not find marketplace agent '{agent_id}'. "
|
||||||
|
"Please check the agent ID and try again."
|
||||||
|
),
|
||||||
|
error="agent_not_found",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to fetch the marketplace agent. Please try again.",
|
||||||
|
error="fetch_error",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not agent_details.store_listing_version_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"The agent '{agent_id}' does not have an available version. "
|
||||||
|
"Please try a different agent."
|
||||||
|
),
|
||||||
|
error="no_version_available",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the full agent graph
|
||||||
|
try:
|
||||||
|
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
||||||
|
template_agent = graph_to_json(graph)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to fetch the agent configuration. Please try again.",
|
||||||
|
error="graph_fetch_error",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call customize_template
|
||||||
|
try:
|
||||||
|
result = await customize_template(
|
||||||
|
template_agent=template_agent,
|
||||||
|
modification_request=modifications,
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
except AgentGeneratorNotConfiguredError:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
"Agent customization is not available. "
|
||||||
|
"The Agent Generator service is not configured."
|
||||||
|
),
|
||||||
|
error="service_not_configured",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
"Failed to customize the agent due to a service error. "
|
||||||
|
"Please try again."
|
||||||
|
),
|
||||||
|
error="customization_service_error",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
"Failed to customize the agent. "
|
||||||
|
"The agent generation service may be unavailable or timed out. "
|
||||||
|
"Please try again."
|
||||||
|
),
|
||||||
|
error="customization_failed",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle error response
|
||||||
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
|
error_msg = result.get("error", "Unknown error")
|
||||||
|
error_type = result.get("error_type", "unknown")
|
||||||
|
user_message = get_user_message_for_error(
|
||||||
|
error_type,
|
||||||
|
operation="customize the agent",
|
||||||
|
llm_parse_message=(
|
||||||
|
"The AI had trouble customizing the agent. "
|
||||||
|
"Please try again or simplify your request."
|
||||||
|
),
|
||||||
|
validation_message=(
|
||||||
|
"The customized agent failed validation. "
|
||||||
|
"Please try rephrasing your request."
|
||||||
|
),
|
||||||
|
error_details=error_msg,
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=user_message,
|
||||||
|
error=f"customization_failed:{error_type}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle clarifying questions
|
||||||
|
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
||||||
|
questions = result.get("questions") or []
|
||||||
|
if not isinstance(questions, list):
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected clarifying questions format: {type(questions)}"
|
||||||
|
)
|
||||||
|
questions = []
|
||||||
|
return ClarificationNeededResponse(
|
||||||
|
message=(
|
||||||
|
"I need some more information to customize this agent. "
|
||||||
|
"Please answer the following questions:"
|
||||||
|
),
|
||||||
|
questions=[
|
||||||
|
ClarifyingQuestion(
|
||||||
|
question=q.get("question", ""),
|
||||||
|
keyword=q.get("keyword", ""),
|
||||||
|
example=q.get("example"),
|
||||||
|
)
|
||||||
|
for q in questions
|
||||||
|
if isinstance(q, dict)
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Result should be the customized agent JSON
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to customize the agent due to an unexpected response.",
|
||||||
|
error="unexpected_response_type",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
customized_agent = result
|
||||||
|
|
||||||
|
agent_name = customized_agent.get(
|
||||||
|
"name", f"Customized {agent_details.agent_name}"
|
||||||
|
)
|
||||||
|
agent_description = customized_agent.get("description", "")
|
||||||
|
nodes = customized_agent.get("nodes")
|
||||||
|
links = customized_agent.get("links")
|
||||||
|
node_count = len(nodes) if isinstance(nodes, list) else 0
|
||||||
|
link_count = len(links) if isinstance(links, list) else 0
|
||||||
|
|
||||||
|
if not save:
|
||||||
|
return AgentPreviewResponse(
|
||||||
|
message=(
|
||||||
|
f"I've customized the agent '{agent_details.agent_name}'. "
|
||||||
|
f"The customized agent has {node_count} blocks. "
|
||||||
|
f"Review it and call customize_agent with save=true to save it."
|
||||||
|
),
|
||||||
|
agent_json=customized_agent,
|
||||||
|
agent_name=agent_name,
|
||||||
|
description=agent_description,
|
||||||
|
node_count=node_count,
|
||||||
|
link_count=link_count,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="You must be logged in to save agents.",
|
||||||
|
error="auth_required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save to user's library
|
||||||
|
try:
|
||||||
|
created_graph, library_agent = await save_agent_to_library(
|
||||||
|
customized_agent, user_id, is_update=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentSavedResponse(
|
||||||
|
message=(
|
||||||
|
f"Customized agent '{created_graph.name}' "
|
||||||
|
f"(based on '{agent_details.agent_name}') "
|
||||||
|
f"has been saved to your library!"
|
||||||
|
),
|
||||||
|
agent_id=created_graph.id,
|
||||||
|
agent_name=created_graph.name,
|
||||||
|
library_agent_id=library_agent.id,
|
||||||
|
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||||
|
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving customized agent: {e}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to save the customized agent. Please try again.",
|
||||||
|
error="save_failed",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
@@ -17,6 +17,7 @@ from .base import BaseTool
|
|||||||
from .models import (
|
from .models import (
|
||||||
AgentPreviewResponse,
|
AgentPreviewResponse,
|
||||||
AgentSavedResponse,
|
AgentSavedResponse,
|
||||||
|
AsyncProcessingResponse,
|
||||||
ClarificationNeededResponse,
|
ClarificationNeededResponse,
|
||||||
ClarifyingQuestion,
|
ClarifyingQuestion,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -104,6 +105,10 @@ class EditAgentTool(BaseTool):
|
|||||||
save = kwargs.get("save", True)
|
save = kwargs.get("save", True)
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
|
# Extract async processing params (passed by long-running tool handler)
|
||||||
|
operation_id = kwargs.get("_operation_id")
|
||||||
|
task_id = kwargs.get("_task_id")
|
||||||
|
|
||||||
if not agent_id:
|
if not agent_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide the agent ID to edit.",
|
message="Please provide the agent ID to edit.",
|
||||||
@@ -149,7 +154,11 @@ class EditAgentTool(BaseTool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = await generate_agent_patch(
|
result = await generate_agent_patch(
|
||||||
update_request, current_agent, library_agents
|
update_request,
|
||||||
|
current_agent,
|
||||||
|
library_agents,
|
||||||
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
)
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
@@ -169,6 +178,20 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if Agent Generator accepted for async processing
|
||||||
|
if result.get("status") == "accepted":
|
||||||
|
logger.info(
|
||||||
|
f"Agent edit delegated to async processing "
|
||||||
|
f"(operation_id={operation_id}, task_id={task_id})"
|
||||||
|
)
|
||||||
|
return AsyncProcessingResponse(
|
||||||
|
message="Agent edit started. You'll be notified when it's complete.",
|
||||||
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the result is an error from the external service
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
error_msg = result.get("error", "Unknown error")
|
error_msg = result.get("error", "Unknown error")
|
||||||
error_type = result.get("error_type", "unknown")
|
error_type = result.get("error_type", "unknown")
|
||||||
|
|||||||
@@ -13,10 +13,33 @@ from backend.api.features.chat.tools.models import (
|
|||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
from backend.data.block import get_block
|
from backend.blocks import get_block
|
||||||
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_TARGET_RESULTS = 10
|
||||||
|
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
|
||||||
|
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
|
||||||
|
_OVERFETCH_PAGE_SIZE = 40
|
||||||
|
|
||||||
|
# Block types that only work within graphs and cannot run standalone in CoPilot.
|
||||||
|
COPILOT_EXCLUDED_BLOCK_TYPES = {
|
||||||
|
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
|
||||||
|
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
|
||||||
|
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
|
||||||
|
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
|
||||||
|
BlockType.NOTE, # Visual annotation only - no runtime behavior
|
||||||
|
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
|
||||||
|
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
|
||||||
|
}
|
||||||
|
|
||||||
|
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
||||||
|
COPILOT_EXCLUDED_BLOCK_IDS = {
|
||||||
|
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
|
||||||
|
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class FindBlockTool(BaseTool):
|
class FindBlockTool(BaseTool):
|
||||||
"""Tool for searching available blocks."""
|
"""Tool for searching available blocks."""
|
||||||
@@ -88,7 +111,7 @@ class FindBlockTool(BaseTool):
|
|||||||
query=query,
|
query=query,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
page_size=10,
|
page_size=_OVERFETCH_PAGE_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
@@ -108,18 +131,35 @@ class FindBlockTool(BaseTool):
|
|||||||
block = get_block(block_id)
|
block = get_block(block_id)
|
||||||
|
|
||||||
# Skip disabled blocks
|
# Skip disabled blocks
|
||||||
if block and not block.disabled:
|
if not block or block.disabled:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip blocks excluded from CoPilot (graph-only blocks)
|
||||||
|
if (
|
||||||
|
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
# Get input/output schemas
|
# Get input/output schemas
|
||||||
input_schema = {}
|
input_schema = {}
|
||||||
output_schema = {}
|
output_schema = {}
|
||||||
try:
|
try:
|
||||||
input_schema = block.input_schema.jsonschema()
|
input_schema = block.input_schema.jsonschema()
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.debug(
|
||||||
|
"Failed to generate input schema for block %s: %s",
|
||||||
|
block_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
output_schema = block.output_schema.jsonschema()
|
output_schema = block.output_schema.jsonschema()
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.debug(
|
||||||
|
"Failed to generate output schema for block %s: %s",
|
||||||
|
block_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
# Get categories from block instance
|
# Get categories from block instance
|
||||||
categories = []
|
categories = []
|
||||||
@@ -163,6 +203,19 @@ class FindBlockTool(BaseTool):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(blocks) >= _TARGET_RESULTS:
|
||||||
|
break
|
||||||
|
|
||||||
|
if blocks and len(blocks) < _TARGET_RESULTS:
|
||||||
|
logger.debug(
|
||||||
|
"find_block returned %d/%d results for query '%s' "
|
||||||
|
"(filtered %d excluded/disabled blocks)",
|
||||||
|
len(blocks),
|
||||||
|
_TARGET_RESULTS,
|
||||||
|
query,
|
||||||
|
len(results) - len(blocks),
|
||||||
|
)
|
||||||
|
|
||||||
if not blocks:
|
if not blocks:
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
message=f"No blocks found for '{query}'",
|
message=f"No blocks found for '{query}'",
|
||||||
|
|||||||
@@ -0,0 +1,139 @@
|
|||||||
|
"""Tests for block filtering in FindBlockTool."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.find_block import (
|
||||||
|
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||||
|
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||||
|
FindBlockTool,
|
||||||
|
)
|
||||||
|
from backend.api.features.chat.tools.models import BlockListResponse
|
||||||
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
_TEST_USER_ID = "test-user-find-block"
|
||||||
|
|
||||||
|
|
||||||
|
def make_mock_block(
|
||||||
|
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||||
|
):
|
||||||
|
"""Create a mock block for testing."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.id = block_id
|
||||||
|
mock.name = name
|
||||||
|
mock.description = f"{name} description"
|
||||||
|
mock.block_type = block_type
|
||||||
|
mock.disabled = disabled
|
||||||
|
mock.input_schema = MagicMock()
|
||||||
|
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||||
|
mock.input_schema.get_credentials_fields.return_value = {}
|
||||||
|
mock.output_schema = MagicMock()
|
||||||
|
mock.output_schema.jsonschema.return_value = {}
|
||||||
|
mock.categories = []
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindBlockFiltering:
|
||||||
|
"""Tests for block filtering in FindBlockTool."""
|
||||||
|
|
||||||
|
def test_excluded_block_types_contains_expected_types(self):
|
||||||
|
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
|
||||||
|
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
|
||||||
|
def test_excluded_block_ids_contains_smart_decision_maker(self):
|
||||||
|
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
||||||
|
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_type_filtered_from_results(self):
|
||||||
|
"""Verify blocks with excluded BlockTypes are filtered from search results."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
|
||||||
|
search_results = [
|
||||||
|
{"content_id": "input-block-id", "score": 0.9},
|
||||||
|
{"content_id": "standard-block-id", "score": 0.8},
|
||||||
|
]
|
||||||
|
|
||||||
|
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||||
|
standard_block = make_mock_block(
|
||||||
|
"standard-block-id", "HTTP Request", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_get_block(block_id):
|
||||||
|
return {
|
||||||
|
"input-block-id": input_block,
|
||||||
|
"standard-block-id": standard_block,
|
||||||
|
}.get(block_id)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=(search_results, 2),
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.get_block",
|
||||||
|
side_effect=mock_get_block,
|
||||||
|
):
|
||||||
|
tool = FindBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID, session=session, query="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only return the standard block, not the INPUT block
|
||||||
|
assert isinstance(response, BlockListResponse)
|
||||||
|
assert len(response.blocks) == 1
|
||||||
|
assert response.blocks[0].id == "standard-block-id"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_id_filtered_from_results(self):
|
||||||
|
"""Verify SmartDecisionMakerBlock is filtered from search results."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||||
|
search_results = [
|
||||||
|
{"content_id": smart_decision_id, "score": 0.9},
|
||||||
|
{"content_id": "normal-block-id", "score": 0.8},
|
||||||
|
]
|
||||||
|
|
||||||
|
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
|
||||||
|
smart_block = make_mock_block(
|
||||||
|
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
normal_block = make_mock_block(
|
||||||
|
"normal-block-id", "Normal Block", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_get_block(block_id):
|
||||||
|
return {
|
||||||
|
smart_decision_id: smart_block,
|
||||||
|
"normal-block-id": normal_block,
|
||||||
|
}.get(block_id)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=(search_results, 2),
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.get_block",
|
||||||
|
side_effect=mock_get_block,
|
||||||
|
):
|
||||||
|
tool = FindBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID, session=session, query="decision"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only return normal block, not SmartDecisionMakerBlock
|
||||||
|
assert isinstance(response, BlockListResponse)
|
||||||
|
assert len(response.blocks) == 1
|
||||||
|
assert response.blocks[0].id == "normal-block-id"
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
"""Shared helpers for chat tools."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def get_inputs_from_schema(
|
||||||
|
input_schema: dict[str, Any],
|
||||||
|
exclude_fields: set[str] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Extract input field info from JSON schema."""
|
||||||
|
if not isinstance(input_schema, dict):
|
||||||
|
return []
|
||||||
|
|
||||||
|
exclude = exclude_fields or set()
|
||||||
|
properties = input_schema.get("properties", {})
|
||||||
|
required = set(input_schema.get("required", []))
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"title": schema.get("title", name),
|
||||||
|
"type": schema.get("type", "string"),
|
||||||
|
"description": schema.get("description", ""),
|
||||||
|
"required": name in required,
|
||||||
|
"default": schema.get("default"),
|
||||||
|
}
|
||||||
|
for name, schema in properties.items()
|
||||||
|
if name not in exclude
|
||||||
|
]
|
||||||
@@ -38,6 +38,8 @@ class ResponseType(str, Enum):
|
|||||||
OPERATION_STARTED = "operation_started"
|
OPERATION_STARTED = "operation_started"
|
||||||
OPERATION_PENDING = "operation_pending"
|
OPERATION_PENDING = "operation_pending"
|
||||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||||
|
# Input validation
|
||||||
|
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||||
|
|
||||||
|
|
||||||
# Base response model
|
# Base response model
|
||||||
@@ -68,6 +70,10 @@ class AgentInfo(BaseModel):
|
|||||||
has_external_trigger: bool | None = None
|
has_external_trigger: bool | None = None
|
||||||
new_output: bool | None = None
|
new_output: bool | None = None
|
||||||
graph_id: str | None = None
|
graph_id: str | None = None
|
||||||
|
inputs: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Input schema for the agent, including field names, types, and defaults",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AgentsFoundResponse(ToolResponseBase):
|
class AgentsFoundResponse(ToolResponseBase):
|
||||||
@@ -194,6 +200,20 @@ class ErrorResponse(ToolResponseBase):
|
|||||||
details: dict[str, Any] | None = None
|
details: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class InputValidationErrorResponse(ToolResponseBase):
|
||||||
|
"""Response when run_agent receives unknown input fields."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
|
||||||
|
unrecognized_fields: list[str] = Field(
|
||||||
|
description="List of input field names that were not recognized"
|
||||||
|
)
|
||||||
|
inputs: dict[str, Any] = Field(
|
||||||
|
description="The agent's valid input schema for reference"
|
||||||
|
)
|
||||||
|
graph_id: str | None = None
|
||||||
|
graph_version: int | None = None
|
||||||
|
|
||||||
|
|
||||||
# Agent output models
|
# Agent output models
|
||||||
class ExecutionOutputInfo(BaseModel):
|
class ExecutionOutputInfo(BaseModel):
|
||||||
"""Summary of a single execution's outputs."""
|
"""Summary of a single execution's outputs."""
|
||||||
@@ -352,11 +372,15 @@ class OperationStartedResponse(ToolResponseBase):
|
|||||||
|
|
||||||
This is returned immediately to the client while the operation continues
|
This is returned immediately to the client while the operation continues
|
||||||
to execute. The user can close the tab and check back later.
|
to execute. The user can close the tab and check back later.
|
||||||
|
|
||||||
|
The task_id can be used to reconnect to the SSE stream via
|
||||||
|
GET /chat/tasks/{task_id}/stream?last_idx=0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||||
operation_id: str
|
operation_id: str
|
||||||
tool_name: str
|
tool_name: str
|
||||||
|
task_id: str | None = None # For SSE reconnection
|
||||||
|
|
||||||
|
|
||||||
class OperationPendingResponse(ToolResponseBase):
|
class OperationPendingResponse(ToolResponseBase):
|
||||||
@@ -380,3 +404,20 @@ class OperationInProgressResponse(ToolResponseBase):
|
|||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||||
tool_call_id: str
|
tool_call_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncProcessingResponse(ToolResponseBase):
|
||||||
|
"""Response when an operation has been delegated to async processing.
|
||||||
|
|
||||||
|
This is returned by tools when the external service accepts the request
|
||||||
|
for async processing (HTTP 202 Accepted). The Redis Streams completion
|
||||||
|
consumer will handle the result when the external service completes.
|
||||||
|
|
||||||
|
The status field is specifically "accepted" to allow the long-running tool
|
||||||
|
handler to detect this response and skip LLM continuation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||||
|
status: str = "accepted" # Must be "accepted" for detection
|
||||||
|
operation_id: str | None = None
|
||||||
|
task_id: str | None = None
|
||||||
|
|||||||
@@ -24,12 +24,14 @@ from backend.util.timezone_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
AgentDetails,
|
AgentDetails,
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
ExecutionOptions,
|
ExecutionOptions,
|
||||||
ExecutionStartedResponse,
|
ExecutionStartedResponse,
|
||||||
|
InputValidationErrorResponse,
|
||||||
SetupInfo,
|
SetupInfo,
|
||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
@@ -260,7 +262,7 @@ class RunAgentTool(BaseTool):
|
|||||||
),
|
),
|
||||||
requirements={
|
requirements={
|
||||||
"credentials": requirements_creds_list,
|
"credentials": requirements_creds_list,
|
||||||
"inputs": self._get_inputs_list(graph.input_schema),
|
"inputs": get_inputs_from_schema(graph.input_schema),
|
||||||
"execution_modes": self._get_execution_modes(graph),
|
"execution_modes": self._get_execution_modes(graph),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@@ -273,6 +275,22 @@ class RunAgentTool(BaseTool):
|
|||||||
input_properties = graph.input_schema.get("properties", {})
|
input_properties = graph.input_schema.get("properties", {})
|
||||||
required_fields = set(graph.input_schema.get("required", []))
|
required_fields = set(graph.input_schema.get("required", []))
|
||||||
provided_inputs = set(params.inputs.keys())
|
provided_inputs = set(params.inputs.keys())
|
||||||
|
valid_fields = set(input_properties.keys())
|
||||||
|
|
||||||
|
# Check for unknown input fields
|
||||||
|
unrecognized_fields = provided_inputs - valid_fields
|
||||||
|
if unrecognized_fields:
|
||||||
|
return InputValidationErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
||||||
|
f"Agent was not executed. Please use the correct field names from the schema."
|
||||||
|
),
|
||||||
|
session_id=session_id,
|
||||||
|
unrecognized_fields=sorted(unrecognized_fields),
|
||||||
|
inputs=graph.input_schema,
|
||||||
|
graph_id=graph.id,
|
||||||
|
graph_version=graph.version,
|
||||||
|
)
|
||||||
|
|
||||||
# If agent has inputs but none were provided AND use_defaults is not set,
|
# If agent has inputs but none were provided AND use_defaults is not set,
|
||||||
# always show what's available first so user can decide
|
# always show what's available first so user can decide
|
||||||
@@ -352,22 +370,6 @@ class RunAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
|
||||||
"""Extract inputs list from schema."""
|
|
||||||
inputs_list = []
|
|
||||||
if isinstance(input_schema, dict) and "properties" in input_schema:
|
|
||||||
for field_name, field_schema in input_schema["properties"].items():
|
|
||||||
inputs_list.append(
|
|
||||||
{
|
|
||||||
"name": field_name,
|
|
||||||
"title": field_schema.get("title", field_name),
|
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
|
||||||
"required": field_name in input_schema.get("required", []),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return inputs_list
|
|
||||||
|
|
||||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||||
"""Get available execution modes for the graph."""
|
"""Get available execution modes for the graph."""
|
||||||
trigger_info = graph.trigger_setup_info
|
trigger_info = graph.trigger_setup_info
|
||||||
@@ -381,7 +383,7 @@ class RunAgentTool(BaseTool):
|
|||||||
suffix: str,
|
suffix: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build a message describing available inputs for an agent."""
|
"""Build a message describing available inputs for an agent."""
|
||||||
inputs_list = self._get_inputs_list(graph.input_schema)
|
inputs_list = get_inputs_from_schema(graph.input_schema)
|
||||||
required_names = [i["name"] for i in inputs_list if i["required"]]
|
required_names = [i["name"] for i in inputs_list if i["required"]]
|
||||||
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
||||||
|
|
||||||
|
|||||||
@@ -402,3 +402,42 @@ async def test_run_agent_schedule_without_name(setup_test_data):
|
|||||||
# Should return error about missing schedule_name
|
# Should return error about missing schedule_name
|
||||||
assert result_data.get("type") == "error"
|
assert result_data.get("type") == "error"
|
||||||
assert "schedule_name" in result_data["message"].lower()
|
assert "schedule_name" in result_data["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
|
||||||
|
"""Test that run_agent returns input_validation_error for unknown input fields."""
|
||||||
|
user = setup_test_data["user"]
|
||||||
|
store_submission = setup_test_data["store_submission"]
|
||||||
|
|
||||||
|
tool = RunAgentTool()
|
||||||
|
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||||
|
session = make_session(user_id=user.id)
|
||||||
|
|
||||||
|
# Execute with unknown input field names
|
||||||
|
response = await tool.execute(
|
||||||
|
user_id=user.id,
|
||||||
|
session_id=str(uuid.uuid4()),
|
||||||
|
tool_call_id=str(uuid.uuid4()),
|
||||||
|
username_agent_slug=agent_marketplace_id,
|
||||||
|
inputs={
|
||||||
|
"unknown_field": "some value",
|
||||||
|
"another_unknown": "another value",
|
||||||
|
},
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert hasattr(response, "output")
|
||||||
|
assert isinstance(response.output, str)
|
||||||
|
result_data = orjson.loads(response.output)
|
||||||
|
|
||||||
|
# Should return input_validation_error type with unrecognized fields
|
||||||
|
assert result_data.get("type") == "input_validation_error"
|
||||||
|
assert "unrecognized_fields" in result_data
|
||||||
|
assert set(result_data["unrecognized_fields"]) == {
|
||||||
|
"another_unknown",
|
||||||
|
"unknown_field",
|
||||||
|
}
|
||||||
|
assert "inputs" in result_data # Contains the valid schema
|
||||||
|
assert "Agent was not executed" in result_data["message"]
|
||||||
|
|||||||
@@ -5,15 +5,23 @@ import uuid
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.block import get_block
|
from backend.api.features.chat.tools.find_block import (
|
||||||
|
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||||
|
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||||
|
)
|
||||||
|
from backend.blocks import get_block
|
||||||
|
from backend.blocks._base import AnyBlockSchema
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -22,7 +30,10 @@ from .models import (
|
|||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
UserReadiness,
|
UserReadiness,
|
||||||
)
|
)
|
||||||
from .utils import build_missing_credentials_from_field_info
|
from .utils import (
|
||||||
|
build_missing_credentials_from_field_info,
|
||||||
|
match_credentials_to_requirements,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -71,65 +82,6 @@ class RunBlockTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def _check_block_credentials(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
block: Any,
|
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
|
||||||
"""
|
|
||||||
Check if user has required credentials for a block.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[matched_credentials, missing_credentials]
|
|
||||||
"""
|
|
||||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
|
||||||
missing_credentials: list[CredentialsMetaInput] = []
|
|
||||||
|
|
||||||
# Get credential field info from block's input schema
|
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
|
||||||
|
|
||||||
if not credentials_fields_info:
|
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
# Get user's available credentials
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
|
||||||
|
|
||||||
for field_name, field_info in credentials_fields_info.items():
|
|
||||||
# field_info.provider is a frozenset of acceptable providers
|
|
||||||
# field_info.supported_types is a frozenset of acceptable types
|
|
||||||
matching_cred = next(
|
|
||||||
(
|
|
||||||
cred
|
|
||||||
for cred in available_creds
|
|
||||||
if cred.provider in field_info.provider
|
|
||||||
and cred.type in field_info.supported_types
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if matching_cred:
|
|
||||||
matched_credentials[field_name] = CredentialsMetaInput(
|
|
||||||
id=matching_cred.id,
|
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Create a placeholder for the missing credential
|
|
||||||
provider = next(iter(field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
|
||||||
missing_credentials.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=field_name.replace("_", " ").title(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
@@ -184,12 +136,24 @@ class RunBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if block is excluded from CoPilot (graph-only blocks)
|
||||||
|
if (
|
||||||
|
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||||
|
):
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Block '{block.name}' cannot be run directly in CoPilot. "
|
||||||
|
"This block is designed for use within graphs only."
|
||||||
|
),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||||
|
|
||||||
# Check credentials
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
creds_manager = IntegrationCredentialsManager()
|
||||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
matched_credentials, missing_credentials = (
|
||||||
user_id, block
|
await self._resolve_block_credentials(user_id, block, input_data)
|
||||||
)
|
)
|
||||||
|
|
||||||
if missing_credentials:
|
if missing_credentials:
|
||||||
@@ -318,29 +282,75 @@ class RunBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
async def _resolve_block_credentials(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
block: AnyBlockSchema,
|
||||||
|
input_data: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Resolve credentials for a block by matching user's available credentials.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
block: Block to resolve credentials for
|
||||||
|
input_data: Input data for the block (used to determine provider via discriminator)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple of (matched_credentials, missing_credentials) - matched credentials
|
||||||
|
are used for block execution, missing ones indicate setup requirements.
|
||||||
|
"""
|
||||||
|
input_data = input_data or {}
|
||||||
|
requirements = self._resolve_discriminated_credentials(block, input_data)
|
||||||
|
|
||||||
|
if not requirements:
|
||||||
|
return {}, []
|
||||||
|
|
||||||
|
return await match_credentials_to_requirements(user_id, requirements)
|
||||||
|
|
||||||
|
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
|
||||||
"""Extract non-credential inputs from block schema."""
|
"""Extract non-credential inputs from block schema."""
|
||||||
inputs_list = []
|
|
||||||
schema = block.input_schema.jsonschema()
|
schema = block.input_schema.jsonschema()
|
||||||
properties = schema.get("properties", {})
|
|
||||||
required_fields = set(schema.get("required", []))
|
|
||||||
|
|
||||||
# Get credential field names to exclude
|
|
||||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||||
|
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||||
|
|
||||||
for field_name, field_schema in properties.items():
|
def _resolve_discriminated_credentials(
|
||||||
# Skip credential fields
|
self,
|
||||||
if field_name in credentials_fields:
|
block: AnyBlockSchema,
|
||||||
continue
|
input_data: dict[str, Any],
|
||||||
|
) -> dict[str, CredentialsFieldInfo]:
|
||||||
|
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||||
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
|
if not credentials_fields_info:
|
||||||
|
return {}
|
||||||
|
|
||||||
inputs_list.append(
|
resolved: dict[str, CredentialsFieldInfo] = {}
|
||||||
{
|
|
||||||
"name": field_name,
|
for field_name, field_info in credentials_fields_info.items():
|
||||||
"title": field_schema.get("title", field_name),
|
effective_field_info = field_info
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
if field_info.discriminator and field_info.discriminator_mapping:
|
||||||
"required": field_name in required_fields,
|
discriminator_value = input_data.get(field_info.discriminator)
|
||||||
}
|
if discriminator_value is None:
|
||||||
|
field = block.input_schema.model_fields.get(
|
||||||
|
field_info.discriminator
|
||||||
|
)
|
||||||
|
if field and field.default is not PydanticUndefined:
|
||||||
|
discriminator_value = field.default
|
||||||
|
|
||||||
|
if (
|
||||||
|
discriminator_value
|
||||||
|
and discriminator_value in field_info.discriminator_mapping
|
||||||
|
):
|
||||||
|
effective_field_info = field_info.discriminate(discriminator_value)
|
||||||
|
# For host-scoped credentials, add the discriminator value
|
||||||
|
# (e.g., URL) so _credential_is_for_host can match it
|
||||||
|
effective_field_info.discriminator_values.add(discriminator_value)
|
||||||
|
logger.debug(
|
||||||
|
f"Discriminated provider for {field_name}: "
|
||||||
|
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return inputs_list
|
resolved[field_name] = effective_field_info
|
||||||
|
|
||||||
|
return resolved
|
||||||
|
|||||||
@@ -0,0 +1,106 @@
|
|||||||
|
"""Tests for block execution guards in RunBlockTool."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.models import ErrorResponse
|
||||||
|
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||||
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
_TEST_USER_ID = "test-user-run-block"
|
||||||
|
|
||||||
|
|
||||||
|
def make_mock_block(
|
||||||
|
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||||
|
):
|
||||||
|
"""Create a mock block for testing."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.id = block_id
|
||||||
|
mock.name = name
|
||||||
|
mock.block_type = block_type
|
||||||
|
mock.disabled = disabled
|
||||||
|
mock.input_schema = MagicMock()
|
||||||
|
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||||
|
mock.input_schema.get_credentials_fields_info.return_value = []
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunBlockFiltering:
|
||||||
|
"""Tests for block execution guards in RunBlockTool."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_type_returns_error(self):
|
||||||
|
"""Attempting to execute a block with excluded BlockType returns error."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=input_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="input-block-id",
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, ErrorResponse)
|
||||||
|
assert "cannot be run directly in CoPilot" in response.message
|
||||||
|
assert "designed for use within graphs only" in response.message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_id_returns_error(self):
|
||||||
|
"""Attempting to execute SmartDecisionMakerBlock returns error."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||||
|
smart_block = make_mock_block(
|
||||||
|
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=smart_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id=smart_decision_id,
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, ErrorResponse)
|
||||||
|
assert "cannot be run directly in CoPilot" in response.message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_non_excluded_block_passes_guard(self):
|
||||||
|
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
standard_block = make_mock_block(
|
||||||
|
"standard-id", "HTTP Request", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=standard_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="standard-id",
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should NOT be an ErrorResponse about CoPilot exclusion
|
||||||
|
# (may be other errors like missing credentials, but not the exclusion guard)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
assert "cannot be run directly in CoPilot" not in response.message
|
||||||
@@ -6,9 +6,14 @@ from typing import Any
|
|||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.library import model as library_model
|
from backend.api.features.library import model as library_model
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data import graph as graph_db
|
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
from backend.data.model import (
|
||||||
|
Credentials,
|
||||||
|
CredentialsFieldInfo,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
HostScopedCredentials,
|
||||||
|
OAuth2Credentials,
|
||||||
|
)
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
@@ -39,14 +44,8 @@ async def fetch_graph_from_store_slug(
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Get the graph from store listing version
|
# Get the graph from store listing version
|
||||||
graph_meta = await store_db.get_available_graph(
|
graph = await store_db.get_available_graph(
|
||||||
store_agent.store_listing_version_id
|
store_agent.store_listing_version_id, hide_nodes=False
|
||||||
)
|
|
||||||
graph = await graph_db.get_graph(
|
|
||||||
graph_id=graph_meta.id,
|
|
||||||
version=graph_meta.version,
|
|
||||||
user_id=None, # Public access
|
|
||||||
include_subgraphs=True,
|
|
||||||
)
|
)
|
||||||
return graph, store_agent
|
return graph, store_agent
|
||||||
|
|
||||||
@@ -123,7 +122,7 @@ def build_missing_credentials_from_graph(
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
field_key: _serialize_missing_credential(field_key, field_info)
|
field_key: _serialize_missing_credential(field_key, field_info)
|
||||||
for field_key, (field_info, _node_fields) in aggregated_fields.items()
|
for field_key, (field_info, _, _) in aggregated_fields.items()
|
||||||
if field_key not in matched_keys
|
if field_key not in matched_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,6 +224,99 @@ async def get_or_create_library_agent(
|
|||||||
return library_agents[0]
|
return library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def match_credentials_to_requirements(
|
||||||
|
user_id: str,
|
||||||
|
requirements: dict[str, CredentialsFieldInfo],
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Match user's credentials against a dictionary of credential requirements.
|
||||||
|
|
||||||
|
This is the core matching logic shared by both graph and block credential matching.
|
||||||
|
"""
|
||||||
|
matched: dict[str, CredentialsMetaInput] = {}
|
||||||
|
missing: list[CredentialsMetaInput] = []
|
||||||
|
|
||||||
|
if not requirements:
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
available_creds = await get_user_credentials(user_id)
|
||||||
|
|
||||||
|
for field_name, field_info in requirements.items():
|
||||||
|
matching_cred = find_matching_credential(available_creds, field_info)
|
||||||
|
|
||||||
|
if matching_cred:
|
||||||
|
try:
|
||||||
|
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
||||||
|
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
||||||
|
f"credential_id={matching_cred.id}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=f"{field_name} (validation failed: {e})",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=field_name.replace("_", " ").title(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_credentials(user_id: str) -> list[Credentials]:
|
||||||
|
"""Get all available credentials for a user."""
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
return await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def find_matching_credential(
|
||||||
|
available_creds: list[Credentials],
|
||||||
|
field_info: CredentialsFieldInfo,
|
||||||
|
) -> Credentials | None:
|
||||||
|
"""Find a credential that matches the required provider, type, scopes, and host."""
|
||||||
|
for cred in available_creds:
|
||||||
|
if cred.provider not in field_info.provider:
|
||||||
|
continue
|
||||||
|
if cred.type not in field_info.supported_types:
|
||||||
|
continue
|
||||||
|
if cred.type == "oauth2" and not _credential_has_required_scopes(
|
||||||
|
cred, field_info
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
|
||||||
|
continue
|
||||||
|
return cred
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_credential_meta_from_match(
|
||||||
|
matching_cred: Credentials,
|
||||||
|
) -> CredentialsMetaInput:
|
||||||
|
"""Create a CredentialsMetaInput from a matched credential."""
|
||||||
|
return CredentialsMetaInput(
|
||||||
|
id=matching_cred.id,
|
||||||
|
provider=matching_cred.provider, # type: ignore
|
||||||
|
type=matching_cred.type,
|
||||||
|
title=matching_cred.title,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def match_user_credentials_to_graph(
|
async def match_user_credentials_to_graph(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
graph: GraphModel,
|
graph: GraphModel,
|
||||||
@@ -264,7 +356,8 @@ async def match_user_credentials_to_graph(
|
|||||||
# provider is in the set of acceptable providers.
|
# provider is in the set of acceptable providers.
|
||||||
for credential_field_name, (
|
for credential_field_name, (
|
||||||
credential_requirements,
|
credential_requirements,
|
||||||
_node_fields,
|
_,
|
||||||
|
_,
|
||||||
) in aggregated_creds.items():
|
) in aggregated_creds.items():
|
||||||
# Find first matching credential by provider, type, and scopes
|
# Find first matching credential by provider, type, and scopes
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
@@ -273,7 +366,14 @@ async def match_user_credentials_to_graph(
|
|||||||
for cred in available_creds
|
for cred in available_creds
|
||||||
if cred.provider in credential_requirements.provider
|
if cred.provider in credential_requirements.provider
|
||||||
and cred.type in credential_requirements.supported_types
|
and cred.type in credential_requirements.supported_types
|
||||||
and _credential_has_required_scopes(cred, credential_requirements)
|
and (
|
||||||
|
cred.type != "oauth2"
|
||||||
|
or _credential_has_required_scopes(cred, credential_requirements)
|
||||||
|
)
|
||||||
|
and (
|
||||||
|
cred.type != "host_scoped"
|
||||||
|
or _credential_is_for_host(cred, credential_requirements)
|
||||||
|
)
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -318,27 +418,32 @@ async def match_user_credentials_to_graph(
|
|||||||
|
|
||||||
|
|
||||||
def _credential_has_required_scopes(
|
def _credential_has_required_scopes(
|
||||||
credential: Credentials,
|
credential: OAuth2Credentials,
|
||||||
requirements: CredentialsFieldInfo,
|
requirements: CredentialsFieldInfo,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""Check if an OAuth2 credential has all the scopes required by the input."""
|
||||||
Check if a credential has all the scopes required by the block.
|
|
||||||
|
|
||||||
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
|
||||||
of the required scopes. For other credential types, returns True (no scope check).
|
|
||||||
"""
|
|
||||||
# Only OAuth2 credentials have scopes to check
|
|
||||||
if credential.type != "oauth2":
|
|
||||||
return True
|
|
||||||
|
|
||||||
# If no scopes are required, any credential matches
|
# If no scopes are required, any credential matches
|
||||||
if not requirements.required_scopes:
|
if not requirements.required_scopes:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check that credential scopes are a superset of required scopes
|
|
||||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
|
def _credential_is_for_host(
|
||||||
|
credential: HostScopedCredentials,
|
||||||
|
requirements: CredentialsFieldInfo,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if a host-scoped credential matches the host required by the input."""
|
||||||
|
# We need to know the host to match host-scoped credentials to.
|
||||||
|
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
|
||||||
|
# to discriminator_values. No discriminator_values -> no host to match against.
|
||||||
|
if not requirements.discriminator_values:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check that credential host matches required host.
|
||||||
|
# Host-scoped credential inputs are grouped by host, so any item from the set works.
|
||||||
|
return credential.matches_url(list(requirements.discriminator_values)[0])
|
||||||
|
|
||||||
|
|
||||||
async def check_user_has_required_credentials(
|
async def check_user_has_required_credentials(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
required_credentials: list[CredentialsMetaInput],
|
required_credentials: list[CredentialsMetaInput],
|
||||||
|
|||||||
@@ -12,14 +12,16 @@ import backend.api.features.store.image_gen as store_image_gen
|
|||||||
import backend.api.features.store.media as store_media
|
import backend.api.features.store.media as store_media
|
||||||
import backend.data.graph as graph_db
|
import backend.data.graph as graph_db
|
||||||
import backend.data.integrations as integrations_db
|
import backend.data.integrations as integrations_db
|
||||||
from backend.data.block import BlockInput
|
|
||||||
from backend.data.db import transaction
|
from backend.data.db import transaction
|
||||||
from backend.data.execution import get_graph_execution
|
from backend.data.execution import get_graph_execution
|
||||||
from backend.data.graph import GraphSettings
|
from backend.data.graph import GraphSettings
|
||||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput, GraphInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||||
|
on_graph_activate,
|
||||||
|
on_graph_deactivate,
|
||||||
|
)
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
@@ -371,7 +373,7 @@ async def get_library_agent_by_graph_id(
|
|||||||
|
|
||||||
|
|
||||||
async def add_generated_agent_image(
|
async def add_generated_agent_image(
|
||||||
graph: graph_db.BaseGraph,
|
graph: graph_db.GraphBaseMeta,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
) -> Optional[prisma.models.LibraryAgent]:
|
) -> Optional[prisma.models.LibraryAgent]:
|
||||||
@@ -537,6 +539,92 @@ async def update_agent_version_in_library(
|
|||||||
return library_model.LibraryAgent.from_db(lib)
|
return library_model.LibraryAgent.from_db(lib)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_graph_in_library(
|
||||||
|
graph: graph_db.Graph,
|
||||||
|
user_id: str,
|
||||||
|
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||||
|
"""Create a new graph and add it to the user's library."""
|
||||||
|
graph.version = 1
|
||||||
|
graph_model = graph_db.make_graph_model(graph, user_id)
|
||||||
|
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||||
|
|
||||||
|
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||||
|
|
||||||
|
library_agents = await create_library_agent(
|
||||||
|
graph=created_graph,
|
||||||
|
user_id=user_id,
|
||||||
|
sensitive_action_safe_mode=True,
|
||||||
|
create_library_agents_for_sub_graphs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if created_graph.is_active:
|
||||||
|
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
||||||
|
|
||||||
|
return created_graph, library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def update_graph_in_library(
|
||||||
|
graph: graph_db.Graph,
|
||||||
|
user_id: str,
|
||||||
|
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||||
|
"""Create a new version of an existing graph and update the library entry."""
|
||||||
|
existing_versions = await graph_db.get_graph_all_versions(graph.id, user_id)
|
||||||
|
current_active_version = (
|
||||||
|
next((v for v in existing_versions if v.is_active), None)
|
||||||
|
if existing_versions
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
graph.version = (
|
||||||
|
max(v.version for v in existing_versions) + 1 if existing_versions else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_model = graph_db.make_graph_model(graph, user_id)
|
||||||
|
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||||
|
|
||||||
|
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||||
|
|
||||||
|
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
||||||
|
if not library_agent:
|
||||||
|
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
||||||
|
|
||||||
|
library_agent = await update_library_agent_version_and_settings(
|
||||||
|
user_id, created_graph
|
||||||
|
)
|
||||||
|
|
||||||
|
if created_graph.is_active:
|
||||||
|
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
||||||
|
await graph_db.set_graph_active_version(
|
||||||
|
graph_id=created_graph.id,
|
||||||
|
version=created_graph.version,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if current_active_version:
|
||||||
|
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||||
|
|
||||||
|
return created_graph, library_agent
|
||||||
|
|
||||||
|
|
||||||
|
async def update_library_agent_version_and_settings(
|
||||||
|
user_id: str, agent_graph: graph_db.GraphModel
|
||||||
|
) -> library_model.LibraryAgent:
|
||||||
|
"""Update library agent to point to new graph version and sync settings."""
|
||||||
|
library = await update_agent_version_in_library(
|
||||||
|
user_id, agent_graph.id, agent_graph.version
|
||||||
|
)
|
||||||
|
updated_settings = GraphSettings.from_graph(
|
||||||
|
graph=agent_graph,
|
||||||
|
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||||
|
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||||
|
)
|
||||||
|
if updated_settings != library.settings:
|
||||||
|
library = await update_library_agent(
|
||||||
|
library_agent_id=library.id,
|
||||||
|
user_id=user_id,
|
||||||
|
settings=updated_settings,
|
||||||
|
)
|
||||||
|
return library
|
||||||
|
|
||||||
|
|
||||||
async def update_library_agent(
|
async def update_library_agent(
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
@@ -1041,7 +1129,7 @@ async def create_preset_from_graph_execution(
|
|||||||
async def update_preset(
|
async def update_preset(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
preset_id: str,
|
preset_id: str,
|
||||||
inputs: Optional[BlockInput] = None,
|
inputs: Optional[GraphInput] = None,
|
||||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
|
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
description: Optional[str] = None,
|
description: Optional[str] = None,
|
||||||
|
|||||||
@@ -6,9 +6,12 @@ import prisma.enums
|
|||||||
import prisma.models
|
import prisma.models
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
from backend.data.block import BlockInput
|
|
||||||
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
||||||
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
from backend.data.model import (
|
||||||
|
CredentialsMetaInput,
|
||||||
|
GraphInput,
|
||||||
|
is_credentials_field_name,
|
||||||
|
)
|
||||||
from backend.util.json import loads as json_loads
|
from backend.util.json import loads as json_loads
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
@@ -323,7 +326,7 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
|
|||||||
graph_id: str
|
graph_id: str
|
||||||
graph_version: int
|
graph_version: int
|
||||||
|
|
||||||
inputs: BlockInput
|
inputs: GraphInput
|
||||||
credentials: dict[str, CredentialsMetaInput]
|
credentials: dict[str, CredentialsMetaInput]
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
@@ -352,7 +355,7 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
|
|||||||
Request model used when updating a preset for a library agent.
|
Request model used when updating a preset for a library agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
inputs: Optional[BlockInput] = None
|
inputs: Optional[GraphInput] = None
|
||||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None
|
credentials: Optional[dict[str, CredentialsMetaInput]] = None
|
||||||
|
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
@@ -395,7 +398,7 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
|||||||
"Webhook must be included in AgentPreset query when webhookId is set"
|
"Webhook must be included in AgentPreset query when webhookId is set"
|
||||||
)
|
)
|
||||||
|
|
||||||
input_data: BlockInput = {}
|
input_data: GraphInput = {}
|
||||||
input_credentials: dict[str, CredentialsMetaInput] = {}
|
input_credentials: dict[str, CredentialsMetaInput] = {}
|
||||||
|
|
||||||
for preset_input in preset.InputPresets:
|
for preset_input in preset.InputPresets:
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ from typing import Optional
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from backend.blocks import get_block
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.block import get_block
|
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from .models import ApiResponse, ChatRequest, GraphData
|
from .models import ApiResponse, ChatRequest, GraphData
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ class BlockHandler(ContentHandler):
|
|||||||
|
|
||||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||||
"""Fetch blocks without embeddings."""
|
"""Fetch blocks without embeddings."""
|
||||||
from backend.data.block import get_blocks
|
from backend.blocks import get_blocks
|
||||||
|
|
||||||
# Get all available blocks
|
# Get all available blocks
|
||||||
all_blocks = get_blocks()
|
all_blocks = get_blocks()
|
||||||
@@ -249,7 +249,7 @@ class BlockHandler(ContentHandler):
|
|||||||
|
|
||||||
async def get_stats(self) -> dict[str, int]:
|
async def get_stats(self) -> dict[str, int]:
|
||||||
"""Get statistics about block embedding coverage."""
|
"""Get statistics about block embedding coverage."""
|
||||||
from backend.data.block import get_blocks
|
from backend.blocks import get_blocks
|
||||||
|
|
||||||
all_blocks = get_blocks()
|
all_blocks = get_blocks()
|
||||||
|
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ async def test_block_handler_get_missing_items(mocker):
|
|||||||
mock_existing = []
|
mock_existing = []
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.data.block.get_blocks",
|
"backend.blocks.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
@@ -135,7 +135,7 @@ async def test_block_handler_get_stats(mocker):
|
|||||||
mock_embedded = [{"count": 2}]
|
mock_embedded = [{"count": 2}]
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.data.block.get_blocks",
|
"backend.blocks.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
@@ -327,7 +327,7 @@ async def test_block_handler_handles_missing_attributes():
|
|||||||
mock_blocks = {"block-minimal": mock_block_class}
|
mock_blocks = {"block-minimal": mock_block_class}
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.data.block.get_blocks",
|
"backend.blocks.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
@@ -360,7 +360,7 @@ async def test_block_handler_skips_failed_blocks():
|
|||||||
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.data.block.get_blocks",
|
"backend.blocks.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, overload
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
@@ -11,8 +11,8 @@ import prisma.types
|
|||||||
|
|
||||||
from backend.data.db import transaction
|
from backend.data.db import transaction
|
||||||
from backend.data.graph import (
|
from backend.data.graph import (
|
||||||
GraphMeta,
|
|
||||||
GraphModel,
|
GraphModel,
|
||||||
|
GraphModelWithoutNodes,
|
||||||
get_graph,
|
get_graph,
|
||||||
get_graph_as_admin,
|
get_graph_as_admin,
|
||||||
get_sub_graphs,
|
get_sub_graphs,
|
||||||
@@ -334,7 +334,22 @@ async def get_store_agent_details(
|
|||||||
raise DatabaseError("Failed to fetch agent details") from e
|
raise DatabaseError("Failed to fetch agent details") from e
|
||||||
|
|
||||||
|
|
||||||
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
@overload
|
||||||
|
async def get_available_graph(
|
||||||
|
store_listing_version_id: str, hide_nodes: Literal[False]
|
||||||
|
) -> GraphModel: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def get_available_graph(
|
||||||
|
store_listing_version_id: str, hide_nodes: Literal[True] = True
|
||||||
|
) -> GraphModelWithoutNodes: ...
|
||||||
|
|
||||||
|
|
||||||
|
async def get_available_graph(
|
||||||
|
store_listing_version_id: str,
|
||||||
|
hide_nodes: bool = True,
|
||||||
|
) -> GraphModelWithoutNodes | GraphModel:
|
||||||
try:
|
try:
|
||||||
# Get avaialble, non-deleted store listing version
|
# Get avaialble, non-deleted store listing version
|
||||||
store_listing_version = (
|
store_listing_version = (
|
||||||
@@ -344,7 +359,7 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
|||||||
"isAvailable": True,
|
"isAvailable": True,
|
||||||
"isDeleted": False,
|
"isDeleted": False,
|
||||||
},
|
},
|
||||||
include={"AgentGraph": {"include": {"Nodes": True}}},
|
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -354,7 +369,9 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
|||||||
detail=f"Store listing version {store_listing_version_id} not found",
|
detail=f"Store listing version {store_listing_version_id} not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
return GraphModel.from_db(store_listing_version.AgentGraph).meta()
|
return (GraphModelWithoutNodes if hide_nodes else GraphModel).from_db(
|
||||||
|
store_listing_version.AgentGraph
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting agent: {e}")
|
logger.error(f"Error getting agent: {e}")
|
||||||
|
|||||||
@@ -662,7 +662,7 @@ async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
|||||||
)
|
)
|
||||||
current_ids = {row["id"] for row in valid_agents}
|
current_ids = {row["id"] for row in valid_agents}
|
||||||
elif content_type == ContentType.BLOCK:
|
elif content_type == ContentType.BLOCK:
|
||||||
from backend.data.block import get_blocks
|
from backend.blocks import get_blocks
|
||||||
|
|
||||||
current_ids = set(get_blocks().keys())
|
current_ids = set(get_blocks().keys())
|
||||||
elif content_type == ContentType.DOCUMENTATION:
|
elif content_type == ContentType.DOCUMENTATION:
|
||||||
|
|||||||
@@ -454,6 +454,9 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
cleanup_embeddings: list,
|
cleanup_embeddings: list,
|
||||||
):
|
):
|
||||||
"""Test unified search pagination works correctly."""
|
"""Test unified search pagination works correctly."""
|
||||||
|
# Use a unique search term to avoid matching other test data
|
||||||
|
unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
# Create multiple items
|
# Create multiple items
|
||||||
content_ids = []
|
content_ids = []
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
@@ -465,14 +468,14 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
content_type=ContentType.BLOCK,
|
content_type=ContentType.BLOCK,
|
||||||
content_id=content_id,
|
content_id=content_id,
|
||||||
embedding=mock_embedding,
|
embedding=mock_embedding,
|
||||||
searchable_text=f"pagination test item number {i}",
|
searchable_text=f"{unique_term} item number {i}",
|
||||||
metadata={"index": i},
|
metadata={"index": i},
|
||||||
user_id=None,
|
user_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get first page
|
# Get first page
|
||||||
page1_results, total1 = await unified_hybrid_search(
|
page1_results, total1 = await unified_hybrid_search(
|
||||||
query="pagination test",
|
query=unique_term,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
page_size=2,
|
page_size=2,
|
||||||
@@ -480,7 +483,7 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
|
|
||||||
# Get second page
|
# Get second page
|
||||||
page2_results, total2 = await unified_hybrid_search(
|
page2_results, total2 = await unified_hybrid_search(
|
||||||
query="pagination test",
|
query=unique_term,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=2,
|
page=2,
|
||||||
page_size=2,
|
page_size=2,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ Includes BM25 reranking for improved lexical relevance.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@@ -362,7 +363,11 @@ async def unified_hybrid_search(
|
|||||||
LIMIT {limit_param} OFFSET {offset_param}
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
results = await query_raw_with_schema(sql_query, *params)
|
results = await query_raw_with_schema(sql_query, *params)
|
||||||
|
except Exception as e:
|
||||||
|
await _log_vector_error_diagnostics(e)
|
||||||
|
raise
|
||||||
|
|
||||||
total = results[0]["total_count"] if results else 0
|
total = results[0]["total_count"] if results else 0
|
||||||
# Apply BM25 reranking
|
# Apply BM25 reranking
|
||||||
@@ -686,7 +691,11 @@ async def hybrid_search(
|
|||||||
LIMIT {limit_param} OFFSET {offset_param}
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
results = await query_raw_with_schema(sql_query, *params)
|
results = await query_raw_with_schema(sql_query, *params)
|
||||||
|
except Exception as e:
|
||||||
|
await _log_vector_error_diagnostics(e)
|
||||||
|
raise
|
||||||
|
|
||||||
total = results[0]["total_count"] if results else 0
|
total = results[0]["total_count"] if results else 0
|
||||||
|
|
||||||
@@ -718,6 +727,87 @@ async def hybrid_search_simple(
|
|||||||
return await hybrid_search(query=query, page=page, page_size=page_size)
|
return await hybrid_search(query=query, page=page, page_size=page_size)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Diagnostics
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# Rate limit: only log vector error diagnostics once per this interval
|
||||||
|
_VECTOR_DIAG_INTERVAL_SECONDS = 60
|
||||||
|
_last_vector_diag_time: float = 0
|
||||||
|
|
||||||
|
|
||||||
|
async def _log_vector_error_diagnostics(error: Exception) -> None:
|
||||||
|
"""Log diagnostic info when 'type vector does not exist' error occurs.
|
||||||
|
|
||||||
|
Note: Diagnostic queries use query_raw_with_schema which may run on a different
|
||||||
|
pooled connection than the one that failed. Session-level search_path can differ,
|
||||||
|
so these diagnostics show cluster-wide state, not necessarily the failed session.
|
||||||
|
|
||||||
|
Includes rate limiting to avoid log spam - only logs once per minute.
|
||||||
|
Caller should re-raise the error after calling this function.
|
||||||
|
"""
|
||||||
|
global _last_vector_diag_time
|
||||||
|
|
||||||
|
# Check if this is the vector type error
|
||||||
|
error_str = str(error).lower()
|
||||||
|
if not (
|
||||||
|
"type" in error_str and "vector" in error_str and "does not exist" in error_str
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Rate limit: only log once per interval
|
||||||
|
now = time.time()
|
||||||
|
if now - _last_vector_diag_time < _VECTOR_DIAG_INTERVAL_SECONDS:
|
||||||
|
return
|
||||||
|
_last_vector_diag_time = now
|
||||||
|
|
||||||
|
try:
|
||||||
|
diagnostics: dict[str, object] = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
search_path_result = await query_raw_with_schema("SHOW search_path")
|
||||||
|
diagnostics["search_path"] = search_path_result
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics["search_path"] = f"Error: {e}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
schema_result = await query_raw_with_schema("SELECT current_schema()")
|
||||||
|
diagnostics["current_schema"] = schema_result
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics["current_schema"] = f"Error: {e}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_result = await query_raw_with_schema(
|
||||||
|
"SELECT current_user, session_user, current_database()"
|
||||||
|
)
|
||||||
|
diagnostics["user_info"] = user_result
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics["user_info"] = f"Error: {e}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check pgvector extension installation (cluster-wide, stable info)
|
||||||
|
ext_result = await query_raw_with_schema(
|
||||||
|
"SELECT extname, extversion, nspname as schema "
|
||||||
|
"FROM pg_extension e "
|
||||||
|
"JOIN pg_namespace n ON e.extnamespace = n.oid "
|
||||||
|
"WHERE extname = 'vector'"
|
||||||
|
)
|
||||||
|
diagnostics["pgvector_extension"] = ext_result
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics["pgvector_extension"] = f"Error: {e}"
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"Vector type error diagnostics:\n"
|
||||||
|
f" Error: {error}\n"
|
||||||
|
f" search_path: {diagnostics.get('search_path')}\n"
|
||||||
|
f" current_schema: {diagnostics.get('current_schema')}\n"
|
||||||
|
f" user_info: {diagnostics.get('user_info')}\n"
|
||||||
|
f" pgvector_extension: {diagnostics.get('pgvector_extension')}"
|
||||||
|
)
|
||||||
|
except Exception as diag_error:
|
||||||
|
logger.error(f"Failed to collect vector error diagnostics: {diag_error}")
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
||||||
# for existing code that expects the popularity parameter
|
# for existing code that expects the popularity parameter
|
||||||
HybridSearchWeights = StoreAgentSearchWeights
|
HybridSearchWeights = StoreAgentSearchWeights
|
||||||
|
|||||||
@@ -7,16 +7,7 @@ from replicate.client import Client as ReplicateClient
|
|||||||
from replicate.exceptions import ReplicateError
|
from replicate.exceptions import ReplicateError
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.blocks.ideogram import (
|
from backend.data.graph import GraphBaseMeta
|
||||||
AspectRatio,
|
|
||||||
ColorPalettePreset,
|
|
||||||
IdeogramModelBlock,
|
|
||||||
IdeogramModelName,
|
|
||||||
MagicPromptOption,
|
|
||||||
StyleType,
|
|
||||||
UpscaleOption,
|
|
||||||
)
|
|
||||||
from backend.data.graph import BaseGraph
|
|
||||||
from backend.data.model import CredentialsMetaInput, ProviderName
|
from backend.data.model import CredentialsMetaInput, ProviderName
|
||||||
from backend.integrations.credentials_store import ideogram_credentials
|
from backend.integrations.credentials_store import ideogram_credentials
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
@@ -34,14 +25,14 @@ class ImageStyle(str, Enum):
|
|||||||
DIGITAL_ART = "digital art"
|
DIGITAL_ART = "digital art"
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
||||||
if settings.config.use_agent_image_generation_v2:
|
if settings.config.use_agent_image_generation_v2:
|
||||||
return await generate_agent_image_v2(graph=agent)
|
return await generate_agent_image_v2(graph=agent)
|
||||||
else:
|
else:
|
||||||
return await generate_agent_image_v1(agent=agent)
|
return await generate_agent_image_v1(agent=agent)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Generate an image for an agent using Ideogram model.
|
Generate an image for an agent using Ideogram model.
|
||||||
Returns:
|
Returns:
|
||||||
@@ -50,18 +41,31 @@ async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
|||||||
if not ideogram_credentials.api_key:
|
if not ideogram_credentials.api_key:
|
||||||
raise ValueError("Missing Ideogram API key")
|
raise ValueError("Missing Ideogram API key")
|
||||||
|
|
||||||
|
from backend.blocks.ideogram import (
|
||||||
|
AspectRatio,
|
||||||
|
ColorPalettePreset,
|
||||||
|
IdeogramModelBlock,
|
||||||
|
IdeogramModelName,
|
||||||
|
MagicPromptOption,
|
||||||
|
StyleType,
|
||||||
|
UpscaleOption,
|
||||||
|
)
|
||||||
|
|
||||||
name = graph.name
|
name = graph.name
|
||||||
description = f"{name} ({graph.description})" if graph.description else name
|
description = f"{name} ({graph.description})" if graph.description else name
|
||||||
|
|
||||||
prompt = (
|
prompt = (
|
||||||
f"Create a visually striking retro-futuristic vector pop art illustration prominently featuring "
|
"Create a visually striking retro-futuristic vector pop art illustration "
|
||||||
f'"{name}" in bold typography. The image clearly and literally depicts a {description}, '
|
f'prominently featuring "{name}" in bold typography. The image clearly and '
|
||||||
f"along with recognizable objects directly associated with the primary function of a {name}. "
|
f"literally depicts a {description}, along with recognizable objects directly "
|
||||||
f"Ensure the imagery is concrete, intuitive, and immediately understandable, clearly conveying the "
|
f"associated with the primary function of a {name}. "
|
||||||
f"purpose of a {name}. Maintain vibrant, limited-palette colors, sharp vector lines, geometric "
|
f"Ensure the imagery is concrete, intuitive, and immediately understandable, "
|
||||||
f"shapes, flat illustration techniques, and solid colors without gradients or shading. Preserve a "
|
f"clearly conveying the purpose of a {name}. "
|
||||||
f"retro-futuristic aesthetic influenced by mid-century futurism and 1960s psychedelia, "
|
"Maintain vibrant, limited-palette colors, sharp vector lines, "
|
||||||
f"prioritizing clear visual storytelling and thematic clarity above all else."
|
"geometric shapes, flat illustration techniques, and solid colors "
|
||||||
|
"without gradients or shading. Preserve a retro-futuristic aesthetic "
|
||||||
|
"influenced by mid-century futurism and 1960s psychedelia, "
|
||||||
|
"prioritizing clear visual storytelling and thematic clarity above all else."
|
||||||
)
|
)
|
||||||
|
|
||||||
custom_colors = [
|
custom_colors = [
|
||||||
@@ -99,12 +103,12 @@ async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
|||||||
return io.BytesIO(response.content)
|
return io.BytesIO(response.content)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Generate an image for an agent using Flux model via Replicate API.
|
Generate an image for an agent using Flux model via Replicate API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent (Graph): The agent to generate an image for
|
agent (GraphBaseMeta | AgentGraph): The agent to generate an image for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
io.BytesIO: The generated image as bytes
|
io.BytesIO: The generated image as bytes
|
||||||
@@ -114,7 +118,13 @@ async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
|||||||
raise ValueError("Missing Replicate API key in settings")
|
raise ValueError("Missing Replicate API key in settings")
|
||||||
|
|
||||||
# Construct prompt from agent details
|
# Construct prompt from agent details
|
||||||
prompt = f"Create a visually engaging app store thumbnail for the AI agent that highlights what it does in a clear and captivating way:\n- **Name**: {agent.name}\n- **Description**: {agent.description}\nFocus on showcasing its core functionality with an appealing design."
|
prompt = (
|
||||||
|
"Create a visually engaging app store thumbnail for the AI agent "
|
||||||
|
"that highlights what it does in a clear and captivating way:\n"
|
||||||
|
f"- **Name**: {agent.name}\n"
|
||||||
|
f"- **Description**: {agent.description}\n"
|
||||||
|
f"Focus on showcasing its core functionality with an appealing design."
|
||||||
|
)
|
||||||
|
|
||||||
# Set up Replicate client
|
# Set up Replicate client
|
||||||
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ async def get_agent(
|
|||||||
)
|
)
|
||||||
async def get_graph_meta_by_store_listing_version_id(
|
async def get_graph_meta_by_store_listing_version_id(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
) -> backend.data.graph.GraphMeta:
|
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||||
"""
|
"""
|
||||||
Get Agent Graph from Store Listing Version ID.
|
Get Agent Graph from Store Listing Version ID.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -40,10 +40,11 @@ from backend.api.model import (
|
|||||||
UpdateTimezoneRequest,
|
UpdateTimezoneRequest,
|
||||||
UploadFileResponse,
|
UploadFileResponse,
|
||||||
)
|
)
|
||||||
|
from backend.blocks import get_block, get_blocks
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.auth import api_key as api_key_db
|
from backend.data.auth import api_key as api_key_db
|
||||||
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
|
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||||
from backend.data.credit import (
|
from backend.data.credit import (
|
||||||
AutoTopUpConfig,
|
AutoTopUpConfig,
|
||||||
RefundRequest,
|
RefundRequest,
|
||||||
@@ -101,7 +102,6 @@ from backend.util.timezone_utils import (
|
|||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
|
||||||
from .library import db as library_db
|
from .library import db as library_db
|
||||||
from .library import model as library_model
|
|
||||||
from .store.model import StoreAgentDetails
|
from .store.model import StoreAgentDetails
|
||||||
|
|
||||||
|
|
||||||
@@ -823,18 +823,16 @@ async def update_graph(
|
|||||||
graph: graph_db.Graph,
|
graph: graph_db.Graph,
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> graph_db.GraphModel:
|
) -> graph_db.GraphModel:
|
||||||
# Sanity check
|
|
||||||
if graph.id and graph.id != graph_id:
|
if graph.id and graph.id != graph_id:
|
||||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||||
|
|
||||||
# Determine new version
|
|
||||||
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||||
if not existing_versions:
|
if not existing_versions:
|
||||||
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
||||||
latest_version_number = max(g.version for g in existing_versions)
|
|
||||||
graph.version = latest_version_number + 1
|
|
||||||
|
|
||||||
|
graph.version = max(g.version for g in existing_versions) + 1
|
||||||
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
||||||
|
|
||||||
graph = graph_db.make_graph_model(graph, user_id)
|
graph = graph_db.make_graph_model(graph, user_id)
|
||||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||||
graph.validate_graph(for_run=False)
|
graph.validate_graph(for_run=False)
|
||||||
@@ -842,27 +840,23 @@ async def update_graph(
|
|||||||
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
||||||
|
|
||||||
if new_graph_version.is_active:
|
if new_graph_version.is_active:
|
||||||
# Keep the library agent up to date with the new active version
|
await library_db.update_library_agent_version_and_settings(
|
||||||
await _update_library_agent_version_and_settings(user_id, new_graph_version)
|
user_id, new_graph_version
|
||||||
|
)
|
||||||
# Handle activation of the new graph first to ensure continuity
|
|
||||||
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
||||||
# Ensure new version is the only active version
|
|
||||||
await graph_db.set_graph_active_version(
|
await graph_db.set_graph_active_version(
|
||||||
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
||||||
)
|
)
|
||||||
if current_active_version:
|
if current_active_version:
|
||||||
# Handle deactivation of the previously active version
|
|
||||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||||
|
|
||||||
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
|
|
||||||
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
||||||
graph_id,
|
graph_id,
|
||||||
new_graph_version.version,
|
new_graph_version.version,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
include_subgraphs=True,
|
include_subgraphs=True,
|
||||||
)
|
)
|
||||||
assert new_graph_version_with_subgraphs # make type checker happy
|
assert new_graph_version_with_subgraphs
|
||||||
return new_graph_version_with_subgraphs
|
return new_graph_version_with_subgraphs
|
||||||
|
|
||||||
|
|
||||||
@@ -900,33 +894,15 @@ async def set_graph_active_version(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Keep the library agent up to date with the new active version
|
# Keep the library agent up to date with the new active version
|
||||||
await _update_library_agent_version_and_settings(user_id, new_active_graph)
|
await library_db.update_library_agent_version_and_settings(
|
||||||
|
user_id, new_active_graph
|
||||||
|
)
|
||||||
|
|
||||||
if current_active_graph and current_active_graph.version != new_active_version:
|
if current_active_graph and current_active_graph.version != new_active_version:
|
||||||
# Handle deactivation of the previously active version
|
# Handle deactivation of the previously active version
|
||||||
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
async def _update_library_agent_version_and_settings(
|
|
||||||
user_id: str, agent_graph: graph_db.GraphModel
|
|
||||||
) -> library_model.LibraryAgent:
|
|
||||||
library = await library_db.update_agent_version_in_library(
|
|
||||||
user_id, agent_graph.id, agent_graph.version
|
|
||||||
)
|
|
||||||
updated_settings = GraphSettings.from_graph(
|
|
||||||
graph=agent_graph,
|
|
||||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
|
||||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
|
||||||
)
|
|
||||||
if updated_settings != library.settings:
|
|
||||||
library = await library_db.update_library_agent(
|
|
||||||
library_agent_id=library.id,
|
|
||||||
user_id=user_id,
|
|
||||||
settings=updated_settings,
|
|
||||||
)
|
|
||||||
return library
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.patch(
|
@v1_router.patch(
|
||||||
path="/graphs/{graph_id}/settings",
|
path="/graphs/{graph_id}/settings",
|
||||||
summary="Update graph settings",
|
summary="Update graph settings",
|
||||||
|
|||||||
@@ -40,6 +40,10 @@ import backend.data.user
|
|||||||
import backend.integrations.webhooks.utils
|
import backend.integrations.webhooks.utils
|
||||||
import backend.util.service
|
import backend.util.service
|
||||||
import backend.util.settings
|
import backend.util.settings
|
||||||
|
from backend.api.features.chat.completion_consumer import (
|
||||||
|
start_completion_consumer,
|
||||||
|
stop_completion_consumer,
|
||||||
|
)
|
||||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||||
from backend.data.model import Credentials
|
from backend.data.model import Credentials
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
@@ -118,9 +122,21 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||||
|
|
||||||
|
# Start chat completion consumer for Redis Streams notifications
|
||||||
|
try:
|
||||||
|
await start_completion_consumer()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not start chat completion consumer: {e}")
|
||||||
|
|
||||||
with launch_darkly_context():
|
with launch_darkly_context():
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
# Stop chat completion consumer
|
||||||
|
try:
|
||||||
|
await stop_completion_consumer()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error stopping chat completion consumer: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await shutdown_cloud_storage_handler()
|
await shutdown_cloud_storage_handler()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -3,22 +3,19 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, TypeVar
|
from typing import Sequence, Type, TypeVar
|
||||||
|
|
||||||
|
from backend.blocks._base import AnyBlockSchema, BlockType
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from backend.data.block import Block
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl_seconds=3600)
|
@cached(ttl_seconds=3600)
|
||||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
|
||||||
from backend.data.block import Block
|
from backend.blocks._base import Block
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
|
|
||||||
# Check if example blocks should be loaded from settings
|
# Check if example blocks should be loaded from settings
|
||||||
@@ -50,8 +47,8 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
|||||||
importlib.import_module(f".{module}", package=__name__)
|
importlib.import_module(f".{module}", package=__name__)
|
||||||
|
|
||||||
# Load all Block instances from the available modules
|
# Load all Block instances from the available modules
|
||||||
available_blocks: dict[str, type["Block"]] = {}
|
available_blocks: dict[str, type["AnyBlockSchema"]] = {}
|
||||||
for block_cls in all_subclasses(Block):
|
for block_cls in _all_subclasses(Block):
|
||||||
class_name = block_cls.__name__
|
class_name = block_cls.__name__
|
||||||
|
|
||||||
if class_name.endswith("Base"):
|
if class_name.endswith("Base"):
|
||||||
@@ -64,7 +61,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
|||||||
"please name the class with 'Base' at the end"
|
"please name the class with 'Base' at the end"
|
||||||
)
|
)
|
||||||
|
|
||||||
block = block_cls.create()
|
block = block_cls() # pyright: ignore[reportAbstractUsage]
|
||||||
|
|
||||||
if not isinstance(block.id, str) or len(block.id) != 36:
|
if not isinstance(block.id, str) or len(block.id) != 36:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -105,7 +102,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
|||||||
available_blocks[block.id] = block_cls
|
available_blocks[block.id] = block_cls
|
||||||
|
|
||||||
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
|
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
|
||||||
from backend.data.block import is_block_auth_configured
|
from ._utils import is_block_auth_configured
|
||||||
|
|
||||||
filtered_blocks = {}
|
filtered_blocks = {}
|
||||||
for block_id, block_cls in available_blocks.items():
|
for block_id, block_cls in available_blocks.items():
|
||||||
@@ -115,11 +112,48 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
|||||||
return filtered_blocks
|
return filtered_blocks
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["load_all_blocks"]
|
def _all_subclasses(cls: type[T]) -> list[type[T]]:
|
||||||
|
|
||||||
|
|
||||||
def all_subclasses(cls: type[T]) -> list[type[T]]:
|
|
||||||
subclasses = cls.__subclasses__()
|
subclasses = cls.__subclasses__()
|
||||||
for subclass in subclasses:
|
for subclass in subclasses:
|
||||||
subclasses += all_subclasses(subclass)
|
subclasses += _all_subclasses(subclass)
|
||||||
return subclasses
|
return subclasses
|
||||||
|
|
||||||
|
|
||||||
|
# ============== Block access helper functions ============== #
|
||||||
|
|
||||||
|
|
||||||
|
def get_blocks() -> dict[str, Type["AnyBlockSchema"]]:
|
||||||
|
return load_all_blocks()
|
||||||
|
|
||||||
|
|
||||||
|
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||||
|
def get_block(block_id: str) -> "AnyBlockSchema | None":
|
||||||
|
cls = get_blocks().get(block_id)
|
||||||
|
return cls() if cls else None
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=3600)
|
||||||
|
def get_webhook_block_ids() -> Sequence[str]:
|
||||||
|
return [
|
||||||
|
id
|
||||||
|
for id, B in get_blocks().items()
|
||||||
|
if B().block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=3600)
|
||||||
|
def get_io_block_ids() -> Sequence[str]:
|
||||||
|
return [
|
||||||
|
id
|
||||||
|
for id, B in get_blocks().items()
|
||||||
|
if B().block_type in (BlockType.INPUT, BlockType.OUTPUT)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=3600)
|
||||||
|
def get_human_in_the_loop_block_ids() -> Sequence[str]:
|
||||||
|
return [
|
||||||
|
id
|
||||||
|
for id, B in get_blocks().items()
|
||||||
|
if B().block_type == BlockType.HUMAN_IN_THE_LOOP
|
||||||
|
]
|
||||||
|
|||||||
739
autogpt_platform/backend/backend/blocks/_base.py
Normal file
739
autogpt_platform/backend/backend/blocks/_base.py
Normal file
@@ -0,0 +1,739 @@
|
|||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
ClassVar,
|
||||||
|
Generic,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
TypeAlias,
|
||||||
|
TypeVar,
|
||||||
|
cast,
|
||||||
|
get_origin,
|
||||||
|
)
|
||||||
|
|
||||||
|
import jsonref
|
||||||
|
import jsonschema
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
|
||||||
|
from backend.data.model import (
|
||||||
|
Credentials,
|
||||||
|
CredentialsFieldInfo,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
SchemaField,
|
||||||
|
is_credentials_field_name,
|
||||||
|
)
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util import json
|
||||||
|
from backend.util.exceptions import (
|
||||||
|
BlockError,
|
||||||
|
BlockExecutionError,
|
||||||
|
BlockInputError,
|
||||||
|
BlockOutputError,
|
||||||
|
BlockUnknownError,
|
||||||
|
)
|
||||||
|
from backend.util.settings import Config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import ContributorDetails, NodeExecutionStats
|
||||||
|
|
||||||
|
from ..data.graph import Link
|
||||||
|
|
||||||
|
app_config = Config()
|
||||||
|
|
||||||
|
|
||||||
|
BlockTestOutput = BlockOutputEntry | tuple[str, Callable[[Any], bool]]
|
||||||
|
|
||||||
|
|
||||||
|
class BlockType(Enum):
|
||||||
|
STANDARD = "Standard"
|
||||||
|
INPUT = "Input"
|
||||||
|
OUTPUT = "Output"
|
||||||
|
NOTE = "Note"
|
||||||
|
WEBHOOK = "Webhook"
|
||||||
|
WEBHOOK_MANUAL = "Webhook (manual)"
|
||||||
|
AGENT = "Agent"
|
||||||
|
AI = "AI"
|
||||||
|
AYRSHARE = "Ayrshare"
|
||||||
|
HUMAN_IN_THE_LOOP = "Human In The Loop"
|
||||||
|
|
||||||
|
|
||||||
|
class BlockCategory(Enum):
|
||||||
|
AI = "Block that leverages AI to perform a task."
|
||||||
|
SOCIAL = "Block that interacts with social media platforms."
|
||||||
|
TEXT = "Block that processes text data."
|
||||||
|
SEARCH = "Block that searches or extracts information from the internet."
|
||||||
|
BASIC = "Block that performs basic operations."
|
||||||
|
INPUT = "Block that interacts with input of the graph."
|
||||||
|
OUTPUT = "Block that interacts with output of the graph."
|
||||||
|
LOGIC = "Programming logic to control the flow of your agent"
|
||||||
|
COMMUNICATION = "Block that interacts with communication platforms."
|
||||||
|
DEVELOPER_TOOLS = "Developer tools such as GitHub blocks."
|
||||||
|
DATA = "Block that interacts with structured data."
|
||||||
|
HARDWARE = "Block that interacts with hardware."
|
||||||
|
AGENT = "Block that interacts with other agents."
|
||||||
|
CRM = "Block that interacts with CRM services."
|
||||||
|
SAFETY = (
|
||||||
|
"Block that provides AI safety mechanisms such as detecting harmful content"
|
||||||
|
)
|
||||||
|
PRODUCTIVITY = "Block that helps with productivity"
|
||||||
|
ISSUE_TRACKING = "Block that helps with issue tracking"
|
||||||
|
MULTIMEDIA = "Block that interacts with multimedia content"
|
||||||
|
MARKETING = "Block that helps with marketing"
|
||||||
|
|
||||||
|
def dict(self) -> dict[str, str]:
|
||||||
|
return {"category": self.name, "description": self.value}
|
||||||
|
|
||||||
|
|
||||||
|
class BlockCostType(str, Enum):
|
||||||
|
RUN = "run" # cost X credits per run
|
||||||
|
BYTE = "byte" # cost X credits per byte
|
||||||
|
SECOND = "second" # cost X credits per second
|
||||||
|
|
||||||
|
|
||||||
|
class BlockCost(BaseModel):
|
||||||
|
cost_amount: int
|
||||||
|
cost_filter: BlockInput
|
||||||
|
cost_type: BlockCostType
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cost_amount: int,
|
||||||
|
cost_type: BlockCostType = BlockCostType.RUN,
|
||||||
|
cost_filter: Optional[BlockInput] = None,
|
||||||
|
**data: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
cost_amount=cost_amount,
|
||||||
|
cost_filter=cost_filter or {},
|
||||||
|
cost_type=cost_type,
|
||||||
|
**data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockInfo(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
inputSchema: dict[str, Any]
|
||||||
|
outputSchema: dict[str, Any]
|
||||||
|
costs: list[BlockCost]
|
||||||
|
description: str
|
||||||
|
categories: list[dict[str, str]]
|
||||||
|
contributors: list[dict[str, Any]]
|
||||||
|
staticOutput: bool
|
||||||
|
uiType: str
|
||||||
|
|
||||||
|
|
||||||
|
class BlockSchema(BaseModel):
|
||||||
|
cached_jsonschema: ClassVar[dict[str, Any]]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def jsonschema(cls) -> dict[str, Any]:
|
||||||
|
if cls.cached_jsonschema:
|
||||||
|
return cls.cached_jsonschema
|
||||||
|
|
||||||
|
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
|
||||||
|
|
||||||
|
def ref_to_dict(obj):
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
# OpenAPI <3.1 does not support sibling fields that has a $ref key
|
||||||
|
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
|
||||||
|
keys = {"allOf", "anyOf", "oneOf"}
|
||||||
|
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
|
||||||
|
if one_key:
|
||||||
|
obj.update(obj[one_key][0])
|
||||||
|
|
||||||
|
return {
|
||||||
|
key: ref_to_dict(value)
|
||||||
|
for key, value in obj.items()
|
||||||
|
if not key.startswith("$") and key != one_key
|
||||||
|
}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [ref_to_dict(item) for item in obj]
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||||
|
|
||||||
|
return cls.cached_jsonschema
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_data(cls, data: BlockInput) -> str | None:
|
||||||
|
return json.validate_with_jsonschema(
|
||||||
|
schema=cls.jsonschema(),
|
||||||
|
data={k: v for k, v in data.items() if v is not None},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||||
|
return cls.validate_data(data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_field_schema(cls, field_name: str) -> dict[str, Any]:
|
||||||
|
model_schema = cls.jsonschema().get("properties", {})
|
||||||
|
if not model_schema:
|
||||||
|
raise ValueError(f"Invalid model schema {cls}")
|
||||||
|
|
||||||
|
property_schema = model_schema.get(field_name)
|
||||||
|
if not property_schema:
|
||||||
|
raise ValueError(f"Invalid property name {field_name}")
|
||||||
|
|
||||||
|
return property_schema
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:
|
||||||
|
"""
|
||||||
|
Validate the data against a specific property (one of the input/output name).
|
||||||
|
Returns the validation error message if the data does not match the schema.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
property_schema = cls.get_field_schema(field_name)
|
||||||
|
jsonschema.validate(json.to_dict(data), property_schema)
|
||||||
|
return None
|
||||||
|
except jsonschema.ValidationError as e:
|
||||||
|
return str(e)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_fields(cls) -> set[str]:
|
||||||
|
return set(cls.model_fields.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_required_fields(cls) -> set[str]:
|
||||||
|
return {
|
||||||
|
field
|
||||||
|
for field, field_info in cls.model_fields.items()
|
||||||
|
if field_info.is_required()
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __pydantic_init_subclass__(cls, **kwargs):
|
||||||
|
"""Validates the schema definition. Rules:
|
||||||
|
- Fields with annotation `CredentialsMetaInput` MUST be
|
||||||
|
named `credentials` or `*_credentials`
|
||||||
|
- Fields named `credentials` or `*_credentials` MUST be
|
||||||
|
of type `CredentialsMetaInput`
|
||||||
|
"""
|
||||||
|
super().__pydantic_init_subclass__(**kwargs)
|
||||||
|
|
||||||
|
# Reset cached JSON schema to prevent inheriting it from parent class
|
||||||
|
cls.cached_jsonschema = {}
|
||||||
|
|
||||||
|
credentials_fields = cls.get_credentials_fields()
|
||||||
|
|
||||||
|
for field_name in cls.get_fields():
|
||||||
|
if is_credentials_field_name(field_name):
|
||||||
|
if field_name not in credentials_fields:
|
||||||
|
raise TypeError(
|
||||||
|
f"Credentials field '{field_name}' on {cls.__qualname__} "
|
||||||
|
f"is not of type {CredentialsMetaInput.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
CredentialsMetaInput.validate_credentials_field_schema(
|
||||||
|
cls.get_field_schema(field_name), field_name
|
||||||
|
)
|
||||||
|
|
||||||
|
elif field_name in credentials_fields:
|
||||||
|
raise KeyError(
|
||||||
|
f"Credentials field '{field_name}' on {cls.__qualname__} "
|
||||||
|
"has invalid name: must be 'credentials' or *_credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_credentials_fields(cls) -> dict[str, type[CredentialsMetaInput]]:
|
||||||
|
return {
|
||||||
|
field_name: info.annotation
|
||||||
|
for field_name, info in cls.model_fields.items()
|
||||||
|
if (
|
||||||
|
inspect.isclass(info.annotation)
|
||||||
|
and issubclass(
|
||||||
|
get_origin(info.annotation) or info.annotation,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_auto_credentials_fields(cls) -> dict[str, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get fields that have auto_credentials metadata (e.g., GoogleDriveFileInput).
|
||||||
|
|
||||||
|
Returns a dict mapping kwarg_name -> {field_name, auto_credentials_config}
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If multiple fields have the same kwarg_name, as this would
|
||||||
|
cause silent overwriting and only the last field would be processed.
|
||||||
|
"""
|
||||||
|
result: dict[str, dict[str, Any]] = {}
|
||||||
|
schema = cls.jsonschema()
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
|
||||||
|
for field_name, field_schema in properties.items():
|
||||||
|
auto_creds = field_schema.get("auto_credentials")
|
||||||
|
if auto_creds:
|
||||||
|
kwarg_name = auto_creds.get("kwarg_name", "credentials")
|
||||||
|
if kwarg_name in result:
|
||||||
|
raise ValueError(
|
||||||
|
f"Duplicate auto_credentials kwarg_name '{kwarg_name}' "
|
||||||
|
f"in fields '{result[kwarg_name]['field_name']}' and "
|
||||||
|
f"'{field_name}' on {cls.__qualname__}"
|
||||||
|
)
|
||||||
|
result[kwarg_name] = {
|
||||||
|
"field_name": field_name,
|
||||||
|
"config": auto_creds,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# Regular credentials fields
|
||||||
|
for field_name in cls.get_credentials_fields().keys():
|
||||||
|
result[field_name] = CredentialsFieldInfo.model_validate(
|
||||||
|
cls.get_field_schema(field_name), by_alias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Auto-generated credentials fields (from GoogleDriveFileInput etc.)
|
||||||
|
for kwarg_name, info in cls.get_auto_credentials_fields().items():
|
||||||
|
config = info["config"]
|
||||||
|
# Build a schema-like dict that CredentialsFieldInfo can parse
|
||||||
|
auto_schema = {
|
||||||
|
"credentials_provider": [config.get("provider", "google")],
|
||||||
|
"credentials_types": [config.get("type", "oauth2")],
|
||||||
|
"credentials_scopes": config.get("scopes"),
|
||||||
|
}
|
||||||
|
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
||||||
|
auto_schema, by_alias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||||
|
return data # Return as is, by default.
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
|
||||||
|
input_fields_from_nodes = {link.sink_name for link in links}
|
||||||
|
return input_fields_from_nodes - set(data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||||
|
return cls.get_required_fields() - set(data)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockSchemaInput(BlockSchema):
|
||||||
|
"""
|
||||||
|
Base schema class for block inputs.
|
||||||
|
All block input schemas should extend this class for consistency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BlockSchemaOutput(BlockSchema):
|
||||||
|
"""
|
||||||
|
Base schema class for block outputs that includes a standard error field.
|
||||||
|
All block output schemas should extend this class to ensure consistent error handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
error: str = SchemaField(
|
||||||
|
description="Error message if the operation failed", default=""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchemaInput)
|
||||||
|
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchemaOutput)
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyInputSchema(BlockSchemaInput):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyOutputSchema(BlockSchemaOutput):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility - will be deprecated
|
||||||
|
EmptySchema = EmptyOutputSchema
|
||||||
|
|
||||||
|
|
||||||
|
# --8<-- [start:BlockWebhookConfig]
|
||||||
|
class BlockManualWebhookConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration model for webhook-triggered blocks on which
|
||||||
|
the user has to manually set up the webhook at the provider.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: ProviderName
|
||||||
|
"""The service provider that the webhook connects to"""
|
||||||
|
|
||||||
|
webhook_type: str
|
||||||
|
"""
|
||||||
|
Identifier for the webhook type. E.g. GitHub has repo and organization level hooks.
|
||||||
|
|
||||||
|
Only for use in the corresponding `WebhooksManager`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
event_filter_input: str = ""
|
||||||
|
"""
|
||||||
|
Name of the block's event filter input.
|
||||||
|
Leave empty if the corresponding webhook doesn't have distinct event/payload types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
event_format: str = "{event}"
|
||||||
|
"""
|
||||||
|
Template string for the event(s) that a block instance subscribes to.
|
||||||
|
Applied individually to each event selected in the event filter input.
|
||||||
|
|
||||||
|
Example: `"pull_request.{event}"` -> `"pull_request.opened"`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||||
|
"""
|
||||||
|
Configuration model for webhook-triggered blocks for which
|
||||||
|
the webhook can be automatically set up through the provider's API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
resource_format: str
|
||||||
|
"""
|
||||||
|
Template string for the resource that a block instance subscribes to.
|
||||||
|
Fields will be filled from the block's inputs (except `payload`).
|
||||||
|
|
||||||
|
Example: `f"{repo}/pull_requests"` (note: not how it's actually implemented)
|
||||||
|
|
||||||
|
Only for use in the corresponding `WebhooksManager`.
|
||||||
|
"""
|
||||||
|
# --8<-- [end:BlockWebhookConfig]
|
||||||
|
|
||||||
|
|
||||||
|
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: str = "",
|
||||||
|
description: str = "",
|
||||||
|
contributors: list["ContributorDetails"] = [],
|
||||||
|
categories: set[BlockCategory] | None = None,
|
||||||
|
input_schema: Type[BlockSchemaInputType] = EmptyInputSchema,
|
||||||
|
output_schema: Type[BlockSchemaOutputType] = EmptyOutputSchema,
|
||||||
|
test_input: BlockInput | list[BlockInput] | None = None,
|
||||||
|
test_output: BlockTestOutput | list[BlockTestOutput] | None = None,
|
||||||
|
test_mock: dict[str, Any] | None = None,
|
||||||
|
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
|
||||||
|
disabled: bool = False,
|
||||||
|
static_output: bool = False,
|
||||||
|
block_type: BlockType = BlockType.STANDARD,
|
||||||
|
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
|
||||||
|
is_sensitive_action: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the block with the given schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id: The unique identifier for the block, this value will be persisted in the
|
||||||
|
DB. So it should be a unique and constant across the application run.
|
||||||
|
Use the UUID format for the ID.
|
||||||
|
description: The description of the block, explaining what the block does.
|
||||||
|
contributors: The list of contributors who contributed to the block.
|
||||||
|
input_schema: The schema, defined as a Pydantic model, for the input data.
|
||||||
|
output_schema: The schema, defined as a Pydantic model, for the output data.
|
||||||
|
test_input: The list or single sample input data for the block, for testing.
|
||||||
|
test_output: The list or single expected output if the test_input is run.
|
||||||
|
test_mock: function names on the block implementation to mock on test run.
|
||||||
|
disabled: If the block is disabled, it will not be available for execution.
|
||||||
|
static_output: Whether the output links of the block are static by default.
|
||||||
|
"""
|
||||||
|
from backend.data.model import NodeExecutionStats
|
||||||
|
|
||||||
|
self.id = id
|
||||||
|
self.input_schema = input_schema
|
||||||
|
self.output_schema = output_schema
|
||||||
|
self.test_input = test_input
|
||||||
|
self.test_output = test_output
|
||||||
|
self.test_mock = test_mock
|
||||||
|
self.test_credentials = test_credentials
|
||||||
|
self.description = description
|
||||||
|
self.categories = categories or set()
|
||||||
|
self.contributors = contributors or set()
|
||||||
|
self.disabled = disabled
|
||||||
|
self.static_output = static_output
|
||||||
|
self.block_type = block_type
|
||||||
|
self.webhook_config = webhook_config
|
||||||
|
self.is_sensitive_action = is_sensitive_action
|
||||||
|
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
||||||
|
|
||||||
|
if self.webhook_config:
|
||||||
|
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||||
|
# Enforce presence of credentials field on auto-setup webhook blocks
|
||||||
|
if not (cred_fields := self.input_schema.get_credentials_fields()):
|
||||||
|
raise TypeError(
|
||||||
|
"credentials field is required on auto-setup webhook blocks"
|
||||||
|
)
|
||||||
|
# Disallow multiple credentials inputs on webhook blocks
|
||||||
|
elif len(cred_fields) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Multiple credentials inputs not supported on webhook blocks"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.block_type = BlockType.WEBHOOK
|
||||||
|
else:
|
||||||
|
self.block_type = BlockType.WEBHOOK_MANUAL
|
||||||
|
|
||||||
|
# Enforce shape of webhook event filter, if present
|
||||||
|
if self.webhook_config.event_filter_input:
|
||||||
|
event_filter_field = self.input_schema.model_fields[
|
||||||
|
self.webhook_config.event_filter_input
|
||||||
|
]
|
||||||
|
if not (
|
||||||
|
isinstance(event_filter_field.annotation, type)
|
||||||
|
and issubclass(event_filter_field.annotation, BaseModel)
|
||||||
|
and all(
|
||||||
|
field.annotation is bool
|
||||||
|
for field in event_filter_field.annotation.model_fields.values()
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.name} has an invalid webhook event selector: "
|
||||||
|
"field must be a BaseModel and all its fields must be boolean"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enforce presence of 'payload' input
|
||||||
|
if "payload" not in self.input_schema.model_fields:
|
||||||
|
raise TypeError(
|
||||||
|
f"{self.name} is webhook-triggered but has no 'payload' input"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Disable webhook-triggered block if webhook functionality not available
|
||||||
|
if not app_config.platform_base_url:
|
||||||
|
self.disabled = True
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def run(self, input_data: BlockSchemaInputType, **kwargs) -> BlockOutput:
|
||||||
|
"""
|
||||||
|
Run the block with the given input data.
|
||||||
|
Args:
|
||||||
|
input_data: The input data with the structure of input_schema.
|
||||||
|
|
||||||
|
Kwargs: Currently 14/02/2025 these include
|
||||||
|
graph_id: The ID of the graph.
|
||||||
|
node_id: The ID of the node.
|
||||||
|
graph_exec_id: The ID of the graph execution.
|
||||||
|
node_exec_id: The ID of the node execution.
|
||||||
|
user_id: The ID of the user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Generator that yields (output_name, output_data).
|
||||||
|
output_name: One of the output name defined in Block's output_schema.
|
||||||
|
output_data: The data for the output_name, matching the defined schema.
|
||||||
|
"""
|
||||||
|
# --- satisfy the type checker, never executed -------------
|
||||||
|
if False: # noqa: SIM115
|
||||||
|
yield "name", "value" # pyright: ignore[reportMissingYield]
|
||||||
|
raise NotImplementedError(f"{self.name} does not implement the run method.")
|
||||||
|
|
||||||
|
async def run_once(
|
||||||
|
self, input_data: BlockSchemaInputType, output: str, **kwargs
|
||||||
|
) -> Any:
|
||||||
|
async for item in self.run(input_data, **kwargs):
|
||||||
|
name, data = item
|
||||||
|
if name == output:
|
||||||
|
return data
|
||||||
|
raise ValueError(f"{self.name} did not produce any output for {output}")
|
||||||
|
|
||||||
|
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
|
||||||
|
self.execution_stats += stats
|
||||||
|
return self.execution_stats
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"inputSchema": self.input_schema.jsonschema(),
|
||||||
|
"outputSchema": self.output_schema.jsonschema(),
|
||||||
|
"description": self.description,
|
||||||
|
"categories": [category.dict() for category in self.categories],
|
||||||
|
"contributors": [
|
||||||
|
contributor.model_dump() for contributor in self.contributors
|
||||||
|
],
|
||||||
|
"staticOutput": self.static_output,
|
||||||
|
"uiType": self.block_type.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_info(self) -> BlockInfo:
|
||||||
|
from backend.data.credit import get_block_cost
|
||||||
|
|
||||||
|
return BlockInfo(
|
||||||
|
id=self.id,
|
||||||
|
name=self.name,
|
||||||
|
inputSchema=self.input_schema.jsonschema(),
|
||||||
|
outputSchema=self.output_schema.jsonschema(),
|
||||||
|
costs=get_block_cost(self),
|
||||||
|
description=self.description,
|
||||||
|
categories=[category.dict() for category in self.categories],
|
||||||
|
contributors=[
|
||||||
|
contributor.model_dump() for contributor in self.contributors
|
||||||
|
],
|
||||||
|
staticOutput=self.static_output,
|
||||||
|
uiType=self.block_type.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
async for output_name, output_data in self._execute(input_data, **kwargs):
|
||||||
|
yield output_name, output_data
|
||||||
|
except Exception as ex:
|
||||||
|
if isinstance(ex, BlockError):
|
||||||
|
raise ex
|
||||||
|
else:
|
||||||
|
raise (
|
||||||
|
BlockExecutionError
|
||||||
|
if isinstance(ex, ValueError)
|
||||||
|
else BlockUnknownError
|
||||||
|
)(
|
||||||
|
message=str(ex),
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
) from ex
|
||||||
|
|
||||||
|
async def is_block_exec_need_review(
|
||||||
|
self,
|
||||||
|
input_data: BlockInput,
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
node_id: str,
|
||||||
|
node_exec_id: str,
|
||||||
|
graph_exec_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
execution_context: "ExecutionContext",
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[bool, BlockInput]:
|
||||||
|
"""
|
||||||
|
Check if this block execution needs human review and handle the review process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (should_pause, input_data_to_use)
|
||||||
|
- should_pause: True if execution should be paused for review
|
||||||
|
- input_data_to_use: The input data to use (may be modified by reviewer)
|
||||||
|
"""
|
||||||
|
if not (
|
||||||
|
self.is_sensitive_action and execution_context.sensitive_action_safe_mode
|
||||||
|
):
|
||||||
|
return False, input_data
|
||||||
|
|
||||||
|
from backend.blocks.helpers.review import HITLReviewHelper
|
||||||
|
|
||||||
|
# Handle the review request and get decision
|
||||||
|
decision = await HITLReviewHelper.handle_review_decision(
|
||||||
|
input_data=input_data,
|
||||||
|
user_id=user_id,
|
||||||
|
node_id=node_id,
|
||||||
|
node_exec_id=node_exec_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_version=graph_version,
|
||||||
|
block_name=self.name,
|
||||||
|
editable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if decision is None:
|
||||||
|
# We're awaiting review - pause execution
|
||||||
|
return True, input_data
|
||||||
|
|
||||||
|
if not decision.should_proceed:
|
||||||
|
# Review was rejected, raise an error to stop execution
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Block execution rejected by reviewer: {decision.message}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Review was approved - use the potentially modified data
|
||||||
|
# ReviewResult.data must be a dict for block inputs
|
||||||
|
reviewed_data = decision.review_result.data
|
||||||
|
if not isinstance(reviewed_data, dict):
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
return False, reviewed_data
|
||||||
|
|
||||||
|
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||||
|
# Check for review requirement only if running within a graph execution context
|
||||||
|
# Direct block execution (e.g., from chat) skips the review process
|
||||||
|
has_graph_context = all(
|
||||||
|
key in kwargs
|
||||||
|
for key in (
|
||||||
|
"node_exec_id",
|
||||||
|
"graph_exec_id",
|
||||||
|
"graph_id",
|
||||||
|
"execution_context",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if has_graph_context:
|
||||||
|
should_pause, input_data = await self.is_block_exec_need_review(
|
||||||
|
input_data, **kwargs
|
||||||
|
)
|
||||||
|
if should_pause:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Validate the input data (original or reviewer-modified) once
|
||||||
|
if error := self.input_schema.validate_data(input_data):
|
||||||
|
raise BlockInputError(
|
||||||
|
message=f"Unable to execute block with invalid input data: {error}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the validated input data
|
||||||
|
async for output_name, output_data in self.run(
|
||||||
|
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if output_name == "error":
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=output_data, block_name=self.name, block_id=self.id
|
||||||
|
)
|
||||||
|
if self.block_type == BlockType.STANDARD and (
|
||||||
|
error := self.output_schema.validate_field(output_name, output_data)
|
||||||
|
):
|
||||||
|
raise BlockOutputError(
|
||||||
|
message=f"Block produced an invalid output data: {error}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
yield output_name, output_data
|
||||||
|
|
||||||
|
def is_triggered_by_event_type(
|
||||||
|
self, trigger_config: dict[str, Any], event_type: str
|
||||||
|
) -> bool:
|
||||||
|
if not self.webhook_config:
|
||||||
|
raise TypeError("This method can't be used on non-trigger blocks")
|
||||||
|
if not self.webhook_config.event_filter_input:
|
||||||
|
return True
|
||||||
|
event_filter = trigger_config.get(self.webhook_config.event_filter_input)
|
||||||
|
if not event_filter:
|
||||||
|
raise ValueError("Event filter is not configured on trigger")
|
||||||
|
return event_type in [
|
||||||
|
self.webhook_config.event_format.format(event=k)
|
||||||
|
for k in event_filter
|
||||||
|
if event_filter[k] is True
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for any block with standard input/output schemas
|
||||||
|
AnyBlockSchema: TypeAlias = Block[BlockSchemaInput, BlockSchemaOutput]
|
||||||
122
autogpt_platform/backend/backend/blocks/_utils.py
Normal file
122
autogpt_platform/backend/backend/blocks/_utils.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
from ._base import AnyBlockSchema
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_block_auth_configured(
|
||||||
|
block_cls: type[AnyBlockSchema],
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a block has a valid authentication method configured at runtime.
|
||||||
|
|
||||||
|
For example if a block is an OAuth-only block and there env vars are not set,
|
||||||
|
do not show it in the UI.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from backend.sdk.registry import AutoRegistry
|
||||||
|
|
||||||
|
# Create an instance to access input_schema
|
||||||
|
try:
|
||||||
|
block = block_cls()
|
||||||
|
except Exception as e:
|
||||||
|
# If we can't create a block instance, assume it's not OAuth-only
|
||||||
|
logger.error(f"Error creating block instance for {block_cls.__name__}: {e}")
|
||||||
|
return True
|
||||||
|
logger.debug(
|
||||||
|
f"Checking if block {block_cls.__name__} has a valid provider configured"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all credential inputs from input schema
|
||||||
|
credential_inputs = block.input_schema.get_credentials_fields_info()
|
||||||
|
required_inputs = block.input_schema.get_required_fields()
|
||||||
|
if not credential_inputs:
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} has no credential inputs - Treating as valid"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check credential inputs
|
||||||
|
if len(required_inputs.intersection(credential_inputs.keys())) == 0:
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} has only optional credential inputs"
|
||||||
|
" - will work without credentials configured"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the credential inputs for this block are correctly configured
|
||||||
|
for field_name, field_info in credential_inputs.items():
|
||||||
|
provider_names = field_info.provider
|
||||||
|
if not provider_names:
|
||||||
|
logger.warning(
|
||||||
|
f"Block {block_cls.__name__} "
|
||||||
|
f"has credential input '{field_name}' with no provider options"
|
||||||
|
" - Disabling"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If a field has multiple possible providers, each one needs to be usable to
|
||||||
|
# prevent breaking the UX
|
||||||
|
for _provider_name in provider_names:
|
||||||
|
provider_name = _provider_name.value
|
||||||
|
if provider_name in ProviderName.__members__.values():
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"provider '{provider_name}' is part of the legacy provider system"
|
||||||
|
" - Treating as valid"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
provider = AutoRegistry.get_provider(provider_name)
|
||||||
|
if not provider:
|
||||||
|
logger.warning(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"refers to unknown provider '{provider_name}' - Disabling"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check the provider's supported auth types
|
||||||
|
if field_info.supported_types != provider.supported_auth_types:
|
||||||
|
logger.warning(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"has mismatched supported auth types (field <> Provider): "
|
||||||
|
f"{field_info.supported_types} != {provider.supported_auth_types}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (supported_auth_types := provider.supported_auth_types):
|
||||||
|
# No auth methods are been configured for this provider
|
||||||
|
logger.warning(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"provider '{provider_name}' "
|
||||||
|
"has no authentication methods configured - Disabling"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if provider supports OAuth
|
||||||
|
if "oauth2" in supported_auth_types:
|
||||||
|
# Check if OAuth environment variables are set
|
||||||
|
if (oauth_config := provider.oauth_config) and bool(
|
||||||
|
os.getenv(oauth_config.client_id_env_var)
|
||||||
|
and os.getenv(oauth_config.client_secret_env_var)
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"provider '{provider_name}' is configured for OAuth"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"provider '{provider_name}' "
|
||||||
|
"is missing OAuth client ID or secret - Disabling"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' is valid; "
|
||||||
|
f"supported credential types: {', '.join(field_info.supported_types)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockInput,
|
BlockInput,
|
||||||
@@ -9,13 +9,15 @@ from backend.data.block import (
|
|||||||
BlockSchema,
|
BlockSchema,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockType,
|
BlockType,
|
||||||
get_block,
|
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks
|
from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks
|
||||||
from backend.data.model import NodeExecutionStats, SchemaField
|
from backend.data.model import NodeExecutionStats, SchemaField
|
||||||
from backend.util.json import validate_with_jsonschema
|
from backend.util.json import validate_with_jsonschema
|
||||||
from backend.util.retry import func_retry
|
from backend.util.retry import func_retry
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.executor.utils import LogMetadata
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -124,9 +126,10 @@ class AgentExecutorBlock(Block):
|
|||||||
graph_version: int,
|
graph_version: int,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
logger,
|
logger: "LogMetadata",
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
|
from backend.blocks import get_block
|
||||||
from backend.data.execution import ExecutionEventType
|
from backend.data.execution import ExecutionEventType
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
|
|
||||||
@@ -198,7 +201,7 @@ class AgentExecutorBlock(Block):
|
|||||||
self,
|
self,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
logger,
|
logger: "LogMetadata",
|
||||||
) -> None:
|
) -> None:
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,11 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.blocks._base import (
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.llm import (
|
from backend.blocks.llm import (
|
||||||
DEFAULT_LLM_MODEL,
|
DEFAULT_LLM_MODEL,
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -11,12 +17,6 @@ from backend.blocks.llm import (
|
|||||||
LLMResponse,
|
LLMResponse,
|
||||||
llm_call,
|
llm_call,
|
||||||
)
|
)
|
||||||
from backend.data.block import (
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from pydantic import SecretStr
|
|||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -5,7 +5,12 @@ from pydantic import SecretStr
|
|||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Literal
|
|||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Literal
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,3 +1,10 @@
|
|||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.apollo._api import ApolloClient
|
from backend.blocks.apollo._api import ApolloClient
|
||||||
from backend.blocks.apollo._auth import (
|
from backend.blocks.apollo._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -10,13 +17,6 @@ from backend.blocks.apollo.models import (
|
|||||||
PrimaryPhone,
|
PrimaryPhone,
|
||||||
SearchOrganizationsRequest,
|
SearchOrganizationsRequest,
|
||||||
)
|
)
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.apollo._api import ApolloClient
|
from backend.blocks.apollo._api import ApolloClient
|
||||||
from backend.blocks.apollo._auth import (
|
from backend.blocks.apollo._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -14,13 +21,6 @@ from backend.blocks.apollo.models import (
|
|||||||
SearchPeopleRequest,
|
SearchPeopleRequest,
|
||||||
SenorityLevels,
|
SenorityLevels,
|
||||||
)
|
)
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,10 @@
|
|||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.apollo._api import ApolloClient
|
from backend.blocks.apollo._api import ApolloClient
|
||||||
from backend.blocks.apollo._auth import (
|
from backend.blocks.apollo._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -6,13 +13,6 @@ from backend.blocks.apollo._auth import (
|
|||||||
ApolloCredentialsInput,
|
ApolloCredentialsInput,
|
||||||
)
|
)
|
||||||
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.data.block import BlockSchemaInput
|
from backend.blocks._base import BlockSchemaInput
|
||||||
from backend.data.model import SchemaField, UserIntegrations
|
from backend.data.model import SchemaField, UserIntegrations
|
||||||
from backend.integrations.ayrshare import AyrshareClient
|
from backend.integrations.ayrshare import AyrshareClient
|
||||||
from backend.util.clients import get_database_manager_async_client
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import enum
|
import enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Literal, Optional
|
|||||||
from e2b import AsyncSandbox as BaseAsyncSandbox
|
from e2b import AsyncSandbox as BaseAsyncSandbox
|
||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from e2b_code_interpreter import Result as E2BExecutionResult
|
|||||||
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
||||||
from pydantic import BaseModel, Field, JsonValue, SecretStr
|
from pydantic import BaseModel, Field, JsonValue, SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from openai import AsyncOpenAI
|
|||||||
from openai.types.responses import Response as OpenAIResponse
|
from openai.types.responses import Response as OpenAIResponse
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockManualWebhookConfig,
|
BlockManualWebhookConfig,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import codecs
|
import codecs
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Any, Literal, cast
|
|||||||
import discord
|
import discord
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
Discord OAuth-based blocks.
|
Discord OAuth-based blocks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
28
autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py
Normal file
28
autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""ElevenLabs integration blocks - test credentials and shared utilities."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
|
provider="elevenlabs",
|
||||||
|
api_key=SecretStr("mock-elevenlabs-api-key"),
|
||||||
|
title="Mock ElevenLabs API key",
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_CREDENTIALS_INPUT = {
|
||||||
|
"provider": TEST_CREDENTIALS.provider,
|
||||||
|
"id": TEST_CREDENTIALS.id,
|
||||||
|
"type": TEST_CREDENTIALS.type,
|
||||||
|
"title": TEST_CREDENTIALS.title,
|
||||||
|
}
|
||||||
|
|
||||||
|
ElevenLabsCredentials = APIKeyCredentials
|
||||||
|
ElevenLabsCredentialsInput = CredentialsMetaInput[
|
||||||
|
Literal[ProviderName.ELEVENLABS], Literal["api_key"]
|
||||||
|
]
|
||||||
@@ -7,7 +7,7 @@ from typing import Literal
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, SecretStr
|
from pydantic import BaseModel, ConfigDict, SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
77
autogpt_platform/backend/backend/blocks/encoder_block.py
Normal file
77
autogpt_platform/backend/backend/blocks/encoder_block.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""Text encoding block for converting special characters to escape sequences."""
|
||||||
|
|
||||||
|
import codecs
|
||||||
|
|
||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncoderBlock(Block):
|
||||||
|
"""
|
||||||
|
Encodes a string by converting special characters into escape sequences.
|
||||||
|
|
||||||
|
This block is the inverse of TextDecoderBlock. It takes text containing
|
||||||
|
special characters (like newlines, tabs, etc.) and converts them into
|
||||||
|
their escape sequence representations (e.g., newline becomes \\n).
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
"""Input schema for TextEncoderBlock."""
|
||||||
|
|
||||||
|
text: str = SchemaField(
|
||||||
|
description="A string containing special characters to be encoded",
|
||||||
|
placeholder="Your text with newlines and quotes to encode",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
"""Output schema for TextEncoderBlock."""
|
||||||
|
|
||||||
|
encoded_text: str = SchemaField(
|
||||||
|
description="The encoded text with special characters converted to escape sequences"
|
||||||
|
)
|
||||||
|
error: str = SchemaField(description="Error message if encoding fails")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="5185f32e-4b65-4ecf-8fbb-873f003f09d6",
|
||||||
|
description="Encodes a string by converting special characters into escape sequences",
|
||||||
|
categories={BlockCategory.TEXT},
|
||||||
|
input_schema=TextEncoderBlock.Input,
|
||||||
|
output_schema=TextEncoderBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"text": """Hello
|
||||||
|
World!
|
||||||
|
This is a "quoted" string."""
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
(
|
||||||
|
"encoded_text",
|
||||||
|
"""Hello\\nWorld!\\nThis is a "quoted" string.""",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||||
|
"""
|
||||||
|
Encode the input text by converting special characters to escape sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data: The input containing the text to encode.
|
||||||
|
**kwargs: Additional keyword arguments (unused).
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The encoded text with escape sequences, or an error message if encoding fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
encoded_text = codecs.encode(input_data.text, "unicode_escape").decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
yield "encoded_text", encoded_text
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", f"Encoding error: {str(e)}"
|
||||||
@@ -8,7 +8,7 @@ which provides access to LinkedIn profile data and related information.
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -478,7 +478,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
webset = aexa.websets.get(id=input_data.external_id)
|
webset = await aexa.websets.get(id=input_data.external_id)
|
||||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||||
|
|
||||||
yield "webset", webset_result
|
yield "webset", webset_result
|
||||||
@@ -494,7 +494,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
count=input_data.search_count,
|
count=input_data.search_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
webset = aexa.websets.create(
|
webset = await aexa.websets.create(
|
||||||
params=CreateWebsetParameters(
|
params=CreateWebsetParameters(
|
||||||
search=search_params,
|
search=search_params,
|
||||||
external_id=input_data.external_id,
|
external_id=input_data.external_id,
|
||||||
@@ -554,7 +554,7 @@ class ExaUpdateWebsetBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_webset = aexa.websets.update(id=input_data.webset_id, params=payload)
|
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -617,7 +617,7 @@ class ExaListWebsetsBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = aexa.websets.list(
|
response = await aexa.websets.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -678,7 +678,7 @@ class ExaGetWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_webset = aexa.websets.get(id=input_data.webset_id)
|
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -748,7 +748,7 @@ class ExaDeleteWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_webset = aexa.websets.delete(id=input_data.webset_id)
|
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
deleted_webset.status.value
|
deleted_webset.status.value
|
||||||
@@ -798,7 +798,7 @@ class ExaCancelWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_webset = aexa.websets.cancel(id=input_data.webset_id)
|
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
canceled_webset.status.value
|
canceled_webset.status.value
|
||||||
@@ -968,7 +968,7 @@ class ExaPreviewWebsetBlock(Block):
|
|||||||
entity["description"] = input_data.entity_description
|
entity["description"] = input_data.entity_description
|
||||||
payload["entity"] = entity
|
payload["entity"] = entity
|
||||||
|
|
||||||
sdk_preview = aexa.websets.preview(params=payload)
|
sdk_preview = await aexa.websets.preview(params=payload)
|
||||||
|
|
||||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||||
|
|
||||||
@@ -1051,7 +1051,7 @@ class ExaWebsetStatusBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
@@ -1185,7 +1185,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Extract basic info
|
# Extract basic info
|
||||||
webset_id = webset.id
|
webset_id = webset.id
|
||||||
@@ -1211,7 +1211,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
total_items = 0
|
total_items = 0
|
||||||
|
|
||||||
if input_data.include_sample_items and input_data.sample_size > 0:
|
if input_data.include_sample_items and input_data.sample_size > 0:
|
||||||
items_response = aexa.websets.items.list(
|
items_response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
sample_items_data = [
|
sample_items_data = [
|
||||||
@@ -1362,7 +1362,7 @@ class ExaWebsetReadyCheckBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset details
|
# Get webset details
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = aexa.websets.enrichments.create(
|
sdk_enrichment = await aexa.websets.enrichments.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -223,7 +223,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_enrich = aexa.websets.enrichments.get(
|
current_enrich = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=enrichment_id
|
webset_id=input_data.webset_id, id=enrichment_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -234,7 +234,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
|
|
||||||
if current_status in ["completed", "failed", "cancelled"]:
|
if current_status in ["completed", "failed", "cancelled"]:
|
||||||
# Estimate items from webset searches
|
# Estimate items from webset searches
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
for search in webset.searches:
|
for search in webset.searches:
|
||||||
if search.progress:
|
if search.progress:
|
||||||
@@ -329,7 +329,7 @@ class ExaGetEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = aexa.websets.enrichments.get(
|
sdk_enrichment = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -474,7 +474,7 @@ class ExaDeleteEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_enrichment = aexa.websets.enrichments.delete(
|
deleted_enrichment = await aexa.websets.enrichments.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -525,13 +525,13 @@ class ExaCancelEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_enrichment = aexa.websets.enrichments.cancel(
|
canceled_enrichment = await aexa.websets.enrichments.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to estimate how many items were enriched before cancellation
|
# Try to estimate how many items were enriched before cancellation
|
||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
items_response = aexa.websets.items.list(
|
items_response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=100
|
webset_id=input_data.webset_id, limit=100
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
# Create mock SDK import object
|
# Create mock SDK import object
|
||||||
mock_import = MagicMock()
|
mock_import = MagicMock()
|
||||||
@@ -247,7 +247,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
|
imports=MagicMock(create=AsyncMock(return_value=mock_import))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -294,7 +294,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_import = aexa.websets.imports.create(
|
sdk_import = await aexa.websets.imports.create(
|
||||||
params=payload, csv_data=input_data.csv_data
|
params=payload, csv_data=input_data.csv_data
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -360,7 +360,7 @@ class ExaGetImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
|
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
|
||||||
|
|
||||||
import_obj = ImportModel.from_sdk(sdk_import)
|
import_obj = ImportModel.from_sdk(sdk_import)
|
||||||
|
|
||||||
@@ -426,7 +426,7 @@ class ExaListImportsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = aexa.websets.imports.list(
|
response = await aexa.websets.imports.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -474,7 +474,9 @@ class ExaDeleteImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
|
deleted_import = await aexa.websets.imports.delete(
|
||||||
|
import_id=input_data.import_id
|
||||||
|
)
|
||||||
|
|
||||||
yield "import_id", deleted_import.id
|
yield "import_id", deleted_import.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -573,14 +575,14 @@ class ExaExportWebsetBlock(Block):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create mock iterator
|
# Create async iterator for list_all
|
||||||
mock_items = [mock_item1, mock_item2]
|
async def async_item_iterator(*args, **kwargs):
|
||||||
|
for item in [mock_item1, mock_item2]:
|
||||||
|
yield item
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(items=MagicMock(list_all=async_item_iterator))
|
||||||
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -602,7 +604,7 @@ class ExaExportWebsetBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
for sdk_item in item_iterator:
|
async for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class ExaGetWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_item = aexa.websets.items.get(
|
sdk_item = await aexa.websets.items.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -269,7 +269,7 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
response = None
|
response = None
|
||||||
|
|
||||||
while time.time() - start_time < input_data.wait_timeout:
|
while time.time() - start_time < input_data.wait_timeout:
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -282,13 +282,13 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
interval = min(interval * 1.2, 10)
|
interval = min(interval * 1.2, 10)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -340,7 +340,7 @@ class ExaDeleteWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_item = aexa.websets.items.delete(
|
deleted_item = await aexa.websets.items.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -408,7 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
for sdk_item in item_iterator:
|
async for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -475,7 +475,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
entity_type = "unknown"
|
entity_type = "unknown"
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
@@ -495,7 +495,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Get sample items if requested
|
# Get sample items if requested
|
||||||
sample_items: List[WebsetItemModel] = []
|
sample_items: List[WebsetItemModel] = []
|
||||||
if input_data.sample_size > 0:
|
if input_data.sample_size > 0:
|
||||||
items_response = aexa.websets.items.list(
|
items_response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
# Convert to our stable models
|
# Convert to our stable models
|
||||||
@@ -569,7 +569,7 @@ class ExaGetNewItemsBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get items starting from cursor
|
# Get items starting from cursor
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.since_cursor,
|
cursor=input_data.since_cursor,
|
||||||
limit=input_data.max_items,
|
limit=input_data.max_items,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user