mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-13 08:14:58 -05:00
Compare commits
145 Commits
docs/works
...
feat/copil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60c6c14211 | ||
|
|
a86878ae6f | ||
|
|
80e413b969 | ||
|
|
52c8a25531 | ||
|
|
d0f0c32e70 | ||
|
|
8dfd0a77a0 | ||
|
|
4bfd6c8870 | ||
|
|
1918828405 | ||
|
|
9c855b501b | ||
|
|
5c9d0577c0 | ||
|
|
a79bd88e7c | ||
|
|
28c1121a8f | ||
|
|
cb3839198c | ||
|
|
80804986b0 | ||
|
|
43b25b5e2f | ||
|
|
ab0b537cc7 | ||
|
|
b915e67a9b | ||
|
|
9a8c6ad609 | ||
|
|
32e9dda30d | ||
|
|
e8c50b96d1 | ||
|
|
30e854569a | ||
|
|
301d7cbada | ||
|
|
d95aef7665 | ||
|
|
cb45e7957b | ||
|
|
f1d02fb8f3 | ||
|
|
47de6b6420 | ||
|
|
62cd2eea89 | ||
|
|
ae61ec692e | ||
|
|
9296bd8736 | ||
|
|
308113c03d | ||
|
|
51abf13254 | ||
|
|
54b03d3a29 | ||
|
|
239dff5ebd | ||
|
|
1dd53db21c | ||
|
|
06c16ee2fe | ||
|
|
8d2a649ee5 | ||
|
|
cb166dd6fb | ||
|
|
9589474709 | ||
|
|
3d31f62bf1 | ||
|
|
b8b6c9de23 | ||
|
|
749a78723a | ||
|
|
bec2e1ddee | ||
|
|
ec1ab06e0d | ||
|
|
f31cb49557 | ||
|
|
fd28c386f4 | ||
|
|
3bea584659 | ||
|
|
4f6055f494 | ||
|
|
695a185fa1 | ||
|
|
113e87a23c | ||
|
|
d09f1532a4 | ||
|
|
d7f7a2747f | ||
|
|
68849e197c | ||
|
|
211478bb29 | ||
|
|
0e88dd15b2 | ||
|
|
7f3c227f0a | ||
|
|
40b58807ab | ||
|
|
d0e2e6f013 | ||
|
|
efdc8d73cc | ||
|
|
a34810d8a2 | ||
|
|
038b7d5841 | ||
|
|
cac93b0cc9 | ||
|
|
a78145505b | ||
|
|
2025aaf5f2 | ||
|
|
ae9bce3bae | ||
|
|
3107d889fc | ||
|
|
f174fb6303 | ||
|
|
920a4c5f15 | ||
|
|
e95fadbb86 | ||
|
|
b14b3803ad | ||
|
|
36aeb0b2b3 | ||
|
|
2a189c44c4 | ||
|
|
508759610f | ||
|
|
062fe1aa70 | ||
|
|
82c483d6c8 | ||
|
|
7cffa1895f | ||
|
|
9791bdd724 | ||
|
|
750a674c78 | ||
|
|
960c7980a3 | ||
|
|
e85d437bb2 | ||
|
|
2cd0d4fe0f | ||
|
|
44f9536bd6 | ||
|
|
1c1085a227 | ||
|
|
d7ef70469e | ||
|
|
1926127ddd | ||
|
|
8b509e56de | ||
|
|
1ecae8c87e | ||
|
|
659338f90c | ||
|
|
4df5b7bde7 | ||
|
|
acb2d0bd1b | ||
|
|
51aa369c80 | ||
|
|
017a00af46 | ||
|
|
6403ffe353 | ||
|
|
52650eed1d | ||
|
|
c40a98ba3c | ||
|
|
a31fc8b162 | ||
|
|
81c1524658 | ||
|
|
f2ead70f3d | ||
|
|
0f2d1a6553 | ||
|
|
87d817b83b | ||
|
|
7d4c020a9b | ||
|
|
acf932bf4f | ||
|
|
f562d9a277 | ||
|
|
3c92a96504 | ||
|
|
8b8e1df739 | ||
|
|
e596ea87cb | ||
|
|
602a0a4fb1 | ||
|
|
8d7d531ae0 | ||
|
|
43153a12e0 | ||
|
|
587e11c60a | ||
|
|
57da545e02 | ||
|
|
81f8290f01 | ||
|
|
626980bf27 | ||
|
|
6467f6734f | ||
|
|
5a30d11416 | ||
|
|
1f4105e8f9 | ||
|
|
caf9ff34e6 | ||
|
|
e42b27af3c | ||
|
|
34face15d2 | ||
|
|
e8fc8ee623 | ||
|
|
1a16e203b8 | ||
|
|
7d32c83f95 | ||
|
|
5dae303ce0 | ||
|
|
6e2a45b84e | ||
|
|
32f6532e9c | ||
|
|
6cbfbdd013 | ||
|
|
0c6fa60436 | ||
|
|
b04e916c23 | ||
|
|
1a32ba7d9a | ||
|
|
deccc26f1f | ||
|
|
9e38bd5b78 | ||
|
|
a329831b0b | ||
|
|
98dd1a9480 | ||
|
|
9c7c598c7d | ||
|
|
728c40def5 | ||
|
|
0bbe8a184d | ||
|
|
7592deed63 | ||
|
|
b9c759ce4f | ||
|
|
cd64562e1b | ||
|
|
8fddc9d71f | ||
|
|
5efb80d47b | ||
|
|
b49d8e2cba | ||
|
|
452544530d | ||
|
|
32ee7e6cf8 | ||
|
|
670663c406 | ||
|
|
0dbe4cf51e |
@@ -5,42 +5,13 @@
|
|||||||
!docs/
|
!docs/
|
||||||
|
|
||||||
# Platform - Libs
|
# Platform - Libs
|
||||||
!autogpt_platform/autogpt_libs/autogpt_libs/
|
!autogpt_platform/autogpt_libs/
|
||||||
!autogpt_platform/autogpt_libs/pyproject.toml
|
|
||||||
!autogpt_platform/autogpt_libs/poetry.lock
|
|
||||||
!autogpt_platform/autogpt_libs/README.md
|
|
||||||
|
|
||||||
# Platform - Backend
|
# Platform - Backend
|
||||||
!autogpt_platform/backend/backend/
|
!autogpt_platform/backend/
|
||||||
!autogpt_platform/backend/test/e2e_test_data.py
|
|
||||||
!autogpt_platform/backend/migrations/
|
|
||||||
!autogpt_platform/backend/schema.prisma
|
|
||||||
!autogpt_platform/backend/pyproject.toml
|
|
||||||
!autogpt_platform/backend/poetry.lock
|
|
||||||
!autogpt_platform/backend/README.md
|
|
||||||
!autogpt_platform/backend/.env
|
|
||||||
!autogpt_platform/backend/gen_prisma_types_stub.py
|
|
||||||
|
|
||||||
# Platform - Market
|
|
||||||
!autogpt_platform/market/market/
|
|
||||||
!autogpt_platform/market/scripts.py
|
|
||||||
!autogpt_platform/market/schema.prisma
|
|
||||||
!autogpt_platform/market/pyproject.toml
|
|
||||||
!autogpt_platform/market/poetry.lock
|
|
||||||
!autogpt_platform/market/README.md
|
|
||||||
|
|
||||||
# Platform - Frontend
|
# Platform - Frontend
|
||||||
!autogpt_platform/frontend/src/
|
!autogpt_platform/frontend/
|
||||||
!autogpt_platform/frontend/public/
|
|
||||||
!autogpt_platform/frontend/scripts/
|
|
||||||
!autogpt_platform/frontend/package.json
|
|
||||||
!autogpt_platform/frontend/pnpm-lock.yaml
|
|
||||||
!autogpt_platform/frontend/tsconfig.json
|
|
||||||
!autogpt_platform/frontend/README.md
|
|
||||||
## config
|
|
||||||
!autogpt_platform/frontend/*.config.*
|
|
||||||
!autogpt_platform/frontend/.env.*
|
|
||||||
!autogpt_platform/frontend/.env
|
|
||||||
|
|
||||||
# Classic - AutoGPT
|
# Classic - AutoGPT
|
||||||
!classic/original_autogpt/autogpt/
|
!classic/original_autogpt/autogpt/
|
||||||
@@ -64,6 +35,38 @@
|
|||||||
# Classic - Frontend
|
# Classic - Frontend
|
||||||
!classic/frontend/build/web/
|
!classic/frontend/build/web/
|
||||||
|
|
||||||
# Explicitly re-ignore some folders
|
# Explicitly re-ignore unwanted files from whitelisted directories
|
||||||
.*
|
# Note: These patterns MUST come after the whitelist rules to take effect
|
||||||
**/__pycache__
|
|
||||||
|
# Hidden files and directories (but keep frontend .env files needed for build)
|
||||||
|
**/.*
|
||||||
|
!autogpt_platform/frontend/.env
|
||||||
|
!autogpt_platform/frontend/.env.default
|
||||||
|
!autogpt_platform/frontend/.env.production
|
||||||
|
|
||||||
|
# Python artifacts
|
||||||
|
**/__pycache__/
|
||||||
|
**/*.pyc
|
||||||
|
**/*.pyo
|
||||||
|
**/.venv/
|
||||||
|
**/.ruff_cache/
|
||||||
|
**/.pytest_cache/
|
||||||
|
**/.coverage
|
||||||
|
**/htmlcov/
|
||||||
|
|
||||||
|
# Node artifacts
|
||||||
|
**/node_modules/
|
||||||
|
**/.next/
|
||||||
|
**/storybook-static/
|
||||||
|
**/playwright-report/
|
||||||
|
**/test-results/
|
||||||
|
|
||||||
|
# Build artifacts
|
||||||
|
**/dist/
|
||||||
|
**/build/
|
||||||
|
!autogpt_platform/frontend/src/**/build/
|
||||||
|
**/target/
|
||||||
|
|
||||||
|
# Logs and temp files
|
||||||
|
**/*.log
|
||||||
|
**/*.tmp
|
||||||
|
|||||||
2
.github/workflows/classic-frontend-ci.yml
vendored
2
.github/workflows/classic-frontend-ci.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||||
if: github.event_name == 'push'
|
if: github.event_name == 'push'
|
||||||
uses: peter-evans/create-pull-request@v7
|
uses: peter-evans/create-pull-request@v8
|
||||||
with:
|
with:
|
||||||
add-paths: classic/frontend/build/web
|
add-paths: classic/frontend/build/web
|
||||||
base: ${{ github.ref_name }}
|
base: ${{ github.ref_name }}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.event.workflow_run.head_branch }}
|
ref: ${{ github.event.workflow_run.head_branch }}
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
@@ -42,7 +42,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Get CI failure details
|
- name: Get CI failure details
|
||||||
id: failure_details
|
id: failure_details
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const run = await github.rest.actions.getWorkflowRun({
|
const run = await github.rest.actions.getWorkflowRun({
|
||||||
|
|||||||
11
.github/workflows/claude-dependabot.yml
vendored
11
.github/workflows/claude-dependabot.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
|||||||
actions: read # Required for CI access
|
actions: read # Required for CI access
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -41,7 +41,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -78,7 +78,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -124,7 +124,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
@@ -309,6 +309,7 @@ jobs:
|
|||||||
uses: anthropics/claude-code-action@v1
|
uses: anthropics/claude-code-action@v1
|
||||||
with:
|
with:
|
||||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||||
|
allowed_bots: "dependabot[bot]"
|
||||||
claude_args: |
|
claude_args: |
|
||||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||||
prompt: |
|
prompt: |
|
||||||
|
|||||||
10
.github/workflows/claude.yml
vendored
10
.github/workflows/claude.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
|||||||
actions: read # Required for CI access
|
actions: read # Required for CI access
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -94,7 +94,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -140,7 +140,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
|
|||||||
2
.github/workflows/codeql.yml
vendored
2
.github/workflows/codeql.yml
vendored
@@ -58,7 +58,7 @@ jobs:
|
|||||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
# Initializes the CodeQL tools for scanning.
|
# Initializes the CodeQL tools for scanning.
|
||||||
- name: Initialize CodeQL
|
- name: Initialize CodeQL
|
||||||
|
|||||||
10
.github/workflows/copilot-setup-steps.yml
vendored
10
.github/workflows/copilot-setup-steps.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
# If you do not check out your code, Copilot will do this for you.
|
# If you do not check out your code, Copilot will do this for you.
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
submodules: true
|
submodules: true
|
||||||
@@ -39,7 +39,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -76,7 +76,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -132,7 +132,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
|
|||||||
4
.github/workflows/docs-block-sync.yml
vendored
4
.github/workflows/docs-block-sync.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
4
.github/workflows/docs-claude-review.yml
vendored
4
.github/workflows/docs-claude-review.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
4
.github/workflows/docs-enhance.yml
vendored
4
.github/workflows/docs-enhance.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger deploy workflow
|
- name: Trigger deploy workflow
|
||||||
uses: peter-evans/repository-dispatch@v3
|
uses: peter-evans/repository-dispatch@v4
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.ref_name || 'master' }}
|
ref: ${{ github.ref_name || 'master' }}
|
||||||
|
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger deploy workflow
|
- name: Trigger deploy workflow
|
||||||
uses: peter-evans/repository-dispatch@v3
|
uses: peter-evans/repository-dispatch@v4
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
|
|||||||
4
.github/workflows/platform-backend-ci.yml
vendored
4
.github/workflows/platform-backend-ci.yml
vendored
@@ -68,7 +68,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
submodules: true
|
submodules: true
|
||||||
@@ -88,7 +88,7 @@ jobs:
|
|||||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ jobs:
|
|||||||
- name: Check comment permissions and deployment status
|
- name: Check comment permissions and deployment status
|
||||||
id: check_status
|
id: check_status
|
||||||
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const commentBody = context.payload.comment.body.trim();
|
const commentBody = context.payload.comment.body.trim();
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post permission denied comment
|
- name: Post permission denied comment
|
||||||
if: steps.check_status.outputs.permission_denied == 'true'
|
if: steps.check_status.outputs.permission_denied == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -68,7 +68,7 @@ jobs:
|
|||||||
- name: Get PR details for deployment
|
- name: Get PR details for deployment
|
||||||
id: pr_details
|
id: pr_details
|
||||||
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const pr = await github.rest.pulls.get({
|
const pr = await github.rest.pulls.get({
|
||||||
@@ -82,7 +82,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Dispatch Deploy Event
|
- name: Dispatch Deploy Event
|
||||||
if: steps.check_status.outputs.should_deploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true'
|
||||||
uses: peter-evans/repository-dispatch@v3
|
uses: peter-evans/repository-dispatch@v4
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
@@ -98,7 +98,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post deploy success comment
|
- name: Post deploy success comment
|
||||||
if: steps.check_status.outputs.should_deploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -110,7 +110,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Dispatch Undeploy Event (from comment)
|
- name: Dispatch Undeploy Event (from comment)
|
||||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: peter-evans/repository-dispatch@v3
|
uses: peter-evans/repository-dispatch@v4
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
@@ -126,7 +126,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post undeploy success comment
|
- name: Post undeploy success comment
|
||||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -139,7 +139,7 @@ jobs:
|
|||||||
- name: Check deployment status on PR close
|
- name: Check deployment status on PR close
|
||||||
id: check_pr_close
|
id: check_pr_close
|
||||||
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const comments = await github.rest.issues.listComments({
|
const comments = await github.rest.issues.listComments({
|
||||||
@@ -168,7 +168,7 @@ jobs:
|
|||||||
github.event_name == 'pull_request' &&
|
github.event_name == 'pull_request' &&
|
||||||
github.event.action == 'closed' &&
|
github.event.action == 'closed' &&
|
||||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||||
uses: peter-evans/repository-dispatch@v3
|
uses: peter-evans/repository-dispatch@v4
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
@@ -187,7 +187,7 @@ jobs:
|
|||||||
github.event_name == 'pull_request' &&
|
github.event_name == 'pull_request' &&
|
||||||
github.event.action == 'closed' &&
|
github.event.action == 'closed' &&
|
||||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
|
|||||||
257
.github/workflows/platform-frontend-ci.yml
vendored
257
.github/workflows/platform-frontend-ci.yml
vendored
@@ -26,12 +26,11 @@ jobs:
|
|||||||
setup:
|
setup:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
|
||||||
components-changed: ${{ steps.filter.outputs.components }}
|
components-changed: ${{ steps.filter.outputs.components }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Check for component changes
|
- name: Check for component changes
|
||||||
uses: dorny/paths-filter@v3
|
uses: dorny/paths-filter@v3
|
||||||
@@ -41,28 +40,17 @@ jobs:
|
|||||||
components:
|
components:
|
||||||
- 'autogpt_platform/frontend/src/components/**'
|
- 'autogpt_platform/frontend/src/components/**'
|
||||||
|
|
||||||
- name: Set up Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: "22.18.0"
|
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Enable corepack
|
||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Generate cache key
|
- name: Set up Node
|
||||||
id: cache-key
|
uses: actions/setup-node@v6
|
||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
- name: Cache dependencies
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
node-version: "22.18.0"
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
cache: "pnpm"
|
||||||
restore-keys: |
|
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
|
||||||
${{ runner.os }}-pnpm-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies to populate cache
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
@@ -71,24 +59,17 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: "22.18.0"
|
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Enable corepack
|
||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Set up Node
|
||||||
uses: actions/cache@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
node-version: "22.18.0"
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
cache: "pnpm"
|
||||||
restore-keys: |
|
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
|
||||||
${{ runner.os }}-pnpm-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
@@ -107,26 +88,19 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Set up Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: "22.18.0"
|
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Enable corepack
|
||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Set up Node
|
||||||
uses: actions/cache@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
node-version: "22.18.0"
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
cache: "pnpm"
|
||||||
restore-keys: |
|
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
|
||||||
${{ runner.os }}-pnpm-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
@@ -141,30 +115,20 @@ jobs:
|
|||||||
exitOnceUploaded: true
|
exitOnceUploaded: true
|
||||||
|
|
||||||
e2e_test:
|
e2e_test:
|
||||||
|
name: end-to-end tests
|
||||||
runs-on: big-boi
|
runs-on: big-boi
|
||||||
needs: setup
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Platform - Copy default supabase .env
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: "22.18.0"
|
|
||||||
|
|
||||||
- name: Enable corepack
|
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Copy default supabase .env
|
|
||||||
run: |
|
run: |
|
||||||
cp ../.env.default ../.env
|
cp ../.env.default ../.env
|
||||||
|
|
||||||
- name: Copy backend .env and set OpenAI API key
|
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
||||||
run: |
|
run: |
|
||||||
cp ../backend/.env.default ../backend/.env
|
cp ../backend/.env.default ../backend/.env
|
||||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||||
@@ -172,77 +136,125 @@ jobs:
|
|||||||
# Used by E2E test data script to generate embeddings for approved store agents
|
# Used by E2E test data script to generate embeddings for approved store agents
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Platform - Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
- name: Cache Docker layers
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
with:
|
||||||
path: /tmp/.buildx-cache
|
driver: docker-container
|
||||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
driver-opts: network=host
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-buildx-frontend-test-
|
|
||||||
|
|
||||||
- name: Run docker compose
|
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||||
|
uses: crazy-max/ghaction-github-runtime@v3
|
||||||
|
|
||||||
|
- name: Set up Platform - Build Docker images (with cache)
|
||||||
|
working-directory: autogpt_platform
|
||||||
run: |
|
run: |
|
||||||
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
|
pip install pyyaml
|
||||||
|
|
||||||
|
# Resolve extends and generate a flat compose file that bake can understand
|
||||||
|
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||||
|
|
||||||
|
# Add cache configuration to the resolved compose file
|
||||||
|
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||||
|
--source docker-compose.resolved.yml \
|
||||||
|
--cache-from "type=gha" \
|
||||||
|
--cache-to "type=gha,mode=max" \
|
||||||
|
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
|
||||||
|
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
|
||||||
|
--git-ref "${{ github.ref }}"
|
||||||
|
|
||||||
|
# Build with bake using the resolved compose file (now includes cache config)
|
||||||
|
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||||
env:
|
env:
|
||||||
DOCKER_BUILDKIT: 1
|
NEXT_PUBLIC_PW_TEST: true
|
||||||
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
|
||||||
BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
|
||||||
|
|
||||||
- name: Move cache
|
- name: Set up tests - Cache E2E test data
|
||||||
run: |
|
id: e2e-data-cache
|
||||||
rm -rf /tmp/.buildx-cache
|
uses: actions/cache@v5
|
||||||
if [ -d "/tmp/.buildx-cache-new" ]; then
|
with:
|
||||||
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
|
path: /tmp/e2e_test_data.sql
|
||||||
fi
|
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
|
||||||
|
|
||||||
- name: Wait for services to be ready
|
- name: Set up Platform - Start Supabase DB + Auth
|
||||||
run: |
|
run: |
|
||||||
|
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
||||||
|
echo "Waiting for database to be ready..."
|
||||||
|
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
||||||
|
echo "Waiting for auth service to be ready..."
|
||||||
|
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
||||||
|
|
||||||
|
- name: Set up Platform - Run migrations
|
||||||
|
run: |
|
||||||
|
echo "Running migrations..."
|
||||||
|
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
||||||
|
echo "✅ Migrations completed"
|
||||||
|
env:
|
||||||
|
NEXT_PUBLIC_PW_TEST: true
|
||||||
|
|
||||||
|
- name: Set up tests - Load cached E2E test data
|
||||||
|
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
||||||
|
run: |
|
||||||
|
echo "✅ Found cached E2E test data, restoring..."
|
||||||
|
{
|
||||||
|
echo "SET session_replication_role = 'replica';"
|
||||||
|
cat /tmp/e2e_test_data.sql
|
||||||
|
echo "SET session_replication_role = 'origin';"
|
||||||
|
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
||||||
|
# Refresh materialized views after restore
|
||||||
|
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||||
|
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
||||||
|
|
||||||
|
echo "✅ E2E test data restored from cache"
|
||||||
|
|
||||||
|
- name: Set up Platform - Start (all other services)
|
||||||
|
run: |
|
||||||
|
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
||||||
echo "Waiting for rest_server to be ready..."
|
echo "Waiting for rest_server to be ready..."
|
||||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||||
echo "Waiting for database to be ready..."
|
env:
|
||||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
NEXT_PUBLIC_PW_TEST: true
|
||||||
|
|
||||||
- name: Create E2E test data
|
- name: Set up tests - Create E2E test data
|
||||||
|
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
||||||
run: |
|
run: |
|
||||||
echo "Creating E2E test data..."
|
echo "Creating E2E test data..."
|
||||||
# First try to run the script from inside the container
|
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
||||||
if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then
|
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||||
echo "✅ Found e2e_test_data.py in container, running it..."
|
echo "❌ E2E test data creation failed!"
|
||||||
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || {
|
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
||||||
echo "❌ E2E test data creation failed!"
|
exit 1
|
||||||
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
}
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
else
|
|
||||||
echo "⚠️ e2e_test_data.py not found in container, copying and running..."
|
|
||||||
# Copy the script into the container and run it
|
|
||||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || {
|
|
||||||
echo "❌ Failed to copy script to container"
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
|
||||||
echo "❌ E2E test data creation failed!"
|
|
||||||
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
# Dump auth.users + platform schema for cache (two separate dumps)
|
||||||
uses: actions/cache@v4
|
echo "Dumping database for cache..."
|
||||||
|
{
|
||||||
|
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||||
|
pg_dump -U postgres --data-only --column-inserts \
|
||||||
|
--table='auth.users' postgres
|
||||||
|
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||||
|
pg_dump -U postgres --data-only --column-inserts \
|
||||||
|
--schema=platform \
|
||||||
|
--exclude-table='platform._prisma_migrations' \
|
||||||
|
--exclude-table='platform.apscheduler_jobs' \
|
||||||
|
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
||||||
|
postgres
|
||||||
|
} > /tmp/e2e_test_data.sql
|
||||||
|
|
||||||
|
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
||||||
|
|
||||||
|
- name: Set up tests - Enable corepack
|
||||||
|
run: corepack enable
|
||||||
|
|
||||||
|
- name: Set up tests - Set up Node
|
||||||
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
node-version: "22.18.0"
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
cache: "pnpm"
|
||||||
restore-keys: |
|
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
|
||||||
${{ runner.os }}-pnpm-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Set up tests - Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
- name: Install Browser 'chromium'
|
- name: Set up tests - Install browser 'chromium'
|
||||||
run: pnpm playwright install --with-deps chromium
|
run: pnpm playwright install --with-deps chromium
|
||||||
|
|
||||||
- name: Run Playwright tests
|
- name: Run Playwright tests
|
||||||
@@ -269,7 +281,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Print Final Docker Compose logs
|
- name: Print Final Docker Compose logs
|
||||||
if: always()
|
if: always()
|
||||||
run: docker compose -f ../docker-compose.yml logs
|
run: docker compose -f ../docker-compose.resolved.yml logs
|
||||||
|
|
||||||
integration_test:
|
integration_test:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@@ -277,26 +289,19 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: "22.18.0"
|
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Enable corepack
|
||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Set up Node
|
||||||
uses: actions/cache@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
node-version: "22.18.0"
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
cache: "pnpm"
|
||||||
restore-keys: |
|
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
|
||||||
${{ runner.os }}-pnpm-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|||||||
16
.github/workflows/platform-fullstack-ci.yml
vendored
16
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -29,10 +29,10 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ jobs:
|
|||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
key: ${{ steps.cache-key.outputs.key }}
|
||||||
@@ -56,19 +56,19 @@ jobs:
|
|||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
types:
|
types:
|
||||||
runs-on: ubuntu-latest
|
runs-on: big-boi
|
||||||
needs: setup
|
needs: setup
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -85,10 +85,10 @@ jobs:
|
|||||||
|
|
||||||
- name: Run docker compose
|
- name: Run docker compose
|
||||||
run: |
|
run: |
|
||||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
|||||||
2
.github/workflows/repo-workflow-checker.yml
vendored
2
.github/workflows/repo-workflow-checker.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
# - name: Wait some time for all actions to start
|
# - name: Wait some time for all actions to start
|
||||||
# run: sleep 30
|
# run: sleep 30
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
# with:
|
# with:
|
||||||
# fetch-depth: 0
|
# fetch-depth: 0
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
|
|||||||
195
.github/workflows/scripts/docker-ci-fix-compose-build-cache.py
vendored
Normal file
195
.github/workflows/scripts/docker-ci-fix-compose-build-cache.py
vendored
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Add cache configuration to a resolved docker-compose file for all services
|
||||||
|
that have a build key, and ensure image names match what docker compose expects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_BRANCH = "dev"
|
||||||
|
CACHE_BUILDS_FOR_COMPONENTS = ["backend", "frontend"]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Add cache config to a resolved compose file"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--source",
|
||||||
|
required=True,
|
||||||
|
help="Source compose file to read (should be output of `docker compose config`)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache-from",
|
||||||
|
default="type=gha",
|
||||||
|
help="Cache source configuration",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache-to",
|
||||||
|
default="type=gha,mode=max",
|
||||||
|
help="Cache destination configuration",
|
||||||
|
)
|
||||||
|
for component in CACHE_BUILDS_FOR_COMPONENTS:
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{component}-hash",
|
||||||
|
default="",
|
||||||
|
help=f"Hash for {component} cache scope (e.g., from hashFiles())",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--git-ref",
|
||||||
|
default="",
|
||||||
|
help="Git ref for branch-based cache scope (e.g., refs/heads/master)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Normalize git ref to a safe scope name (e.g., refs/heads/master -> master)
|
||||||
|
git_ref_scope = ""
|
||||||
|
if args.git_ref:
|
||||||
|
git_ref_scope = args.git_ref.replace("refs/heads/", "").replace("/", "-")
|
||||||
|
|
||||||
|
with open(args.source, "r") as f:
|
||||||
|
compose = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Get project name from compose file or default
|
||||||
|
project_name = compose.get("name", "autogpt_platform")
|
||||||
|
|
||||||
|
def get_image_name(dockerfile: str, target: str) -> str:
|
||||||
|
"""Generate image name based on Dockerfile folder and build target."""
|
||||||
|
dockerfile_parts = dockerfile.replace("\\", "/").split("/")
|
||||||
|
if len(dockerfile_parts) >= 2:
|
||||||
|
folder_name = dockerfile_parts[-2] # e.g., "backend" or "frontend"
|
||||||
|
else:
|
||||||
|
folder_name = "app"
|
||||||
|
return f"{project_name}-{folder_name}:{target}"
|
||||||
|
|
||||||
|
def get_build_key(dockerfile: str, target: str) -> str:
|
||||||
|
"""Generate a unique key for a Dockerfile+target combination."""
|
||||||
|
return f"{dockerfile}:{target}"
|
||||||
|
|
||||||
|
def get_component(dockerfile: str) -> str | None:
|
||||||
|
"""Get component name (frontend/backend) from dockerfile path."""
|
||||||
|
for component in CACHE_BUILDS_FOR_COMPONENTS:
|
||||||
|
if component in dockerfile:
|
||||||
|
return component
|
||||||
|
return None
|
||||||
|
|
||||||
|
# First pass: collect all services with build configs and identify duplicates
|
||||||
|
# Track which (dockerfile, target) combinations we've seen
|
||||||
|
build_key_to_first_service: dict[str, str] = {}
|
||||||
|
services_to_build: list[str] = []
|
||||||
|
services_to_dedupe: list[str] = []
|
||||||
|
|
||||||
|
for service_name, service_config in compose.get("services", {}).items():
|
||||||
|
if "build" not in service_config:
|
||||||
|
continue
|
||||||
|
|
||||||
|
build_config = service_config["build"]
|
||||||
|
dockerfile = build_config.get("dockerfile", "Dockerfile")
|
||||||
|
target = build_config.get("target", "default")
|
||||||
|
build_key = get_build_key(dockerfile, target)
|
||||||
|
|
||||||
|
if build_key not in build_key_to_first_service:
|
||||||
|
# First service with this build config - it will do the actual build
|
||||||
|
build_key_to_first_service[build_key] = service_name
|
||||||
|
services_to_build.append(service_name)
|
||||||
|
else:
|
||||||
|
# Duplicate - will just use the image from the first service
|
||||||
|
services_to_dedupe.append(service_name)
|
||||||
|
|
||||||
|
# Second pass: configure builds and deduplicate
|
||||||
|
modified_services = []
|
||||||
|
for service_name, service_config in compose.get("services", {}).items():
|
||||||
|
if "build" not in service_config:
|
||||||
|
continue
|
||||||
|
|
||||||
|
build_config = service_config["build"]
|
||||||
|
dockerfile = build_config.get("dockerfile", "Dockerfile")
|
||||||
|
target = build_config.get("target", "latest")
|
||||||
|
image_name = get_image_name(dockerfile, target)
|
||||||
|
|
||||||
|
# Set image name for all services (needed for both builders and deduped)
|
||||||
|
service_config["image"] = image_name
|
||||||
|
|
||||||
|
if service_name in services_to_dedupe:
|
||||||
|
# Remove build config - this service will use the pre-built image
|
||||||
|
del service_config["build"]
|
||||||
|
continue
|
||||||
|
|
||||||
|
# This service will do the actual build - add cache config
|
||||||
|
cache_from_list = []
|
||||||
|
cache_to_list = []
|
||||||
|
|
||||||
|
component = get_component(dockerfile)
|
||||||
|
if not component:
|
||||||
|
# Skip services that don't clearly match frontend/backend
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the hash for this component
|
||||||
|
component_hash = getattr(args, f"{component}_hash")
|
||||||
|
|
||||||
|
# Scope format: platform-{component}-{target}-{hash|ref}
|
||||||
|
# Example: platform-backend-server-abc123
|
||||||
|
|
||||||
|
if "type=gha" in args.cache_from:
|
||||||
|
# 1. Primary: exact hash match (most specific)
|
||||||
|
if component_hash:
|
||||||
|
hash_scope = f"platform-{component}-{target}-{component_hash}"
|
||||||
|
cache_from_list.append(f"{args.cache_from},scope={hash_scope}")
|
||||||
|
|
||||||
|
# 2. Fallback: branch-based cache
|
||||||
|
if git_ref_scope:
|
||||||
|
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
|
||||||
|
cache_from_list.append(f"{args.cache_from},scope={ref_scope}")
|
||||||
|
|
||||||
|
# 3. Fallback: dev branch cache (for PRs/feature branches)
|
||||||
|
if git_ref_scope and git_ref_scope != DEFAULT_BRANCH:
|
||||||
|
master_scope = f"platform-{component}-{target}-{DEFAULT_BRANCH}"
|
||||||
|
cache_from_list.append(f"{args.cache_from},scope={master_scope}")
|
||||||
|
|
||||||
|
if "type=gha" in args.cache_to:
|
||||||
|
# Write to both hash-based and branch-based scopes
|
||||||
|
if component_hash:
|
||||||
|
hash_scope = f"platform-{component}-{target}-{component_hash}"
|
||||||
|
cache_to_list.append(f"{args.cache_to},scope={hash_scope}")
|
||||||
|
|
||||||
|
if git_ref_scope:
|
||||||
|
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
|
||||||
|
cache_to_list.append(f"{args.cache_to},scope={ref_scope}")
|
||||||
|
|
||||||
|
# Ensure we have at least one cache source/target
|
||||||
|
if not cache_from_list:
|
||||||
|
cache_from_list.append(args.cache_from)
|
||||||
|
if not cache_to_list:
|
||||||
|
cache_to_list.append(args.cache_to)
|
||||||
|
|
||||||
|
build_config["cache_from"] = cache_from_list
|
||||||
|
build_config["cache_to"] = cache_to_list
|
||||||
|
modified_services.append(service_name)
|
||||||
|
|
||||||
|
# Write back to the same file
|
||||||
|
with open(args.source, "w") as f:
|
||||||
|
yaml.dump(compose, f, default_flow_style=False, sort_keys=False)
|
||||||
|
|
||||||
|
print(f"Added cache config to {len(modified_services)} services in {args.source}:")
|
||||||
|
for svc in modified_services:
|
||||||
|
svc_config = compose["services"][svc]
|
||||||
|
build_cfg = svc_config.get("build", {})
|
||||||
|
cache_from_list = build_cfg.get("cache_from", ["none"])
|
||||||
|
cache_to_list = build_cfg.get("cache_to", ["none"])
|
||||||
|
print(f" - {svc}")
|
||||||
|
print(f" image: {svc_config.get('image', 'N/A')}")
|
||||||
|
print(f" cache_from: {cache_from_list}")
|
||||||
|
print(f" cache_to: {cache_to_list}")
|
||||||
|
if services_to_dedupe:
|
||||||
|
print(
|
||||||
|
f"Deduplicated {len(services_to_dedupe)} services (will use pre-built images):"
|
||||||
|
)
|
||||||
|
for svc in services_to_dedupe:
|
||||||
|
print(f" - {svc} -> {compose['services'][svc].get('image', 'N/A')}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -45,6 +45,11 @@ AutoGPT Platform is a monorepo containing:
|
|||||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||||
|
|
||||||
|
### Branching Strategy
|
||||||
|
|
||||||
|
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||||
|
- **`master`** is the production branch. Only used for production releases.
|
||||||
|
|
||||||
### Creating Pull Requests
|
### Creating Pull Requests
|
||||||
|
|
||||||
- Create the PR against the `dev` branch of the repository.
|
- Create the PR against the `dev` branch of the repository.
|
||||||
|
|||||||
1865
autogpt_platform/autogpt_libs/poetry.lock
generated
1865
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.10,<4.0"
|
python = ">=3.10,<4.0"
|
||||||
colorama = "^0.4.6"
|
colorama = "^0.4.6"
|
||||||
cryptography = "^45.0"
|
cryptography = "^46.0"
|
||||||
expiringdict = "^1.2.2"
|
expiringdict = "^1.2.2"
|
||||||
fastapi = "^0.116.1"
|
fastapi = "^0.128.7"
|
||||||
google-cloud-logging = "^3.12.1"
|
google-cloud-logging = "^3.13.0"
|
||||||
launchdarkly-server-sdk = "^9.12.0"
|
launchdarkly-server-sdk = "^9.15.0"
|
||||||
pydantic = "^2.11.7"
|
pydantic = "^2.12.5"
|
||||||
pydantic-settings = "^2.10.1"
|
pydantic-settings = "^2.12.0"
|
||||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
supabase = "^2.16.0"
|
supabase = "^2.28.0"
|
||||||
uvicorn = "^0.35.0"
|
uvicorn = "^0.40.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pyright = "^1.1.404"
|
pyright = "^1.1.408"
|
||||||
pytest = "^8.4.1"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.1.0"
|
pytest-asyncio = "^1.3.0"
|
||||||
pytest-mock = "^3.14.1"
|
pytest-mock = "^3.15.1"
|
||||||
pytest-cov = "^6.2.1"
|
pytest-cov = "^7.0.0"
|
||||||
ruff = "^0.12.11"
|
ruff = "^0.15.0"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
# ============================ DEPENDENCY BUILDER ============================ #
|
||||||
|
|
||||||
FROM debian:13-slim AS builder
|
FROM debian:13-slim AS builder
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
@@ -51,7 +53,9 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
|
|||||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||||
|
|
||||||
FROM debian:13-slim AS server_dependencies
|
# ============================== BACKEND SERVER ============================== #
|
||||||
|
|
||||||
|
FROM debian:13-slim AS server
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
@@ -62,16 +66,21 @@ ENV POETRY_HOME=/opt/poetry \
|
|||||||
DEBIAN_FRONTEND=noninteractive
|
DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
||||||
RUN apt-get update && apt-get install -y \
|
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
||||||
|
# for the bash_exec MCP tool.
|
||||||
|
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
python3.13 \
|
python3.13 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
imagemagick \
|
imagemagick \
|
||||||
|
jq \
|
||||||
|
ripgrep \
|
||||||
|
tree \
|
||||||
|
bubblewrap \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy only necessary files from builder
|
|
||||||
COPY --from=builder /app /app
|
|
||||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||||
# Copy Node.js installation for Prisma
|
# Copy Node.js installation for Prisma
|
||||||
@@ -81,30 +90,54 @@ COPY --from=builder /usr/bin/npm /usr/bin/npm
|
|||||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||||
|
|
||||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
|
||||||
|
|
||||||
RUN mkdir -p /app/autogpt_platform/autogpt_libs
|
|
||||||
RUN mkdir -p /app/autogpt_platform/backend
|
|
||||||
|
|
||||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
|
||||||
|
|
||||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
|
|
||||||
|
|
||||||
WORKDIR /app/autogpt_platform/backend
|
WORKDIR /app/autogpt_platform/backend
|
||||||
|
|
||||||
FROM server_dependencies AS migrate
|
# Copy only the .venv from builder (not the entire /app directory)
|
||||||
|
# The .venv includes the generated Prisma client
|
||||||
|
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
|
||||||
|
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||||
|
|
||||||
# Migration stage only needs schema and migrations - much lighter than full backend
|
# Copy dependency files + autogpt_libs (path dependency)
|
||||||
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||||
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
|
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
|
||||||
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
|
||||||
|
|
||||||
FROM server_dependencies AS server
|
# Copy backend code + docs (for Copilot docs search)
|
||||||
|
COPY autogpt_platform/backend ./
|
||||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
|
||||||
COPY docs /app/docs
|
COPY docs /app/docs
|
||||||
RUN poetry install --no-ansi --only-root
|
RUN poetry install --no-ansi --only-root
|
||||||
|
|
||||||
ENV PORT=8000
|
ENV PORT=8000
|
||||||
|
|
||||||
CMD ["poetry", "run", "rest"]
|
CMD ["poetry", "run", "rest"]
|
||||||
|
|
||||||
|
# =============================== DB MIGRATOR =============================== #
|
||||||
|
|
||||||
|
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
|
||||||
|
FROM debian:13-slim AS migrate
|
||||||
|
|
||||||
|
WORKDIR /app/autogpt_platform/backend
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# Install only what's needed for prisma migrate: Node.js and minimal Python for prisma-python
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
python3.13 \
|
||||||
|
python3-pip \
|
||||||
|
ca-certificates \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copy Node.js from builder (needed for Prisma CLI)
|
||||||
|
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||||
|
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||||
|
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||||
|
|
||||||
|
# Copy Prisma binaries
|
||||||
|
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||||
|
|
||||||
|
# Install prisma-client-py directly (much smaller than copying full venv)
|
||||||
|
RUN pip3 install prisma>=0.15.0 --break-system-packages
|
||||||
|
|
||||||
|
COPY autogpt_platform/backend/schema.prisma ./
|
||||||
|
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||||
|
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||||
|
COPY autogpt_platform/backend/migrations ./migrations
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from typing_extensions import TypedDict
|
|||||||
|
|
||||||
import backend.api.features.store.cache as store_cache
|
import backend.api.features.store.cache as store_cache
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
import backend.data.block
|
import backend.blocks
|
||||||
from backend.api.external.middleware import require_permission
|
from backend.api.external.middleware import require_permission
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
@@ -67,7 +67,7 @@ async def get_user_info(
|
|||||||
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
||||||
)
|
)
|
||||||
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
blocks = [block() for block in backend.blocks.get_blocks().values()]
|
||||||
return [b.to_dict() for b in blocks if not b.disabled]
|
return [b.to_dict() for b in blocks if not b.disabled]
|
||||||
|
|
||||||
|
|
||||||
@@ -83,7 +83,7 @@ async def execute_graph_block(
|
|||||||
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
||||||
),
|
),
|
||||||
) -> CompletedBlockOutput:
|
) -> CompletedBlockOutput:
|
||||||
obj = backend.data.block.get_block(block_id)
|
obj = backend.blocks.get_block(block_id)
|
||||||
if not obj:
|
if not obj:
|
||||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||||
if obj.disabled:
|
if obj.disabled:
|
||||||
|
|||||||
@@ -10,10 +10,15 @@ import backend.api.features.library.db as library_db
|
|||||||
import backend.api.features.library.model as library_model
|
import backend.api.features.library.model as library_model
|
||||||
import backend.api.features.store.db as store_db
|
import backend.api.features.store.db as store_db
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
import backend.data.block
|
|
||||||
from backend.blocks import load_all_blocks
|
from backend.blocks import load_all_blocks
|
||||||
|
from backend.blocks._base import (
|
||||||
|
AnyBlockSchema,
|
||||||
|
BlockCategory,
|
||||||
|
BlockInfo,
|
||||||
|
BlockSchema,
|
||||||
|
BlockType,
|
||||||
|
)
|
||||||
from backend.blocks.llm import LlmModel
|
from backend.blocks.llm import LlmModel
|
||||||
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
|
||||||
from backend.data.db import query_raw_with_schema
|
from backend.data.db import query_raw_with_schema
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
@@ -22,7 +27,7 @@ from backend.util.models import Pagination
|
|||||||
from .model import (
|
from .model import (
|
||||||
BlockCategoryResponse,
|
BlockCategoryResponse,
|
||||||
BlockResponse,
|
BlockResponse,
|
||||||
BlockType,
|
BlockTypeFilter,
|
||||||
CountResponse,
|
CountResponse,
|
||||||
FilterType,
|
FilterType,
|
||||||
Provider,
|
Provider,
|
||||||
@@ -88,7 +93,7 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
|
|||||||
def get_blocks(
|
def get_blocks(
|
||||||
*,
|
*,
|
||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
type: BlockType | None = None,
|
type: BlockTypeFilter | None = None,
|
||||||
provider: ProviderName | None = None,
|
provider: ProviderName | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 50,
|
page_size: int = 50,
|
||||||
@@ -669,9 +674,9 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
|||||||
for block_type in load_all_blocks().values():
|
for block_type in load_all_blocks().values():
|
||||||
block: AnyBlockSchema = block_type()
|
block: AnyBlockSchema = block_type()
|
||||||
if block.disabled or block.block_type in (
|
if block.disabled or block.block_type in (
|
||||||
backend.data.block.BlockType.INPUT,
|
BlockType.INPUT,
|
||||||
backend.data.block.BlockType.OUTPUT,
|
BlockType.OUTPUT,
|
||||||
backend.data.block.BlockType.AGENT,
|
BlockType.AGENT,
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
# Find the execution count for this block
|
# Find the execution count for this block
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
import backend.api.features.library.model as library_model
|
import backend.api.features.library.model as library_model
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
from backend.data.block import BlockInfo
|
from backend.blocks._base import BlockInfo
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ FilterType = Literal[
|
|||||||
"my_agents",
|
"my_agents",
|
||||||
]
|
]
|
||||||
|
|
||||||
BlockType = Literal["all", "input", "action", "output"]
|
BlockTypeFilter = Literal["all", "input", "action", "output"]
|
||||||
|
|
||||||
|
|
||||||
class SearchEntry(BaseModel):
|
class SearchEntry(BaseModel):
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ async def get_block_categories(
|
|||||||
)
|
)
|
||||||
async def get_blocks(
|
async def get_blocks(
|
||||||
category: Annotated[str | None, fastapi.Query()] = None,
|
category: Annotated[str | None, fastapi.Query()] = None,
|
||||||
type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None,
|
type: Annotated[builder_model.BlockTypeFilter | None, fastapi.Query()] = None,
|
||||||
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
|
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
|
||||||
page: Annotated[int, fastapi.Query()] = 1,
|
page: Annotated[int, fastapi.Query()] = 1,
|
||||||
page_size: Annotated[int, fastapi.Query()] = 50,
|
page_size: Annotated[int, fastapi.Query()] = 50,
|
||||||
|
|||||||
@@ -27,12 +27,11 @@ class ChatConfig(BaseSettings):
|
|||||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||||
|
|
||||||
# Streaming Configuration
|
# Streaming Configuration
|
||||||
max_context_messages: int = Field(
|
|
||||||
default=50, ge=1, le=200, description="Maximum context messages"
|
|
||||||
)
|
|
||||||
|
|
||||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
max_retries: int = Field(
|
||||||
|
default=3,
|
||||||
|
description="Max retries for fallback path (SDK handles retries internally)",
|
||||||
|
)
|
||||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||||
max_agent_schedules: int = Field(
|
max_agent_schedules: int = Field(
|
||||||
default=30, description="Maximum number of agent schedules"
|
default=30, description="Maximum number of agent schedules"
|
||||||
@@ -93,6 +92,37 @@ class ChatConfig(BaseSettings):
|
|||||||
description="Name of the prompt in Langfuse to fetch",
|
description="Name of the prompt in Langfuse to fetch",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Claude Agent SDK Configuration
|
||||||
|
use_claude_agent_sdk: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Use Claude Agent SDK for chat completions",
|
||||||
|
)
|
||||||
|
claude_agent_model: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Model for the Claude Agent SDK path. If None, derives from "
|
||||||
|
"the `model` field by stripping the OpenRouter provider prefix.",
|
||||||
|
)
|
||||||
|
claude_agent_max_buffer_size: int = Field(
|
||||||
|
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
|
||||||
|
description="Max buffer size in bytes for Claude Agent SDK JSON message parsing. "
|
||||||
|
"Increase if tool outputs exceed the limit.",
|
||||||
|
)
|
||||||
|
claude_agent_max_subtasks: int = Field(
|
||||||
|
default=10,
|
||||||
|
description="Max number of sub-agent Tasks the SDK can spawn per session.",
|
||||||
|
)
|
||||||
|
claude_agent_use_resume: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Use --resume for multi-turn conversations instead of "
|
||||||
|
"history compression. Falls back to compression when unavailable.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extended thinking configuration for Claude models
|
||||||
|
thinking_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable adaptive thinking for Claude models via OpenRouter",
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("api_key", mode="before")
|
@field_validator("api_key", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_api_key(cls, v):
|
def get_api_key(cls, v):
|
||||||
@@ -132,6 +162,17 @@ class ChatConfig(BaseSettings):
|
|||||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@field_validator("use_claude_agent_sdk", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def get_use_claude_agent_sdk(cls, v):
|
||||||
|
"""Get use_claude_agent_sdk from environment if not provided."""
|
||||||
|
# Check environment variable - default to True if not set
|
||||||
|
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
||||||
|
if env_val:
|
||||||
|
return env_val in ("true", "1", "yes", "on")
|
||||||
|
# Default to True (SDK enabled by default)
|
||||||
|
return True if v is None else v
|
||||||
|
|
||||||
# Prompt paths for different contexts
|
# Prompt paths for different contexts
|
||||||
PROMPT_PATHS: dict[str, str] = {
|
PROMPT_PATHS: dict[str, str] = {
|
||||||
"default": "prompts/chat_system.md",
|
"default": "prompts/chat_system.md",
|
||||||
|
|||||||
@@ -45,10 +45,7 @@ async def create_chat_session(
|
|||||||
successfulAgentRuns=SafeJson({}),
|
successfulAgentRuns=SafeJson({}),
|
||||||
successfulAgentSchedules=SafeJson({}),
|
successfulAgentSchedules=SafeJson({}),
|
||||||
)
|
)
|
||||||
return await PrismaChatSession.prisma().create(
|
return await PrismaChatSession.prisma().create(data=data)
|
||||||
data=data,
|
|
||||||
include={"Messages": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def update_chat_session(
|
async def update_chat_session(
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
@@ -104,6 +104,26 @@ class ChatSession(BaseModel):
|
|||||||
successful_agent_runs: dict[str, int] = {}
|
successful_agent_runs: dict[str, int] = {}
|
||||||
successful_agent_schedules: dict[str, int] = {}
|
successful_agent_schedules: dict[str, int] = {}
|
||||||
|
|
||||||
|
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
||||||
|
"""Attach a tool_call to the current turn's assistant message.
|
||||||
|
|
||||||
|
Searches backwards for the most recent assistant message (stopping at
|
||||||
|
any user message boundary). If found, appends the tool_call to it.
|
||||||
|
Otherwise creates a new assistant message with the tool_call.
|
||||||
|
"""
|
||||||
|
for msg in reversed(self.messages):
|
||||||
|
if msg.role == "user":
|
||||||
|
break
|
||||||
|
if msg.role == "assistant":
|
||||||
|
if not msg.tool_calls:
|
||||||
|
msg.tool_calls = []
|
||||||
|
msg.tool_calls.append(tool_call)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.messages.append(
|
||||||
|
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def new(user_id: str) -> "ChatSession":
|
def new(user_id: str) -> "ChatSession":
|
||||||
return ChatSession(
|
return ChatSession(
|
||||||
@@ -172,6 +192,47 @@ class ChatSession(BaseModel):
|
|||||||
successful_agent_schedules=successful_agent_schedules,
|
successful_agent_schedules=successful_agent_schedules,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_consecutive_assistant_messages(
|
||||||
|
messages: list[ChatCompletionMessageParam],
|
||||||
|
) -> list[ChatCompletionMessageParam]:
|
||||||
|
"""Merge consecutive assistant messages into single messages.
|
||||||
|
|
||||||
|
Long-running tool flows can create split assistant messages: one with
|
||||||
|
text content and another with tool_calls. Anthropic's API requires
|
||||||
|
tool_result blocks to reference a tool_use in the immediately preceding
|
||||||
|
assistant message, so these splits cause 400 errors via OpenRouter.
|
||||||
|
"""
|
||||||
|
if len(messages) < 2:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
result: list[ChatCompletionMessageParam] = [messages[0]]
|
||||||
|
for msg in messages[1:]:
|
||||||
|
prev = result[-1]
|
||||||
|
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
|
||||||
|
result.append(msg)
|
||||||
|
continue
|
||||||
|
|
||||||
|
prev = cast(ChatCompletionAssistantMessageParam, prev)
|
||||||
|
curr = cast(ChatCompletionAssistantMessageParam, msg)
|
||||||
|
|
||||||
|
curr_content = curr.get("content") or ""
|
||||||
|
if curr_content:
|
||||||
|
prev_content = prev.get("content") or ""
|
||||||
|
prev["content"] = (
|
||||||
|
f"{prev_content}\n{curr_content}" if prev_content else curr_content
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_tool_calls = curr.get("tool_calls")
|
||||||
|
if curr_tool_calls:
|
||||||
|
prev_tool_calls = prev.get("tool_calls")
|
||||||
|
prev["tool_calls"] = (
|
||||||
|
list(prev_tool_calls) + list(curr_tool_calls)
|
||||||
|
if prev_tool_calls
|
||||||
|
else list(curr_tool_calls)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||||
messages = []
|
messages = []
|
||||||
for message in self.messages:
|
for message in self.messages:
|
||||||
@@ -258,7 +319,7 @@ class ChatSession(BaseModel):
|
|||||||
name=message.name or "",
|
name=message.name or "",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return messages
|
return self._merge_consecutive_assistant_messages(messages)
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||||
@@ -273,9 +334,8 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
|||||||
try:
|
try:
|
||||||
session = ChatSession.model_validate_json(raw_session)
|
session = ChatSession.model_validate_json(raw_session)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loading session {session_id} from cache: "
|
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
|
||||||
f"message_count={len(session.messages)}, "
|
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
|
||||||
f"roles={[m.role for m in session.messages]}"
|
|
||||||
)
|
)
|
||||||
return session
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -317,11 +377,9 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
messages = prisma_session.Messages
|
messages = prisma_session.Messages
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"Loading session {session_id} from DB: "
|
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
|
||||||
f"has_messages={messages is not None}, "
|
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
|
||||||
f"message_count={len(messages) if messages else 0}, "
|
|
||||||
f"roles={[m.role for m in messages] if messages else []}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatSession.from_db(prisma_session, messages)
|
return ChatSession.from_db(prisma_session, messages)
|
||||||
@@ -372,10 +430,9 @@ async def _save_session_to_db(
|
|||||||
"function_call": msg.function_call,
|
"function_call": msg.function_call,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
|
||||||
f"roles={[m['role'] for m in messages_data]}, "
|
f"roles={[m['role'] for m in messages_data]}"
|
||||||
f"start_sequence={existing_message_count}"
|
|
||||||
)
|
)
|
||||||
await chat_db.add_chat_messages_batch(
|
await chat_db.add_chat_messages_batch(
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
@@ -415,7 +472,7 @@ async def get_chat_session(
|
|||||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||||
|
|
||||||
# Fall back to database
|
# Fall back to database
|
||||||
logger.info(f"Session {session_id} not in cache, checking database")
|
logger.debug(f"Session {session_id} not in cache, checking database")
|
||||||
session = await _get_session_from_db(session_id)
|
session = await _get_session_from_db(session_id)
|
||||||
|
|
||||||
if session is None:
|
if session is None:
|
||||||
@@ -432,7 +489,6 @@ async def get_chat_session(
|
|||||||
# Cache the session from DB
|
# Cache the session from DB
|
||||||
try:
|
try:
|
||||||
await _cache_session(session)
|
await _cache_session(session)
|
||||||
logger.info(f"Cached session {session_id} from database")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||||
|
|
||||||
@@ -497,6 +553,40 @@ async def upsert_chat_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
||||||
|
"""Atomically append a message to a session and persist it.
|
||||||
|
|
||||||
|
Acquires the session lock, re-fetches the latest session state,
|
||||||
|
appends the message, and saves — preventing message loss when
|
||||||
|
concurrent requests modify the same session.
|
||||||
|
"""
|
||||||
|
lock = await _get_session_lock(session_id)
|
||||||
|
|
||||||
|
async with lock:
|
||||||
|
session = await get_chat_session(session_id)
|
||||||
|
if session is None:
|
||||||
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
|
session.messages.append(message)
|
||||||
|
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||||
|
session_id
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _save_session_to_db(session, existing_message_count)
|
||||||
|
except Exception as e:
|
||||||
|
raise DatabaseError(
|
||||||
|
f"Failed to persist message to session {session_id}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _cache_session(session)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def create_chat_session(user_id: str) -> ChatSession:
|
async def create_chat_session(user_id: str) -> ChatSession:
|
||||||
"""Create a new chat session and persist it.
|
"""Create a new chat session and persist it.
|
||||||
|
|
||||||
@@ -603,13 +693,19 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
|||||||
logger.warning(f"Session {session_id} not found for title update")
|
logger.warning(f"Session {session_id} not found for title update")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Invalidate cache so next fetch gets updated title
|
# Update title in cache if it exists (instead of invalidating).
|
||||||
|
# This prevents race conditions where cache invalidation causes
|
||||||
|
# the frontend to see stale DB data while streaming is still in progress.
|
||||||
try:
|
try:
|
||||||
redis_key = _get_session_cache_key(session_id)
|
cached = await _get_session_from_cache(session_id)
|
||||||
async_redis = await get_redis_async()
|
if cached:
|
||||||
await async_redis.delete(redis_key)
|
cached.title = title
|
||||||
|
await _cache_session(cached)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
# Not critical - title will be correct on next full cache refresh
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to update title in cache for session {session_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,4 +1,16 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionAssistantMessageParam,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionToolMessageParam,
|
||||||
|
ChatCompletionUserMessageParam,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||||
|
ChatCompletionMessageToolCallParam,
|
||||||
|
Function,
|
||||||
|
)
|
||||||
|
|
||||||
from .model import (
|
from .model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
@@ -117,3 +129,205 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
|||||||
loaded.tool_calls is not None
|
loaded.tool_calls is not None
|
||||||
), f"Tool calls missing for {orig.role} message"
|
), f"Tool calls missing for {orig.role} message"
|
||||||
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# _merge_consecutive_assistant_messages #
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
_tc = ChatCompletionMessageToolCallParam(
|
||||||
|
id="tc1", type="function", function=Function(name="do_stuff", arguments="{}")
|
||||||
|
)
|
||||||
|
_tc2 = ChatCompletionMessageToolCallParam(
|
||||||
|
id="tc2", type="function", function=Function(name="other", arguments="{}")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_noop_when_no_consecutive_assistants():
|
||||||
|
"""Messages without consecutive assistants are returned unchanged."""
|
||||||
|
msgs = [
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="hi"),
|
||||||
|
ChatCompletionAssistantMessageParam(role="assistant", content="hello"),
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="bye"),
|
||||||
|
]
|
||||||
|
merged = ChatSession._merge_consecutive_assistant_messages(msgs)
|
||||||
|
assert len(merged) == 3
|
||||||
|
assert [m["role"] for m in merged] == ["user", "assistant", "user"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_splits_text_and_tool_calls():
|
||||||
|
"""The exact bug scenario: text-only assistant followed by tool_calls-only assistant."""
|
||||||
|
msgs = [
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="build agent"),
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="Let me build that"
|
||||||
|
),
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="", tool_calls=[_tc]
|
||||||
|
),
|
||||||
|
ChatCompletionToolMessageParam(role="tool", content="ok", tool_call_id="tc1"),
|
||||||
|
]
|
||||||
|
merged = ChatSession._merge_consecutive_assistant_messages(msgs)
|
||||||
|
|
||||||
|
assert len(merged) == 3
|
||||||
|
assert merged[0]["role"] == "user"
|
||||||
|
assert merged[2]["role"] == "tool"
|
||||||
|
a = cast(ChatCompletionAssistantMessageParam, merged[1])
|
||||||
|
assert a["role"] == "assistant"
|
||||||
|
assert a.get("content") == "Let me build that"
|
||||||
|
assert a.get("tool_calls") == [_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_combines_tool_calls_from_both():
|
||||||
|
"""Both consecutive assistants have tool_calls — they get merged."""
|
||||||
|
msgs: list[ChatCompletionAssistantMessageParam] = [
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="text", tool_calls=[_tc]
|
||||||
|
),
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="", tool_calls=[_tc2]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
merged = ChatSession._merge_consecutive_assistant_messages(msgs) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert len(merged) == 1
|
||||||
|
a = cast(ChatCompletionAssistantMessageParam, merged[0])
|
||||||
|
assert a.get("tool_calls") == [_tc, _tc2]
|
||||||
|
assert a.get("content") == "text"
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_three_consecutive_assistants():
|
||||||
|
"""Three consecutive assistants collapse into one."""
|
||||||
|
msgs: list[ChatCompletionAssistantMessageParam] = [
|
||||||
|
ChatCompletionAssistantMessageParam(role="assistant", content="a"),
|
||||||
|
ChatCompletionAssistantMessageParam(role="assistant", content="b"),
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="", tool_calls=[_tc]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
merged = ChatSession._merge_consecutive_assistant_messages(msgs) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert len(merged) == 1
|
||||||
|
a = cast(ChatCompletionAssistantMessageParam, merged[0])
|
||||||
|
assert a.get("content") == "a\nb"
|
||||||
|
assert a.get("tool_calls") == [_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_empty_and_single_message():
|
||||||
|
"""Edge cases: empty list and single message."""
|
||||||
|
assert ChatSession._merge_consecutive_assistant_messages([]) == []
|
||||||
|
|
||||||
|
single: list[ChatCompletionMessageParam] = [
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="hi")
|
||||||
|
]
|
||||||
|
assert ChatSession._merge_consecutive_assistant_messages(single) == single
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# add_tool_call_to_current_turn #
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
_raw_tc = {
|
||||||
|
"id": "tc1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "f", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
_raw_tc2 = {
|
||||||
|
"id": "tc2",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "g", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_call_appends_to_existing_assistant():
|
||||||
|
"""When the last assistant is from the current turn, tool_call is added to it."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="user", content="hi"),
|
||||||
|
ChatMessage(role="assistant", content="working on it"),
|
||||||
|
]
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc)
|
||||||
|
|
||||||
|
assert len(session.messages) == 2 # no new message created
|
||||||
|
assert session.messages[1].tool_calls == [_raw_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_call_creates_assistant_when_none_exists():
|
||||||
|
"""When there's no current-turn assistant, a new one is created."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="user", content="hi"),
|
||||||
|
]
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc)
|
||||||
|
|
||||||
|
assert len(session.messages) == 2
|
||||||
|
assert session.messages[1].role == "assistant"
|
||||||
|
assert session.messages[1].tool_calls == [_raw_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_call_does_not_cross_user_boundary():
|
||||||
|
"""A user message acts as a boundary — previous assistant is not modified."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="assistant", content="old turn"),
|
||||||
|
ChatMessage(role="user", content="new message"),
|
||||||
|
]
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc)
|
||||||
|
|
||||||
|
assert len(session.messages) == 3 # new assistant was created
|
||||||
|
assert session.messages[0].tool_calls is None # old assistant untouched
|
||||||
|
assert session.messages[2].role == "assistant"
|
||||||
|
assert session.messages[2].tool_calls == [_raw_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_call_multiple_times():
|
||||||
|
"""Multiple long-running tool calls accumulate on the same assistant."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="user", content="hi"),
|
||||||
|
ChatMessage(role="assistant", content="doing stuff"),
|
||||||
|
]
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc)
|
||||||
|
# Simulate a pending tool result in between (like _yield_tool_call does)
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(role="tool", content="pending", tool_call_id="tc1")
|
||||||
|
)
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc2)
|
||||||
|
|
||||||
|
assert len(session.messages) == 3 # user, assistant, tool — no extra assistant
|
||||||
|
assert session.messages[1].tool_calls == [_raw_tc, _raw_tc2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_openai_messages_merges_split_assistants():
|
||||||
|
"""End-to-end: session with split assistants produces valid OpenAI messages."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="user", content="build agent"),
|
||||||
|
ChatMessage(role="assistant", content="Let me build that"),
|
||||||
|
ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "tc1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "create_agent", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ChatMessage(role="tool", content="done", tool_call_id="tc1"),
|
||||||
|
ChatMessage(role="assistant", content="Saved!"),
|
||||||
|
ChatMessage(role="user", content="show me an example run"),
|
||||||
|
]
|
||||||
|
openai_msgs = session.to_openai_messages()
|
||||||
|
|
||||||
|
# The two consecutive assistants at index 1,2 should be merged
|
||||||
|
roles = [m["role"] for m in openai_msgs]
|
||||||
|
assert roles == ["user", "assistant", "tool", "assistant", "user"]
|
||||||
|
|
||||||
|
# The merged assistant should have both content and tool_calls
|
||||||
|
merged = cast(ChatCompletionAssistantMessageParam, openai_msgs[1])
|
||||||
|
assert merged.get("content") == "Let me build that"
|
||||||
|
tc_list = merged.get("tool_calls")
|
||||||
|
assert tc_list is not None and len(list(tc_list)) == 1
|
||||||
|
assert list(tc_list)[0]["id"] == "tc1"
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.util.json import dumps as json_dumps
|
||||||
|
|
||||||
|
|
||||||
class ResponseType(str, Enum):
|
class ResponseType(str, Enum):
|
||||||
"""Types of streaming responses following AI SDK protocol."""
|
"""Types of streaming responses following AI SDK protocol."""
|
||||||
@@ -18,6 +20,10 @@ class ResponseType(str, Enum):
|
|||||||
START = "start"
|
START = "start"
|
||||||
FINISH = "finish"
|
FINISH = "finish"
|
||||||
|
|
||||||
|
# Step lifecycle (one LLM API call within a message)
|
||||||
|
START_STEP = "start-step"
|
||||||
|
FINISH_STEP = "finish-step"
|
||||||
|
|
||||||
# Text streaming
|
# Text streaming
|
||||||
TEXT_START = "text-start"
|
TEXT_START = "text-start"
|
||||||
TEXT_DELTA = "text-delta"
|
TEXT_DELTA = "text-delta"
|
||||||
@@ -57,6 +63,16 @@ class StreamStart(StreamBaseResponse):
|
|||||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_sse(self) -> str:
|
||||||
|
"""Convert to SSE format, excluding non-protocol fields like taskId."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"type": self.type.value,
|
||||||
|
"messageId": self.messageId,
|
||||||
|
}
|
||||||
|
return f"data: {json.dumps(data)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
class StreamFinish(StreamBaseResponse):
|
class StreamFinish(StreamBaseResponse):
|
||||||
"""End of message/stream."""
|
"""End of message/stream."""
|
||||||
@@ -64,6 +80,26 @@ class StreamFinish(StreamBaseResponse):
|
|||||||
type: ResponseType = ResponseType.FINISH
|
type: ResponseType = ResponseType.FINISH
|
||||||
|
|
||||||
|
|
||||||
|
class StreamStartStep(StreamBaseResponse):
|
||||||
|
"""Start of a step (one LLM API call within a message).
|
||||||
|
|
||||||
|
The AI SDK uses this to add a step-start boundary to message.parts,
|
||||||
|
enabling visual separation between multiple LLM calls in a single message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.START_STEP
|
||||||
|
|
||||||
|
|
||||||
|
class StreamFinishStep(StreamBaseResponse):
|
||||||
|
"""End of a step (one LLM API call within a message).
|
||||||
|
|
||||||
|
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
|
||||||
|
so the next LLM call in a tool-call continuation starts with clean state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.FINISH_STEP
|
||||||
|
|
||||||
|
|
||||||
# ========== Text Streaming ==========
|
# ========== Text Streaming ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -117,7 +153,7 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
|||||||
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
||||||
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
||||||
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
||||||
# Additional fields for internal use (not part of AI SDK spec but useful)
|
# Keep these for internal backend use
|
||||||
toolName: str | None = Field(
|
toolName: str | None = Field(
|
||||||
default=None, description="Name of the tool that was executed"
|
default=None, description="Name of the tool that was executed"
|
||||||
)
|
)
|
||||||
@@ -125,6 +161,17 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
|||||||
default=True, description="Whether the tool execution succeeded"
|
default=True, description="Whether the tool execution succeeded"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_sse(self) -> str:
|
||||||
|
"""Convert to SSE format, excluding non-spec fields."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"type": self.type.value,
|
||||||
|
"toolCallId": self.toolCallId,
|
||||||
|
"output": self.output,
|
||||||
|
}
|
||||||
|
return f"data: {json.dumps(data)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
# ========== Other ==========
|
# ========== Other ==========
|
||||||
|
|
||||||
@@ -148,6 +195,18 @@ class StreamError(StreamBaseResponse):
|
|||||||
default=None, description="Additional error details"
|
default=None, description="Additional error details"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_sse(self) -> str:
|
||||||
|
"""Convert to SSE format, only emitting fields required by AI SDK protocol.
|
||||||
|
|
||||||
|
The AI SDK uses z.strictObject({type, errorText}) which rejects
|
||||||
|
any extra fields like `code` or `details`.
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"type": self.type.value,
|
||||||
|
"errorText": self.errorText,
|
||||||
|
}
|
||||||
|
return f"data: {json_dumps(data)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
class StreamHeartbeat(StreamBaseResponse):
|
class StreamHeartbeat(StreamBaseResponse):
|
||||||
"""Heartbeat to keep SSE connection alive during long-running operations.
|
"""Heartbeat to keep SSE connection alive during long-running operations.
|
||||||
|
|||||||
@@ -1,23 +1,57 @@
|
|||||||
"""Chat API routes for chat session management and streaming via SSE."""
|
"""Chat API routes for chat session management and streaming via SSE."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid as uuid_module
|
import uuid as uuid_module
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from autogpt_libs import auth
|
from autogpt_libs import auth
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||||
|
|
||||||
from . import service as chat_service
|
from . import service as chat_service
|
||||||
from . import stream_registry
|
from . import stream_registry
|
||||||
from .completion_handler import process_operation_failure, process_operation_success
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
from .model import (
|
||||||
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
ChatMessage,
|
||||||
|
ChatSession,
|
||||||
|
append_and_save_message,
|
||||||
|
create_chat_session,
|
||||||
|
get_chat_session,
|
||||||
|
get_user_sessions,
|
||||||
|
)
|
||||||
|
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
|
||||||
|
from .sdk import service as sdk_service
|
||||||
|
from .tools.models import (
|
||||||
|
AgentDetailsResponse,
|
||||||
|
AgentOutputResponse,
|
||||||
|
AgentPreviewResponse,
|
||||||
|
AgentSavedResponse,
|
||||||
|
AgentsFoundResponse,
|
||||||
|
BlockDetailsResponse,
|
||||||
|
BlockListResponse,
|
||||||
|
BlockOutputResponse,
|
||||||
|
ClarificationNeededResponse,
|
||||||
|
DocPageResponse,
|
||||||
|
DocSearchResultsResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
ExecutionStartedResponse,
|
||||||
|
InputValidationErrorResponse,
|
||||||
|
NeedLoginResponse,
|
||||||
|
NoResultsResponse,
|
||||||
|
OperationInProgressResponse,
|
||||||
|
OperationPendingResponse,
|
||||||
|
OperationStartedResponse,
|
||||||
|
SetupRequirementsResponse,
|
||||||
|
UnderstandingUpdatedResponse,
|
||||||
|
)
|
||||||
|
from .tracking import track_user_message
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -209,6 +243,10 @@ async def get_session(
|
|||||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||||
session_id, user_id
|
session_id, user_id
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
|
||||||
|
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||||
|
)
|
||||||
if active_task:
|
if active_task:
|
||||||
# Filter out the in-progress assistant message from the session response.
|
# Filter out the in-progress assistant message from the session response.
|
||||||
# The client will receive the complete assistant response through the SSE
|
# The client will receive the complete assistant response through the SSE
|
||||||
@@ -266,12 +304,54 @@ async def stream_chat_post(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
stream_start_time = time.perf_counter()
|
||||||
|
log_meta = {"component": "ChatStream", "session_id": session_id}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
|
||||||
|
f"user={user_id}, message_len={len(request.message)}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Atomically append user message to session BEFORE creating task to avoid
|
||||||
|
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||||
|
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||||
|
# message loss from concurrent requests.
|
||||||
|
if request.message:
|
||||||
|
message = ChatMessage(
|
||||||
|
role="user" if request.is_user_message else "assistant",
|
||||||
|
content=request.message,
|
||||||
|
)
|
||||||
|
if request.is_user_message:
|
||||||
|
track_user_message(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
message_length=len(request.message),
|
||||||
|
)
|
||||||
|
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||||
|
session = await append_and_save_message(session_id, message)
|
||||||
|
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Create a task in the stream registry for reconnection support
|
||||||
task_id = str(uuid_module.uuid4())
|
task_id = str(uuid_module.uuid4())
|
||||||
operation_id = str(uuid_module.uuid4())
|
operation_id = str(uuid_module.uuid4())
|
||||||
|
log_meta["task_id"] = task_id
|
||||||
|
|
||||||
|
task_create_start = time.perf_counter()
|
||||||
await stream_registry.create_task(
|
await stream_registry.create_task(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -280,40 +360,151 @@ async def stream_chat_post(
|
|||||||
tool_name="chat",
|
tool_name="chat",
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Background task that runs the AI generation independently of SSE connection
|
# Background task that runs the AI generation independently of SSE connection
|
||||||
async def run_ai_generation():
|
async def run_ai_generation():
|
||||||
|
import time as time_module
|
||||||
|
|
||||||
|
gen_start_time = time_module.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
first_chunk_time, ttfc = None, None
|
||||||
|
chunk_count = 0
|
||||||
try:
|
try:
|
||||||
# Emit a start event with task_id for reconnection
|
# Emit a start event with task_id for reconnection
|
||||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
|
||||||
|
* 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
# Choose service based on LaunchDarkly flag (falls back to config default)
|
||||||
|
use_sdk = await is_feature_enabled(
|
||||||
|
Flag.COPILOT_SDK,
|
||||||
|
user_id or "anonymous",
|
||||||
|
default=config.use_claude_agent_sdk,
|
||||||
|
)
|
||||||
|
stream_fn = (
|
||||||
|
sdk_service.stream_chat_completion_sdk
|
||||||
|
if use_sdk
|
||||||
|
else chat_service.stream_chat_completion
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
# Pass message=None since we already added it to the session above
|
||||||
|
async for chunk in stream_fn(
|
||||||
session_id,
|
session_id,
|
||||||
request.message,
|
None, # Message already in session
|
||||||
is_user_message=request.is_user_message,
|
is_user_message=request.is_user_message,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
session=session, # Pass session with message already added
|
||||||
context=request.context,
|
context=request.context,
|
||||||
):
|
):
|
||||||
|
# Skip duplicate StreamStart — we already published one above
|
||||||
|
if isinstance(chunk, StreamStart):
|
||||||
|
continue
|
||||||
|
chunk_count += 1
|
||||||
|
if first_chunk_time is None:
|
||||||
|
first_chunk_time = time_module.perf_counter()
|
||||||
|
ttfc = first_chunk_time - gen_start_time
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunk_type": type(chunk).__name__,
|
||||||
|
"time_to_first_chunk_ms": ttfc * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
# Write to Redis (subscribers will receive via XREAD)
|
# Write to Redis (subscribers will receive via XREAD)
|
||||||
await stream_registry.publish_chunk(task_id, chunk)
|
await stream_registry.publish_chunk(task_id, chunk)
|
||||||
|
|
||||||
# Mark task as completed
|
gen_end_time = time_module.perf_counter()
|
||||||
|
total_time = (gen_end_time - gen_start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
|
||||||
|
f"task={task_id}, session={session_id}, "
|
||||||
|
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"time_to_first_chunk_ms": (
|
||||||
|
ttfc * 1000 if ttfc is not None else None
|
||||||
|
),
|
||||||
|
"n_chunks": chunk_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
await stream_registry.mark_task_completed(task_id, "completed")
|
await stream_registry.mark_task_completed(task_id, "completed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
elapsed = time_module.perf_counter() - gen_start_time
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error in background AI generation for session {session_id}: {e}"
|
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed * 1000,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
# Publish a StreamError so the frontend can display an error message
|
||||||
|
try:
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task_id,
|
||||||
|
StreamError(
|
||||||
|
errorText="An error occurred. Please try again.",
|
||||||
|
code="stream_error",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Best-effort; mark_task_completed will publish StreamFinish
|
||||||
await stream_registry.mark_task_completed(task_id, "failed")
|
await stream_registry.mark_task_completed(task_id, "failed")
|
||||||
|
|
||||||
# Start the AI generation in a background task
|
# Start the AI generation in a background task
|
||||||
bg_task = asyncio.create_task(run_ai_generation())
|
bg_task = asyncio.create_task(run_ai_generation())
|
||||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||||
|
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||||
|
)
|
||||||
|
|
||||||
# SSE endpoint that subscribes to the task's stream
|
# SSE endpoint that subscribes to the task's stream
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
import time as time_module
|
||||||
|
|
||||||
|
event_gen_start = time_module.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
||||||
|
f"user={user_id}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
subscriber_queue = None
|
subscriber_queue = None
|
||||||
|
first_chunk_yielded = False
|
||||||
|
chunks_yielded = 0
|
||||||
try:
|
try:
|
||||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
@@ -328,24 +519,78 @@ async def stream_chat_post(
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Read from the subscriber queue and yield to SSE
|
# Read from the subscriber queue and yield to SSE
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] Starting to read from subscriber_queue",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||||
|
chunks_yielded += 1
|
||||||
|
|
||||||
|
if not first_chunk_yielded:
|
||||||
|
first_chunk_yielded = True
|
||||||
|
elapsed = time_module.perf_counter() - event_gen_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
|
||||||
|
f"type={type(chunk).__name__}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunk_type": type(chunk).__name__,
|
||||||
|
"elapsed_ms": elapsed * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
yield chunk.to_sse()
|
yield chunk.to_sse()
|
||||||
|
|
||||||
# Check for finish signal
|
# Check for finish signal
|
||||||
if isinstance(chunk, StreamFinish):
|
if isinstance(chunk, StreamFinish):
|
||||||
|
total_time = time_module.perf_counter() - event_gen_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
|
||||||
|
f"n_chunks={chunks_yielded}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunks_yielded": chunks_yielded,
|
||||||
|
"total_time_ms": total_time * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
break
|
break
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# Send heartbeat to keep connection alive
|
|
||||||
yield StreamHeartbeat().to_sse()
|
yield StreamHeartbeat().to_sse()
|
||||||
|
|
||||||
except GeneratorExit:
|
except GeneratorExit:
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunks_yielded": chunks_yielded,
|
||||||
|
"reason": "client_disconnect",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
pass # Client disconnected - background task continues
|
pass # Client disconnected - background task continues
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||||
|
logger.error(
|
||||||
|
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Surface error to frontend so it doesn't appear stuck
|
||||||
|
yield StreamError(
|
||||||
|
errorText="An error occurred. Please try again.",
|
||||||
|
code="stream_error",
|
||||||
|
).to_sse()
|
||||||
|
yield StreamFinish().to_sse()
|
||||||
finally:
|
finally:
|
||||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
# Unsubscribe when client disconnects or stream ends
|
||||||
if subscriber_queue is not None:
|
if subscriber_queue is not None:
|
||||||
try:
|
try:
|
||||||
await stream_registry.unsubscribe_from_task(
|
await stream_registry.unsubscribe_from_task(
|
||||||
@@ -357,6 +602,18 @@ async def stream_chat_post(
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||||
|
total_time = time_module.perf_counter() - event_gen_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
||||||
|
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time * 1000,
|
||||||
|
"chunks_yielded": chunks_yielded,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
@@ -374,63 +631,90 @@ async def stream_chat_post(
|
|||||||
@router.get(
|
@router.get(
|
||||||
"/sessions/{session_id}/stream",
|
"/sessions/{session_id}/stream",
|
||||||
)
|
)
|
||||||
async def stream_chat_get(
|
async def resume_session_stream(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
|
||||||
user_id: str | None = Depends(auth.get_user_id),
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
is_user_message: bool = Query(default=True),
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Stream chat responses for a session (GET - legacy endpoint).
|
Resume an active stream for a session.
|
||||||
|
|
||||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
Called by the AI SDK's ``useChat(resume: true)`` on page load.
|
||||||
- Text fragments as they are generated
|
Checks for an active (in-progress) task on the session and either replays
|
||||||
- Tool call UI elements (if invoked)
|
the full SSE stream or returns 204 No Content if nothing is running.
|
||||||
- Tool execution results
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier to associate with the streamed messages.
|
session_id: The chat session identifier.
|
||||||
message: The user's new message to process.
|
|
||||||
user_id: Optional authenticated user ID.
|
user_id: Optional authenticated user ID.
|
||||||
is_user_message: Whether the message is a user message.
|
|
||||||
Returns:
|
|
||||||
StreamingResponse: SSE-formatted response chunks.
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamingResponse (SSE) when an active stream exists,
|
||||||
|
or 204 No Content when there is nothing to resume.
|
||||||
"""
|
"""
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
import asyncio
|
||||||
|
|
||||||
|
active_task, _last_id = await stream_registry.get_active_task_for_session(
|
||||||
|
session_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not active_task:
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
|
task_id=active_task.task_id,
|
||||||
|
user_id=user_id,
|
||||||
|
last_message_id="0-0", # Full replay so useChat rebuilds the message
|
||||||
|
)
|
||||||
|
|
||||||
|
if subscriber_queue is None:
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
first_chunk_type: str | None = None
|
first_chunk_type: str | None = None
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
try:
|
||||||
session_id,
|
while True:
|
||||||
message,
|
try:
|
||||||
is_user_message=is_user_message,
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||||
user_id=user_id,
|
if chunk_count < 3:
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
logger.info(
|
||||||
):
|
"Resume stream chunk",
|
||||||
if chunk_count < 3:
|
extra={
|
||||||
logger.info(
|
"session_id": session_id,
|
||||||
"Chat stream chunk",
|
"chunk_type": str(chunk.type),
|
||||||
extra={
|
},
|
||||||
"session_id": session_id,
|
)
|
||||||
"chunk_type": str(chunk.type),
|
if not first_chunk_type:
|
||||||
},
|
first_chunk_type = str(chunk.type)
|
||||||
|
chunk_count += 1
|
||||||
|
yield chunk.to_sse()
|
||||||
|
|
||||||
|
if isinstance(chunk, StreamFinish):
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
yield StreamHeartbeat().to_sse()
|
||||||
|
except GeneratorExit:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await stream_registry.unsubscribe_from_task(
|
||||||
|
active_task.task_id, subscriber_queue
|
||||||
)
|
)
|
||||||
if not first_chunk_type:
|
except Exception as unsub_err:
|
||||||
first_chunk_type = str(chunk.type)
|
logger.error(
|
||||||
chunk_count += 1
|
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
|
||||||
yield chunk.to_sse()
|
exc_info=True,
|
||||||
logger.info(
|
)
|
||||||
"Chat stream completed",
|
logger.info(
|
||||||
extra={
|
"Resume stream completed",
|
||||||
"session_id": session_id,
|
extra={
|
||||||
"chunk_count": chunk_count,
|
"session_id": session_id,
|
||||||
"first_chunk_type": first_chunk_type,
|
"n_chunks": chunk_count,
|
||||||
},
|
"first_chunk_type": first_chunk_type,
|
||||||
)
|
},
|
||||||
# AI SDK protocol termination
|
)
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
event_generator(),
|
||||||
@@ -438,8 +722,8 @@ async def stream_chat_get(
|
|||||||
headers={
|
headers={
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
"X-Accel-Buffering": "no",
|
||||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
"x-vercel-ai-ui-message-stream": "v1",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -550,8 +834,6 @@ async def stream_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
import asyncio
|
|
||||||
|
|
||||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
@@ -751,3 +1033,43 @@ async def health_check() -> dict:
|
|||||||
"service": "chat",
|
"service": "chat",
|
||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
|
||||||
|
|
||||||
|
ToolResponseUnion = (
|
||||||
|
AgentsFoundResponse
|
||||||
|
| NoResultsResponse
|
||||||
|
| AgentDetailsResponse
|
||||||
|
| SetupRequirementsResponse
|
||||||
|
| ExecutionStartedResponse
|
||||||
|
| NeedLoginResponse
|
||||||
|
| ErrorResponse
|
||||||
|
| InputValidationErrorResponse
|
||||||
|
| AgentOutputResponse
|
||||||
|
| UnderstandingUpdatedResponse
|
||||||
|
| AgentPreviewResponse
|
||||||
|
| AgentSavedResponse
|
||||||
|
| ClarificationNeededResponse
|
||||||
|
| BlockListResponse
|
||||||
|
| BlockDetailsResponse
|
||||||
|
| BlockOutputResponse
|
||||||
|
| DocSearchResultsResponse
|
||||||
|
| DocPageResponse
|
||||||
|
| OperationStartedResponse
|
||||||
|
| OperationPendingResponse
|
||||||
|
| OperationInProgressResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/schema/tool-responses",
|
||||||
|
response_model=ToolResponseUnion,
|
||||||
|
include_in_schema=True,
|
||||||
|
summary="[Dummy] Tool response type export for codegen",
|
||||||
|
description="This endpoint is not meant to be called. It exists solely to "
|
||||||
|
"expose tool response models in the OpenAPI schema for frontend codegen.",
|
||||||
|
)
|
||||||
|
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
|
||||||
|
"""Never called at runtime. Exists only so Orval generates TS types."""
|
||||||
|
raise HTTPException(status_code=501, detail="Schema-only endpoint")
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
"""Claude Agent SDK integration for CoPilot.
|
||||||
|
|
||||||
|
This module provides the integration layer between the Claude Agent SDK
|
||||||
|
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||||
|
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .service import stream_chat_completion_sdk
|
||||||
|
from .tool_adapter import create_copilot_mcp_server
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"stream_chat_completion_sdk",
|
||||||
|
"create_copilot_mcp_server",
|
||||||
|
]
|
||||||
@@ -0,0 +1,203 @@
|
|||||||
|
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||||
|
|
||||||
|
This module provides the adapter layer that converts streaming messages from
|
||||||
|
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
|
||||||
|
the frontend expects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from claude_agent_sdk import (
|
||||||
|
AssistantMessage,
|
||||||
|
Message,
|
||||||
|
ResultMessage,
|
||||||
|
SystemMessage,
|
||||||
|
TextBlock,
|
||||||
|
ToolResultBlock,
|
||||||
|
ToolUseBlock,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
from backend.api.features.chat.response_model import (
|
||||||
|
StreamBaseResponse,
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamFinishStep,
|
||||||
|
StreamStart,
|
||||||
|
StreamStartStep,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamTextEnd,
|
||||||
|
StreamTextStart,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolInputStart,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
)
|
||||||
|
from backend.api.features.chat.sdk.tool_adapter import (
|
||||||
|
MCP_TOOL_PREFIX,
|
||||||
|
pop_pending_tool_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SDKResponseAdapter:
|
||||||
|
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||||
|
|
||||||
|
This class maintains state during a streaming session to properly track
|
||||||
|
text blocks, tool calls, and message lifecycle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message_id: str | None = None):
|
||||||
|
self.message_id = message_id or str(uuid.uuid4())
|
||||||
|
self.text_block_id = str(uuid.uuid4())
|
||||||
|
self.has_started_text = False
|
||||||
|
self.has_ended_text = False
|
||||||
|
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
||||||
|
self.task_id: str | None = None
|
||||||
|
self.step_open = False
|
||||||
|
|
||||||
|
def set_task_id(self, task_id: str) -> None:
|
||||||
|
"""Set the task ID for reconnection support."""
|
||||||
|
self.task_id = task_id
|
||||||
|
|
||||||
|
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
|
||||||
|
"""Convert a single SDK message to Vercel AI SDK format."""
|
||||||
|
responses: list[StreamBaseResponse] = []
|
||||||
|
|
||||||
|
if isinstance(sdk_message, SystemMessage):
|
||||||
|
if sdk_message.subtype == "init":
|
||||||
|
responses.append(
|
||||||
|
StreamStart(messageId=self.message_id, taskId=self.task_id)
|
||||||
|
)
|
||||||
|
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
|
||||||
|
responses.append(StreamStartStep())
|
||||||
|
self.step_open = True
|
||||||
|
|
||||||
|
elif isinstance(sdk_message, AssistantMessage):
|
||||||
|
# After tool results, the SDK sends a new AssistantMessage for the
|
||||||
|
# next LLM turn. Open a new step if the previous one was closed.
|
||||||
|
if not self.step_open:
|
||||||
|
responses.append(StreamStartStep())
|
||||||
|
self.step_open = True
|
||||||
|
|
||||||
|
for block in sdk_message.content:
|
||||||
|
if isinstance(block, TextBlock):
|
||||||
|
if block.text:
|
||||||
|
self._ensure_text_started(responses)
|
||||||
|
responses.append(
|
||||||
|
StreamTextDelta(id=self.text_block_id, delta=block.text)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(block, ToolUseBlock):
|
||||||
|
self._end_text_if_open(responses)
|
||||||
|
|
||||||
|
# Strip MCP prefix so frontend sees "find_block"
|
||||||
|
# instead of "mcp__copilot__find_block".
|
||||||
|
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
|
||||||
|
|
||||||
|
responses.append(
|
||||||
|
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
|
||||||
|
)
|
||||||
|
responses.append(
|
||||||
|
StreamToolInputAvailable(
|
||||||
|
toolCallId=block.id,
|
||||||
|
toolName=tool_name,
|
||||||
|
input=block.input,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.current_tool_calls[block.id] = {"name": tool_name}
|
||||||
|
|
||||||
|
elif isinstance(sdk_message, UserMessage):
|
||||||
|
# UserMessage carries tool results back from tool execution.
|
||||||
|
content = sdk_message.content
|
||||||
|
blocks = content if isinstance(content, list) else []
|
||||||
|
for block in blocks:
|
||||||
|
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
||||||
|
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
||||||
|
tool_name = tool_info.get("name", "unknown")
|
||||||
|
|
||||||
|
# Prefer the stashed full output over the SDK's
|
||||||
|
# (potentially truncated) ToolResultBlock content.
|
||||||
|
# The SDK truncates large results, writing them to disk,
|
||||||
|
# which breaks frontend widget parsing.
|
||||||
|
output = pop_pending_tool_output(tool_name) or (
|
||||||
|
_extract_tool_output(block.content)
|
||||||
|
)
|
||||||
|
|
||||||
|
responses.append(
|
||||||
|
StreamToolOutputAvailable(
|
||||||
|
toolCallId=block.tool_use_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
output=output,
|
||||||
|
success=not (block.is_error or False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Close the current step after tool results — the next
|
||||||
|
# AssistantMessage will open a new step for the continuation.
|
||||||
|
if self.step_open:
|
||||||
|
responses.append(StreamFinishStep())
|
||||||
|
self.step_open = False
|
||||||
|
|
||||||
|
elif isinstance(sdk_message, ResultMessage):
|
||||||
|
self._end_text_if_open(responses)
|
||||||
|
# Close the step before finishing.
|
||||||
|
if self.step_open:
|
||||||
|
responses.append(StreamFinishStep())
|
||||||
|
self.step_open = False
|
||||||
|
|
||||||
|
if sdk_message.subtype == "success":
|
||||||
|
responses.append(StreamFinish())
|
||||||
|
elif sdk_message.subtype in ("error", "error_during_execution"):
|
||||||
|
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
|
||||||
|
responses.append(
|
||||||
|
StreamError(errorText=str(error_msg), code="sdk_error")
|
||||||
|
)
|
||||||
|
responses.append(StreamFinish())
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
|
||||||
|
)
|
||||||
|
responses.append(StreamFinish())
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
||||||
|
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
|
||||||
|
"""Start (or restart) a text block if needed."""
|
||||||
|
if not self.has_started_text or self.has_ended_text:
|
||||||
|
if self.has_ended_text:
|
||||||
|
self.text_block_id = str(uuid.uuid4())
|
||||||
|
self.has_ended_text = False
|
||||||
|
responses.append(StreamTextStart(id=self.text_block_id))
|
||||||
|
self.has_started_text = True
|
||||||
|
|
||||||
|
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
|
||||||
|
"""End the current text block if one is open."""
|
||||||
|
if self.has_started_text and not self.has_ended_text:
|
||||||
|
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||||
|
self.has_ended_text = True
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
||||||
|
"""Extract a string output from a ToolResultBlock's content field."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
||||||
|
if parts:
|
||||||
|
return "".join(parts)
|
||||||
|
try:
|
||||||
|
return json.dumps(content)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return str(content)
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
return json.dumps(content)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return str(content)
|
||||||
@@ -0,0 +1,366 @@
|
|||||||
|
"""Unit tests for the SDK response adapter."""
|
||||||
|
|
||||||
|
from claude_agent_sdk import (
|
||||||
|
AssistantMessage,
|
||||||
|
ResultMessage,
|
||||||
|
SystemMessage,
|
||||||
|
TextBlock,
|
||||||
|
ToolResultBlock,
|
||||||
|
ToolUseBlock,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
from backend.api.features.chat.response_model import (
|
||||||
|
StreamBaseResponse,
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamFinishStep,
|
||||||
|
StreamStart,
|
||||||
|
StreamStartStep,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamTextEnd,
|
||||||
|
StreamTextStart,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolInputStart,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .response_adapter import SDKResponseAdapter
|
||||||
|
from .tool_adapter import MCP_TOOL_PREFIX
|
||||||
|
|
||||||
|
|
||||||
|
def _adapter() -> SDKResponseAdapter:
|
||||||
|
a = SDKResponseAdapter(message_id="msg-1")
|
||||||
|
a.set_task_id("task-1")
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
# -- SystemMessage -----------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_system_init_emits_start_and_step():
|
||||||
|
adapter = _adapter()
|
||||||
|
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||||
|
assert len(results) == 2
|
||||||
|
assert isinstance(results[0], StreamStart)
|
||||||
|
assert results[0].messageId == "msg-1"
|
||||||
|
assert results[0].taskId == "task-1"
|
||||||
|
assert isinstance(results[1], StreamStartStep)
|
||||||
|
|
||||||
|
|
||||||
|
def test_system_non_init_emits_nothing():
|
||||||
|
adapter = _adapter()
|
||||||
|
results = adapter.convert_message(SystemMessage(subtype="other", data={}))
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
|
||||||
|
# -- AssistantMessage with TextBlock -----------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_block_emits_step_start_and_delta():
|
||||||
|
adapter = _adapter()
|
||||||
|
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
||||||
|
results = adapter.convert_message(msg)
|
||||||
|
assert len(results) == 3
|
||||||
|
assert isinstance(results[0], StreamStartStep)
|
||||||
|
assert isinstance(results[1], StreamTextStart)
|
||||||
|
assert isinstance(results[2], StreamTextDelta)
|
||||||
|
assert results[2].delta == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_text_block_emits_only_step():
|
||||||
|
adapter = _adapter()
|
||||||
|
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
|
||||||
|
results = adapter.convert_message(msg)
|
||||||
|
# Empty text skipped, but step still opens
|
||||||
|
assert len(results) == 1
|
||||||
|
assert isinstance(results[0], StreamStartStep)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_text_deltas_reuse_block_id():
|
||||||
|
adapter = _adapter()
|
||||||
|
msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test")
|
||||||
|
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
|
||||||
|
r1 = adapter.convert_message(msg1)
|
||||||
|
r2 = adapter.convert_message(msg2)
|
||||||
|
# First gets step+start+delta, second only delta (block & step already started)
|
||||||
|
assert len(r1) == 3
|
||||||
|
assert isinstance(r1[0], StreamStartStep)
|
||||||
|
assert isinstance(r1[1], StreamTextStart)
|
||||||
|
assert len(r2) == 1
|
||||||
|
assert isinstance(r2[0], StreamTextDelta)
|
||||||
|
assert r1[1].id == r2[0].id # same block ID
|
||||||
|
|
||||||
|
|
||||||
|
# -- AssistantMessage with ToolUseBlock --------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_use_emits_input_start_and_available():
|
||||||
|
"""Tool names arrive with MCP prefix and should be stripped for the frontend."""
|
||||||
|
adapter = _adapter()
|
||||||
|
msg = AssistantMessage(
|
||||||
|
content=[
|
||||||
|
ToolUseBlock(
|
||||||
|
id="tool-1",
|
||||||
|
name=f"{MCP_TOOL_PREFIX}find_agent",
|
||||||
|
input={"q": "x"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model="test",
|
||||||
|
)
|
||||||
|
results = adapter.convert_message(msg)
|
||||||
|
assert len(results) == 3
|
||||||
|
assert isinstance(results[0], StreamStartStep)
|
||||||
|
assert isinstance(results[1], StreamToolInputStart)
|
||||||
|
assert results[1].toolCallId == "tool-1"
|
||||||
|
assert results[1].toolName == "find_agent" # prefix stripped
|
||||||
|
assert isinstance(results[2], StreamToolInputAvailable)
|
||||||
|
assert results[2].toolName == "find_agent" # prefix stripped
|
||||||
|
assert results[2].input == {"q": "x"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_then_tool_ends_text_block():
|
||||||
|
adapter = _adapter()
|
||||||
|
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
|
||||||
|
tool_msg = AssistantMessage(
|
||||||
|
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
||||||
|
model="test",
|
||||||
|
)
|
||||||
|
adapter.convert_message(text_msg) # opens step + text
|
||||||
|
results = adapter.convert_message(tool_msg)
|
||||||
|
# Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable
|
||||||
|
assert len(results) == 3
|
||||||
|
assert isinstance(results[0], StreamTextEnd)
|
||||||
|
assert isinstance(results[1], StreamToolInputStart)
|
||||||
|
|
||||||
|
|
||||||
|
# -- UserMessage with ToolResultBlock ----------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_result_emits_output_and_finish_step():
|
||||||
|
adapter = _adapter()
|
||||||
|
# First register the tool call (opens step) — SDK sends prefixed name
|
||||||
|
tool_msg = AssistantMessage(
|
||||||
|
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})],
|
||||||
|
model="test",
|
||||||
|
)
|
||||||
|
adapter.convert_message(tool_msg)
|
||||||
|
|
||||||
|
# Now send tool result
|
||||||
|
result_msg = UserMessage(
|
||||||
|
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
|
||||||
|
)
|
||||||
|
results = adapter.convert_message(result_msg)
|
||||||
|
assert len(results) == 2
|
||||||
|
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||||
|
assert results[0].toolCallId == "t1"
|
||||||
|
assert results[0].toolName == "find_agent" # prefix stripped
|
||||||
|
assert results[0].output == "found 3 agents"
|
||||||
|
assert results[0].success is True
|
||||||
|
assert isinstance(results[1], StreamFinishStep)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_result_error():
|
||||||
|
adapter = _adapter()
|
||||||
|
adapter.convert_message(
|
||||||
|
AssistantMessage(
|
||||||
|
content=[
|
||||||
|
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={})
|
||||||
|
],
|
||||||
|
model="test",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result_msg = UserMessage(
|
||||||
|
content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)]
|
||||||
|
)
|
||||||
|
results = adapter.convert_message(result_msg)
|
||||||
|
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||||
|
assert results[0].success is False
|
||||||
|
assert isinstance(results[1], StreamFinishStep)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_result_list_content():
|
||||||
|
adapter = _adapter()
|
||||||
|
adapter.convert_message(
|
||||||
|
AssistantMessage(
|
||||||
|
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
||||||
|
model="test",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result_msg = UserMessage(
|
||||||
|
content=[
|
||||||
|
ToolResultBlock(
|
||||||
|
tool_use_id="t1",
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "line1"},
|
||||||
|
{"type": "text", "text": "line2"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
results = adapter.convert_message(result_msg)
|
||||||
|
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||||
|
assert results[0].output == "line1line2"
|
||||||
|
assert isinstance(results[1], StreamFinishStep)
|
||||||
|
|
||||||
|
|
||||||
|
def test_string_user_message_ignored():
|
||||||
|
"""A plain string UserMessage (not tool results) produces no output."""
|
||||||
|
adapter = _adapter()
|
||||||
|
results = adapter.convert_message(UserMessage(content="hello"))
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
|
||||||
|
# -- ResultMessage -----------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_result_success_emits_finish_step_and_finish():
|
||||||
|
adapter = _adapter()
|
||||||
|
# Start some text first (opens step)
|
||||||
|
adapter.convert_message(
|
||||||
|
AssistantMessage(content=[TextBlock(text="done")], model="test")
|
||||||
|
)
|
||||||
|
msg = ResultMessage(
|
||||||
|
subtype="success",
|
||||||
|
duration_ms=100,
|
||||||
|
duration_api_ms=50,
|
||||||
|
is_error=False,
|
||||||
|
num_turns=1,
|
||||||
|
session_id="s1",
|
||||||
|
)
|
||||||
|
results = adapter.convert_message(msg)
|
||||||
|
# TextEnd + FinishStep + StreamFinish
|
||||||
|
assert len(results) == 3
|
||||||
|
assert isinstance(results[0], StreamTextEnd)
|
||||||
|
assert isinstance(results[1], StreamFinishStep)
|
||||||
|
assert isinstance(results[2], StreamFinish)
|
||||||
|
|
||||||
|
|
||||||
|
def test_result_error_emits_error_and_finish():
|
||||||
|
adapter = _adapter()
|
||||||
|
msg = ResultMessage(
|
||||||
|
subtype="error",
|
||||||
|
duration_ms=100,
|
||||||
|
duration_api_ms=50,
|
||||||
|
is_error=True,
|
||||||
|
num_turns=0,
|
||||||
|
session_id="s1",
|
||||||
|
result="API rate limited",
|
||||||
|
)
|
||||||
|
results = adapter.convert_message(msg)
|
||||||
|
# No step was open, so no FinishStep — just Error + Finish
|
||||||
|
assert len(results) == 2
|
||||||
|
assert isinstance(results[0], StreamError)
|
||||||
|
assert "API rate limited" in results[0].errorText
|
||||||
|
assert isinstance(results[1], StreamFinish)
|
||||||
|
|
||||||
|
|
||||||
|
# -- Text after tools (new block ID) ----------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_after_tool_gets_new_block_id():
|
||||||
|
adapter = _adapter()
|
||||||
|
# Text -> Tool -> ToolResult -> Text should get a new text block ID and step
|
||||||
|
adapter.convert_message(
|
||||||
|
AssistantMessage(content=[TextBlock(text="before")], model="test")
|
||||||
|
)
|
||||||
|
adapter.convert_message(
|
||||||
|
AssistantMessage(
|
||||||
|
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
||||||
|
model="test",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Send tool result (closes step)
|
||||||
|
adapter.convert_message(
|
||||||
|
UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")])
|
||||||
|
)
|
||||||
|
results = adapter.convert_message(
|
||||||
|
AssistantMessage(content=[TextBlock(text="after")], model="test")
|
||||||
|
)
|
||||||
|
# Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta
|
||||||
|
assert len(results) == 3
|
||||||
|
assert isinstance(results[0], StreamStartStep)
|
||||||
|
assert isinstance(results[1], StreamTextStart)
|
||||||
|
assert isinstance(results[2], StreamTextDelta)
|
||||||
|
assert results[2].delta == "after"
|
||||||
|
|
||||||
|
|
||||||
|
# -- Full conversation flow --------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_conversation_flow():
|
||||||
|
"""Simulate a complete conversation: init -> text -> tool -> result -> text -> finish."""
|
||||||
|
adapter = _adapter()
|
||||||
|
all_responses: list[StreamBaseResponse] = []
|
||||||
|
|
||||||
|
# 1. Init
|
||||||
|
all_responses.extend(
|
||||||
|
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||||
|
)
|
||||||
|
# 2. Assistant text
|
||||||
|
all_responses.extend(
|
||||||
|
adapter.convert_message(
|
||||||
|
AssistantMessage(content=[TextBlock(text="Let me search")], model="test")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# 3. Tool use
|
||||||
|
all_responses.extend(
|
||||||
|
adapter.convert_message(
|
||||||
|
AssistantMessage(
|
||||||
|
content=[
|
||||||
|
ToolUseBlock(
|
||||||
|
id="t1",
|
||||||
|
name=f"{MCP_TOOL_PREFIX}find_agent",
|
||||||
|
input={"query": "email"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model="test",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# 4. Tool result
|
||||||
|
all_responses.extend(
|
||||||
|
adapter.convert_message(
|
||||||
|
UserMessage(
|
||||||
|
content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# 5. More text
|
||||||
|
all_responses.extend(
|
||||||
|
adapter.convert_message(
|
||||||
|
AssistantMessage(content=[TextBlock(text="I found 2")], model="test")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# 6. Result
|
||||||
|
all_responses.extend(
|
||||||
|
adapter.convert_message(
|
||||||
|
ResultMessage(
|
||||||
|
subtype="success",
|
||||||
|
duration_ms=500,
|
||||||
|
duration_api_ms=400,
|
||||||
|
is_error=False,
|
||||||
|
num_turns=2,
|
||||||
|
session_id="s1",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
types = [type(r).__name__ for r in all_responses]
|
||||||
|
assert types == [
|
||||||
|
"StreamStart",
|
||||||
|
"StreamStartStep", # step 1: text + tool call
|
||||||
|
"StreamTextStart",
|
||||||
|
"StreamTextDelta", # "Let me search"
|
||||||
|
"StreamTextEnd", # closed before tool
|
||||||
|
"StreamToolInputStart",
|
||||||
|
"StreamToolInputAvailable",
|
||||||
|
"StreamToolOutputAvailable", # tool result
|
||||||
|
"StreamFinishStep", # step 1 closed after tool result
|
||||||
|
"StreamStartStep", # step 2: continuation text
|
||||||
|
"StreamTextStart", # new block after tool
|
||||||
|
"StreamTextDelta", # "I found 2"
|
||||||
|
"StreamTextEnd", # closed by result
|
||||||
|
"StreamFinishStep", # step 2 closed
|
||||||
|
"StreamFinish",
|
||||||
|
]
|
||||||
@@ -0,0 +1,335 @@
|
|||||||
|
"""Security hooks for Claude Agent SDK integration.
|
||||||
|
|
||||||
|
This module provides security hooks that validate tool calls before execution,
|
||||||
|
ensuring multi-user isolation and preventing unauthorized operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from backend.api.features.chat.sdk.tool_adapter import MCP_TOOL_PREFIX
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Tools that are blocked entirely (CLI/system access).
|
||||||
|
# "Bash" (capital) is the SDK built-in — it's NOT in allowed_tools but blocked
|
||||||
|
# here as defence-in-depth. The agent uses mcp__copilot__bash_exec instead,
|
||||||
|
# which has kernel-level network isolation (unshare --net).
|
||||||
|
BLOCKED_TOOLS = {
|
||||||
|
"Bash",
|
||||||
|
"bash",
|
||||||
|
"shell",
|
||||||
|
"exec",
|
||||||
|
"terminal",
|
||||||
|
"command",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Tools allowed only when their path argument stays within the SDK workspace.
|
||||||
|
# The SDK uses these to handle oversized tool results (writes to tool-results/
|
||||||
|
# files, then reads them back) and for workspace file operations.
|
||||||
|
WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"}
|
||||||
|
|
||||||
|
# Dangerous patterns in tool inputs
|
||||||
|
DANGEROUS_PATTERNS = [
|
||||||
|
r"sudo",
|
||||||
|
r"rm\s+-rf",
|
||||||
|
r"dd\s+if=",
|
||||||
|
r"/etc/passwd",
|
||||||
|
r"/etc/shadow",
|
||||||
|
r"chmod\s+777",
|
||||||
|
r"curl\s+.*\|.*sh",
|
||||||
|
r"wget\s+.*\|.*sh",
|
||||||
|
r"eval\s*\(",
|
||||||
|
r"exec\s*\(",
|
||||||
|
r"__import__",
|
||||||
|
r"os\.system",
|
||||||
|
r"subprocess",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _deny(reason: str) -> dict[str, Any]:
|
||||||
|
"""Return a hook denial response."""
|
||||||
|
return {
|
||||||
|
"hookSpecificOutput": {
|
||||||
|
"hookEventName": "PreToolUse",
|
||||||
|
"permissionDecision": "deny",
|
||||||
|
"permissionDecisionReason": reason,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_workspace_path(
|
||||||
|
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Validate that a workspace-scoped tool only accesses allowed paths.
|
||||||
|
|
||||||
|
Allowed directories:
|
||||||
|
- The SDK working directory (``/tmp/copilot-<session>/``)
|
||||||
|
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
|
||||||
|
"""
|
||||||
|
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
||||||
|
if not path:
|
||||||
|
# Glob/Grep without a path default to cwd which is already sandboxed
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM
|
||||||
|
# naturally uses relative paths like "test.txt" instead of absolute ones).
|
||||||
|
# Tilde paths (~/) are home-dir references, not relative — expand first.
|
||||||
|
if path.startswith("~"):
|
||||||
|
resolved = os.path.realpath(os.path.expanduser(path))
|
||||||
|
elif not os.path.isabs(path) and sdk_cwd:
|
||||||
|
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
||||||
|
else:
|
||||||
|
resolved = os.path.realpath(path)
|
||||||
|
|
||||||
|
# Allow access within the SDK working directory
|
||||||
|
if sdk_cwd:
|
||||||
|
norm_cwd = os.path.realpath(sdk_cwd)
|
||||||
|
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
|
||||||
|
claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||||
|
tool_results_seg = os.sep + "tool-results" + os.sep
|
||||||
|
if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
|
||||||
|
)
|
||||||
|
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
|
||||||
|
return _deny(
|
||||||
|
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
|
||||||
|
f"directory.{workspace_hint} "
|
||||||
|
"This is enforced by the platform and cannot be bypassed."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_tool_access(
|
||||||
|
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Validate that a tool call is allowed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Empty dict to allow, or dict with hookSpecificOutput to deny
|
||||||
|
"""
|
||||||
|
# Block forbidden tools
|
||||||
|
if tool_name in BLOCKED_TOOLS:
|
||||||
|
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
||||||
|
return _deny(
|
||||||
|
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
|
||||||
|
"This is enforced by the platform and cannot be bypassed. "
|
||||||
|
"Use the CoPilot-specific MCP tools instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Workspace-scoped tools: allowed only within the SDK workspace directory
|
||||||
|
if tool_name in WORKSPACE_SCOPED_TOOLS:
|
||||||
|
return _validate_workspace_path(tool_name, tool_input, sdk_cwd)
|
||||||
|
|
||||||
|
# Check for dangerous patterns in tool input
|
||||||
|
# Use json.dumps for predictable format (str() produces Python repr)
|
||||||
|
input_str = json.dumps(tool_input) if tool_input else ""
|
||||||
|
|
||||||
|
for pattern in DANGEROUS_PATTERNS:
|
||||||
|
if re.search(pattern, input_str, re.IGNORECASE):
|
||||||
|
logger.warning(
|
||||||
|
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
|
||||||
|
)
|
||||||
|
return _deny(
|
||||||
|
"[SECURITY] Input contains a blocked pattern. "
|
||||||
|
"This is enforced by the platform and cannot be bypassed."
|
||||||
|
)
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_user_isolation(
|
||||||
|
tool_name: str, tool_input: dict[str, Any], user_id: str | None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Validate that tool calls respect user isolation."""
|
||||||
|
# For workspace file tools, ensure path doesn't escape
|
||||||
|
if "workspace" in tool_name.lower():
|
||||||
|
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||||
|
if path:
|
||||||
|
# Check for path traversal
|
||||||
|
if ".." in path or path.startswith("/"):
|
||||||
|
logger.warning(
|
||||||
|
f"Blocked path traversal attempt: {path} by user {user_id}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"hookSpecificOutput": {
|
||||||
|
"hookEventName": "PreToolUse",
|
||||||
|
"permissionDecision": "deny",
|
||||||
|
"permissionDecisionReason": "Path traversal not allowed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def create_security_hooks(
|
||||||
|
user_id: str | None,
|
||||||
|
sdk_cwd: str | None = None,
|
||||||
|
max_subtasks: int = 3,
|
||||||
|
on_stop: Callable[[str, str], None] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create the security hooks configuration for Claude Agent SDK.
|
||||||
|
|
||||||
|
Includes security validation and observability hooks:
|
||||||
|
- PreToolUse: Security validation before tool execution
|
||||||
|
- PostToolUse: Log successful tool executions
|
||||||
|
- PostToolUseFailure: Log and handle failed tool executions
|
||||||
|
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||||
|
- Stop: Capture transcript path for stateless resume (when *on_stop* is provided)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Current user ID for isolation validation
|
||||||
|
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
||||||
|
max_subtasks: Maximum Task (sub-agent) spawns allowed per session
|
||||||
|
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
|
||||||
|
the SDK finishes processing — used to read the JSONL transcript
|
||||||
|
before the CLI process exits.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hooks configuration dict for ClaudeAgentOptions
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from claude_agent_sdk import HookMatcher
|
||||||
|
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||||
|
|
||||||
|
# Per-session counter for Task sub-agent spawns
|
||||||
|
task_spawn_count = 0
|
||||||
|
|
||||||
|
async def pre_tool_use_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Combined pre-tool-use validation hook."""
|
||||||
|
nonlocal task_spawn_count
|
||||||
|
_ = context # unused but required by signature
|
||||||
|
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||||
|
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||||
|
|
||||||
|
# Rate-limit Task (sub-agent) spawns per session
|
||||||
|
if tool_name == "Task":
|
||||||
|
task_spawn_count += 1
|
||||||
|
if task_spawn_count > max_subtasks:
|
||||||
|
logger.warning(
|
||||||
|
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
|
||||||
|
)
|
||||||
|
return cast(
|
||||||
|
SyncHookJSONOutput,
|
||||||
|
_deny(
|
||||||
|
f"Maximum {max_subtasks} sub-tasks per session. "
|
||||||
|
"Please continue in the main conversation."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Strip MCP prefix for consistent validation
|
||||||
|
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
|
||||||
|
clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
|
||||||
|
|
||||||
|
# Only block non-CoPilot tools; our MCP-registered tools
|
||||||
|
# (including Read for oversized results) are already sandboxed.
|
||||||
|
if not is_copilot_tool:
|
||||||
|
result = _validate_tool_access(clean_name, tool_input, sdk_cwd)
|
||||||
|
if result:
|
||||||
|
return cast(SyncHookJSONOutput, result)
|
||||||
|
|
||||||
|
# Validate user isolation
|
||||||
|
result = _validate_user_isolation(clean_name, tool_input, user_id)
|
||||||
|
if result:
|
||||||
|
return cast(SyncHookJSONOutput, result)
|
||||||
|
|
||||||
|
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
async def post_tool_use_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Log successful tool executions for observability."""
|
||||||
|
_ = context
|
||||||
|
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||||
|
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
async def post_tool_failure_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Log failed tool executions for debugging."""
|
||||||
|
_ = context
|
||||||
|
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||||
|
error = input_data.get("error", "Unknown error")
|
||||||
|
logger.warning(
|
||||||
|
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
||||||
|
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||||
|
)
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
async def pre_compact_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Log when SDK triggers context compaction.
|
||||||
|
|
||||||
|
The SDK automatically compacts conversation history when it grows too large.
|
||||||
|
This hook provides visibility into when compaction happens.
|
||||||
|
"""
|
||||||
|
_ = context, tool_use_id
|
||||||
|
trigger = input_data.get("trigger", "auto")
|
||||||
|
logger.info(
|
||||||
|
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
||||||
|
)
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
# --- Stop hook: capture transcript path for stateless resume ---
|
||||||
|
async def stop_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Capture transcript path when SDK finishes processing.
|
||||||
|
|
||||||
|
The Stop hook fires while the CLI process is still alive, giving us
|
||||||
|
a reliable window to read the JSONL transcript before SIGTERM.
|
||||||
|
"""
|
||||||
|
_ = context, tool_use_id
|
||||||
|
transcript_path = cast(str, input_data.get("transcript_path", ""))
|
||||||
|
sdk_session_id = cast(str, input_data.get("session_id", ""))
|
||||||
|
|
||||||
|
if transcript_path and on_stop:
|
||||||
|
logger.info(
|
||||||
|
f"[SDK] Stop hook: transcript_path={transcript_path}, "
|
||||||
|
f"sdk_session_id={sdk_session_id[:12]}..."
|
||||||
|
)
|
||||||
|
on_stop(transcript_path, sdk_session_id)
|
||||||
|
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
hooks: dict[str, Any] = {
|
||||||
|
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
||||||
|
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
||||||
|
"PostToolUseFailure": [
|
||||||
|
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
|
||||||
|
],
|
||||||
|
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||||
|
}
|
||||||
|
|
||||||
|
if on_stop is not None:
|
||||||
|
hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])]
|
||||||
|
|
||||||
|
return hooks
|
||||||
|
except ImportError:
|
||||||
|
# Fallback for when SDK isn't available - return empty hooks
|
||||||
|
logger.warning("claude-agent-sdk not available, security hooks disabled")
|
||||||
|
return {}
|
||||||
@@ -0,0 +1,165 @@
|
|||||||
|
"""Unit tests for SDK security hooks."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
||||||
|
|
||||||
|
SDK_CWD = "/tmp/copilot-abc123"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_denied(result: dict) -> bool:
|
||||||
|
hook = result.get("hookSpecificOutput", {})
|
||||||
|
return hook.get("permissionDecision") == "deny"
|
||||||
|
|
||||||
|
|
||||||
|
# -- Blocked tools -----------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_blocked_tools_denied():
|
||||||
|
for tool in ("bash", "shell", "exec", "terminal", "command"):
|
||||||
|
result = _validate_tool_access(tool, {})
|
||||||
|
assert _is_denied(result), f"{tool} should be blocked"
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_tool_allowed():
|
||||||
|
result = _validate_tool_access("SomeCustomTool", {})
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
# -- Workspace-scoped tools --------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_within_workspace_allowed():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
|
||||||
|
)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_within_workspace_allowed():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
|
||||||
|
)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_within_workspace_allowed():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
|
||||||
|
)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_glob_within_workspace_allowed():
|
||||||
|
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_grep_within_workspace_allowed():
|
||||||
|
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_outside_workspace_denied():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
|
||||||
|
)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_outside_workspace_denied():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
|
||||||
|
)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_traversal_attack_denied():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Read",
|
||||||
|
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
|
||||||
|
sdk_cwd=SDK_CWD,
|
||||||
|
)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_path_allowed():
|
||||||
|
"""Glob/Grep without a path argument defaults to cwd — should pass."""
|
||||||
|
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_no_cwd_denies_absolute():
|
||||||
|
"""If no sdk_cwd is set, absolute paths are denied."""
|
||||||
|
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
# -- Tool-results directory --------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_tool_results_allowed():
|
||||||
|
home = os.path.expanduser("~")
|
||||||
|
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||||
|
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_claude_projects_without_tool_results_denied():
|
||||||
|
home = os.path.expanduser("~")
|
||||||
|
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||||
|
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_bash_builtin_always_blocked():
|
||||||
|
"""SDK built-in Bash is blocked — bash_exec MCP tool with bubblewrap is used instead."""
|
||||||
|
result = _validate_tool_access("Bash", {"command": "echo hello"}, sdk_cwd=SDK_CWD)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
# -- Dangerous patterns ------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_dangerous_pattern_blocked():
|
||||||
|
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_pattern_blocked():
|
||||||
|
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
# -- User isolation ----------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_workspace_path_traversal_blocked():
|
||||||
|
result = _validate_user_isolation(
|
||||||
|
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
|
||||||
|
)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_workspace_absolute_path_blocked():
|
||||||
|
result = _validate_user_isolation(
|
||||||
|
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
|
||||||
|
)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_workspace_normal_path_allowed():
|
||||||
|
result = _validate_user_isolation(
|
||||||
|
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
|
||||||
|
)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_workspace_tool_passes_isolation():
|
||||||
|
result = _validate_user_isolation(
|
||||||
|
"find_agent", {"query": "email"}, user_id="user-1"
|
||||||
|
)
|
||||||
|
assert result == {}
|
||||||
@@ -0,0 +1,751 @@
|
|||||||
|
"""Claude Agent SDK service layer for CoPilot chat completions."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
|
from .. import stream_registry
|
||||||
|
from ..config import ChatConfig
|
||||||
|
from ..model import (
|
||||||
|
ChatMessage,
|
||||||
|
ChatSession,
|
||||||
|
get_chat_session,
|
||||||
|
update_session_title,
|
||||||
|
upsert_chat_session,
|
||||||
|
)
|
||||||
|
from ..response_model import (
|
||||||
|
StreamBaseResponse,
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamStart,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
)
|
||||||
|
from ..service import (
|
||||||
|
_build_system_prompt,
|
||||||
|
_execute_long_running_tool_with_streaming,
|
||||||
|
_generate_session_title,
|
||||||
|
)
|
||||||
|
from ..tools.models import OperationPendingResponse, OperationStartedResponse
|
||||||
|
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||||
|
from ..tracking import track_user_message
|
||||||
|
from .response_adapter import SDKResponseAdapter
|
||||||
|
from .security_hooks import create_security_hooks
|
||||||
|
from .tool_adapter import (
|
||||||
|
COPILOT_TOOL_NAMES,
|
||||||
|
LongRunningCallback,
|
||||||
|
create_copilot_mcp_server,
|
||||||
|
set_execution_context,
|
||||||
|
)
|
||||||
|
from .transcript import (
|
||||||
|
download_transcript,
|
||||||
|
read_transcript_file,
|
||||||
|
upload_transcript,
|
||||||
|
validate_transcript,
|
||||||
|
write_transcript_to_tempfile,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
config = ChatConfig()
|
||||||
|
|
||||||
|
# Set to hold background tasks to prevent garbage collection
|
||||||
|
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CapturedTranscript:
|
||||||
|
"""Info captured by the SDK Stop hook for stateless --resume."""
|
||||||
|
|
||||||
|
path: str = ""
|
||||||
|
sdk_session_id: str = ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available(self) -> bool:
|
||||||
|
return bool(self.path)
|
||||||
|
|
||||||
|
|
||||||
|
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
|
||||||
|
|
||||||
|
# Appended to the system prompt to inform the agent about available tools.
|
||||||
|
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
|
||||||
|
# which has kernel-level network isolation (unshare --net).
|
||||||
|
_SDK_TOOL_SUPPLEMENT = """
|
||||||
|
|
||||||
|
## Tool notes
|
||||||
|
|
||||||
|
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||||
|
for shell commands — it runs in a network-isolated sandbox.
|
||||||
|
- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the
|
||||||
|
same working directory. Files created by one are readable by the other.
|
||||||
|
These files are **ephemeral** — they exist only for the current session.
|
||||||
|
- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file`
|
||||||
|
for files that should persist across sessions (stored in cloud storage).
|
||||||
|
- Long-running tools (create_agent, edit_agent, etc.) are handled
|
||||||
|
asynchronously. You will receive an immediate response; the actual result
|
||||||
|
is delivered to the user via a background stream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
|
||||||
|
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
|
||||||
|
|
||||||
|
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
|
||||||
|
existing background infrastructure: stream_registry (Redis Streams),
|
||||||
|
database persistence, and SSE reconnection. This means results survive
|
||||||
|
page refreshes / pod restarts, and the frontend shows the proper loading
|
||||||
|
widget with progress updates.
|
||||||
|
|
||||||
|
The returned callback matches the ``LongRunningCallback`` signature:
|
||||||
|
``(tool_name, args, session) -> MCP response dict``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _callback(
|
||||||
|
tool_name: str, args: dict[str, Any], session: ChatSession
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
operation_id = str(uuid.uuid4())
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
# --- Build user-friendly messages (matches non-SDK service) ---
|
||||||
|
if tool_name == "create_agent":
|
||||||
|
desc = args.get("description", "")
|
||||||
|
desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc
|
||||||
|
pending_msg = (
|
||||||
|
f"Creating your agent: {desc_preview}"
|
||||||
|
if desc_preview
|
||||||
|
else "Creating agent... This may take a few minutes."
|
||||||
|
)
|
||||||
|
started_msg = (
|
||||||
|
"Agent creation started. You can close this tab - "
|
||||||
|
"check your library in a few minutes."
|
||||||
|
)
|
||||||
|
elif tool_name == "edit_agent":
|
||||||
|
changes = args.get("changes", "")
|
||||||
|
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
|
||||||
|
pending_msg = (
|
||||||
|
f"Editing agent: {changes_preview}"
|
||||||
|
if changes_preview
|
||||||
|
else "Editing agent... This may take a few minutes."
|
||||||
|
)
|
||||||
|
started_msg = (
|
||||||
|
"Agent edit started. You can close this tab - "
|
||||||
|
"check your library in a few minutes."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pending_msg = f"Running {tool_name}... This may take a few minutes."
|
||||||
|
started_msg = (
|
||||||
|
f"{tool_name} started. You can close this tab - "
|
||||||
|
"check back in a few minutes."
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Register task in Redis for SSE reconnection ---
|
||||||
|
await stream_registry.create_task(
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
operation_id=operation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Save OperationPendingResponse to chat history ---
|
||||||
|
pending_message = ChatMessage(
|
||||||
|
role="tool",
|
||||||
|
content=OperationPendingResponse(
|
||||||
|
message=pending_msg,
|
||||||
|
operation_id=operation_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
).model_dump_json(),
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
)
|
||||||
|
session.messages.append(pending_message)
|
||||||
|
await upsert_chat_session(session)
|
||||||
|
|
||||||
|
# --- Spawn background task (reuses non-SDK infrastructure) ---
|
||||||
|
bg_task = asyncio.create_task(
|
||||||
|
_execute_long_running_tool_with_streaming(
|
||||||
|
tool_name=tool_name,
|
||||||
|
parameters=args,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_background_tasks.add(bg_task)
|
||||||
|
bg_task.add_done_callback(_background_tasks.discard)
|
||||||
|
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[SDK] Long-running tool {tool_name} delegated to background "
|
||||||
|
f"(operation_id={operation_id}, task_id={task_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Return OperationStartedResponse as MCP tool result ---
|
||||||
|
# This flows through SDK → response adapter → frontend, triggering
|
||||||
|
# the loading widget with SSE reconnection support.
|
||||||
|
started_json = OperationStartedResponse(
|
||||||
|
message=started_msg,
|
||||||
|
operation_id=operation_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
task_id=task_id,
|
||||||
|
).model_dump_json()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": started_json}],
|
||||||
|
"isError": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
return _callback
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_sdk_model() -> str | None:
|
||||||
|
"""Resolve the model name for the Claude Agent SDK CLI.
|
||||||
|
|
||||||
|
Uses ``config.claude_agent_model`` if set, otherwise derives from
|
||||||
|
``config.model`` by stripping the OpenRouter provider prefix (e.g.,
|
||||||
|
``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``).
|
||||||
|
"""
|
||||||
|
if config.claude_agent_model:
|
||||||
|
return config.claude_agent_model
|
||||||
|
model = config.model
|
||||||
|
if "/" in model:
|
||||||
|
return model.split("/", 1)[1]
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _build_sdk_env() -> dict[str, str]:
|
||||||
|
"""Build env vars for the SDK CLI process.
|
||||||
|
|
||||||
|
Routes API calls through OpenRouter (or a custom base_url) using
|
||||||
|
the same ``config.api_key`` / ``config.base_url`` as the non-SDK path.
|
||||||
|
This gives per-call token and cost tracking on the OpenRouter dashboard.
|
||||||
|
|
||||||
|
Only overrides ``ANTHROPIC_API_KEY`` when a valid proxy URL and auth
|
||||||
|
token are both present — otherwise returns an empty dict so the SDK
|
||||||
|
falls back to its default credentials.
|
||||||
|
"""
|
||||||
|
env: dict[str, str] = {}
|
||||||
|
if config.api_key and config.base_url:
|
||||||
|
# Strip /v1 suffix — SDK expects the base URL without a version path
|
||||||
|
base = config.base_url.rstrip("/")
|
||||||
|
if base.endswith("/v1"):
|
||||||
|
base = base[:-3]
|
||||||
|
if not base or not base.startswith("http"):
|
||||||
|
# Invalid base_url — don't override SDK defaults
|
||||||
|
return env
|
||||||
|
env["ANTHROPIC_BASE_URL"] = base
|
||||||
|
env["ANTHROPIC_AUTH_TOKEN"] = config.api_key
|
||||||
|
# Must be explicitly empty so the CLI uses AUTH_TOKEN instead
|
||||||
|
env["ANTHROPIC_API_KEY"] = ""
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sdk_cwd(session_id: str) -> str:
|
||||||
|
"""Create a safe, session-specific working directory path.
|
||||||
|
|
||||||
|
Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path`
|
||||||
|
(single source of truth for path sanitization) and adds a defence-in-depth
|
||||||
|
assertion.
|
||||||
|
"""
|
||||||
|
cwd = make_session_path(session_id)
|
||||||
|
# Defence-in-depth: normpath + startswith is a CodeQL-recognised sanitizer
|
||||||
|
cwd = os.path.normpath(cwd)
|
||||||
|
if not cwd.startswith(_SDK_CWD_PREFIX):
|
||||||
|
raise ValueError(f"SDK cwd escaped prefix: {cwd}")
|
||||||
|
return cwd
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||||
|
"""Remove SDK tool-result files for a specific session working directory.
|
||||||
|
|
||||||
|
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
|
||||||
|
We clean only the specific cwd's results to avoid race conditions between
|
||||||
|
concurrent sessions.
|
||||||
|
|
||||||
|
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
|
||||||
|
"""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# Validate cwd is under the expected prefix
|
||||||
|
normalized = os.path.normpath(cwd)
|
||||||
|
if not normalized.startswith(_SDK_CWD_PREFIX):
|
||||||
|
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# SDK encodes the cwd path by replacing '/' with '-'
|
||||||
|
encoded_cwd = normalized.replace("/", "-")
|
||||||
|
|
||||||
|
# Construct the project directory path (known-safe home expansion)
|
||||||
|
claude_projects = os.path.expanduser("~/.claude/projects")
|
||||||
|
project_dir = os.path.join(claude_projects, encoded_cwd)
|
||||||
|
|
||||||
|
# Security check 3: Validate project_dir is under ~/.claude/projects
|
||||||
|
project_dir = os.path.normpath(project_dir)
|
||||||
|
if not project_dir.startswith(claude_projects):
|
||||||
|
logger.warning(
|
||||||
|
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
results_dir = os.path.join(project_dir, "tool-results")
|
||||||
|
if os.path.isdir(results_dir):
|
||||||
|
for filename in os.listdir(results_dir):
|
||||||
|
file_path = os.path.join(results_dir, filename)
|
||||||
|
try:
|
||||||
|
if os.path.isfile(file_path):
|
||||||
|
os.remove(file_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Also clean up the temp cwd directory itself
|
||||||
|
try:
|
||||||
|
shutil.rmtree(normalized, ignore_errors=True)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def _compress_conversation_history(
|
||||||
|
session: ChatSession,
|
||||||
|
) -> list[ChatMessage]:
|
||||||
|
"""Compress prior conversation messages if they exceed the token threshold.
|
||||||
|
|
||||||
|
Uses the shared compress_context() from prompt.py which supports:
|
||||||
|
- LLM summarization of old messages (keeps recent ones intact)
|
||||||
|
- Progressive content truncation as fallback
|
||||||
|
- Middle-out deletion as last resort
|
||||||
|
|
||||||
|
Returns the compressed prior messages (everything except the current message).
|
||||||
|
"""
|
||||||
|
prior = session.messages[:-1]
|
||||||
|
if len(prior) < 2:
|
||||||
|
return prior
|
||||||
|
|
||||||
|
from backend.util.prompt import compress_context
|
||||||
|
|
||||||
|
# Convert ChatMessages to dicts for compress_context
|
||||||
|
messages_dict = []
|
||||||
|
for msg in prior:
|
||||||
|
msg_dict: dict[str, Any] = {"role": msg.role}
|
||||||
|
if msg.content:
|
||||||
|
msg_dict["content"] = msg.content
|
||||||
|
if msg.tool_calls:
|
||||||
|
msg_dict["tool_calls"] = msg.tool_calls
|
||||||
|
if msg.tool_call_id:
|
||||||
|
msg_dict["tool_call_id"] = msg.tool_call_id
|
||||||
|
messages_dict.append(msg_dict)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
async with openai.AsyncOpenAI(
|
||||||
|
api_key=config.api_key, base_url=config.base_url, timeout=30.0
|
||||||
|
) as client:
|
||||||
|
result = await compress_context(
|
||||||
|
messages=messages_dict,
|
||||||
|
model=config.model,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
|
||||||
|
# Fall back to truncation-only (no LLM summarization)
|
||||||
|
result = await compress_context(
|
||||||
|
messages=messages_dict,
|
||||||
|
model=config.model,
|
||||||
|
client=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.was_compacted:
|
||||||
|
logger.info(
|
||||||
|
f"[SDK] Context compacted: {result.original_token_count} -> "
|
||||||
|
f"{result.token_count} tokens "
|
||||||
|
f"({result.messages_summarized} summarized, "
|
||||||
|
f"{result.messages_dropped} dropped)"
|
||||||
|
)
|
||||||
|
# Convert compressed dicts back to ChatMessages
|
||||||
|
return [
|
||||||
|
ChatMessage(
|
||||||
|
role=m["role"],
|
||||||
|
content=m.get("content"),
|
||||||
|
tool_calls=m.get("tool_calls"),
|
||||||
|
tool_call_id=m.get("tool_call_id"),
|
||||||
|
)
|
||||||
|
for m in result.messages
|
||||||
|
]
|
||||||
|
|
||||||
|
return prior
|
||||||
|
|
||||||
|
|
||||||
|
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
||||||
|
"""Format conversation messages into a context prefix for the user message.
|
||||||
|
|
||||||
|
Returns a string like:
|
||||||
|
<conversation_history>
|
||||||
|
User: hello
|
||||||
|
You responded: Hi! How can I help?
|
||||||
|
</conversation_history>
|
||||||
|
|
||||||
|
Returns None if there are no messages to format.
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
lines: list[str] = []
|
||||||
|
for msg in messages:
|
||||||
|
if not msg.content:
|
||||||
|
continue
|
||||||
|
if msg.role == "user":
|
||||||
|
lines.append(f"User: {msg.content}")
|
||||||
|
elif msg.role == "assistant":
|
||||||
|
lines.append(f"You responded: {msg.content}")
|
||||||
|
# Skip tool messages — they're internal details
|
||||||
|
|
||||||
|
if not lines:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_chat_completion_sdk(
|
||||||
|
session_id: str,
|
||||||
|
message: str | None = None,
|
||||||
|
tool_call_response: str | None = None, # noqa: ARG001
|
||||||
|
is_user_message: bool = True,
|
||||||
|
user_id: str | None = None,
|
||||||
|
retry_count: int = 0, # noqa: ARG001
|
||||||
|
session: ChatSession | None = None,
|
||||||
|
context: dict[str, str] | None = None, # noqa: ARG001
|
||||||
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
|
"""Stream chat completion using Claude Agent SDK.
|
||||||
|
|
||||||
|
Drop-in replacement for stream_chat_completion with improved reliability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
raise NotFoundError(
|
||||||
|
f"Session {session_id} not found. Please create a new session first."
|
||||||
|
)
|
||||||
|
|
||||||
|
if message:
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="user" if is_user_message else "assistant", content=message
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if is_user_message:
|
||||||
|
track_user_message(
|
||||||
|
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||||
|
)
|
||||||
|
|
||||||
|
session = await upsert_chat_session(session)
|
||||||
|
|
||||||
|
# Generate title for new sessions (first user message)
|
||||||
|
if is_user_message and not session.title:
|
||||||
|
user_messages = [m for m in session.messages if m.role == "user"]
|
||||||
|
if len(user_messages) == 1:
|
||||||
|
first_message = user_messages[0].content or message or ""
|
||||||
|
if first_message:
|
||||||
|
task = asyncio.create_task(
|
||||||
|
_update_title_async(session_id, first_message, user_id)
|
||||||
|
)
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
|
# Build system prompt (reuses non-SDK path with Langfuse support)
|
||||||
|
has_history = len(session.messages) > 1
|
||||||
|
system_prompt, _ = await _build_system_prompt(
|
||||||
|
user_id, has_conversation_history=has_history
|
||||||
|
)
|
||||||
|
system_prompt += _SDK_TOOL_SUPPLEMENT
|
||||||
|
message_id = str(uuid.uuid4())
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
yield StreamStart(messageId=message_id, taskId=task_id)
|
||||||
|
|
||||||
|
stream_completed = False
|
||||||
|
# Initialise sdk_cwd before the try so the finally can reference it
|
||||||
|
# even if _make_sdk_cwd raises (in that case it stays as "").
|
||||||
|
sdk_cwd = ""
|
||||||
|
use_resume = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use a session-specific temp dir to avoid cleanup race conditions
|
||||||
|
# between concurrent sessions.
|
||||||
|
sdk_cwd = _make_sdk_cwd(session_id)
|
||||||
|
os.makedirs(sdk_cwd, exist_ok=True)
|
||||||
|
|
||||||
|
set_execution_context(
|
||||||
|
user_id,
|
||||||
|
session,
|
||||||
|
long_running_callback=_build_long_running_callback(user_id),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||||
|
|
||||||
|
# Fail fast when no API credentials are available at all
|
||||||
|
sdk_env = _build_sdk_env()
|
||||||
|
if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"):
|
||||||
|
raise RuntimeError(
|
||||||
|
"No API key configured. Set OPEN_ROUTER_API_KEY "
|
||||||
|
"(or CHAT_API_KEY) for OpenRouter routing, "
|
||||||
|
"or ANTHROPIC_API_KEY for direct Anthropic access."
|
||||||
|
)
|
||||||
|
|
||||||
|
mcp_server = create_copilot_mcp_server()
|
||||||
|
|
||||||
|
sdk_model = _resolve_sdk_model()
|
||||||
|
|
||||||
|
# --- Transcript capture via Stop hook ---
|
||||||
|
captured_transcript = CapturedTranscript()
|
||||||
|
|
||||||
|
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
|
||||||
|
captured_transcript.path = transcript_path
|
||||||
|
captured_transcript.sdk_session_id = sdk_session_id
|
||||||
|
|
||||||
|
security_hooks = create_security_hooks(
|
||||||
|
user_id,
|
||||||
|
sdk_cwd=sdk_cwd,
|
||||||
|
max_subtasks=config.claude_agent_max_subtasks,
|
||||||
|
on_stop=_on_stop if config.claude_agent_use_resume else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Resume strategy: download transcript from bucket ---
|
||||||
|
resume_file: str | None = None
|
||||||
|
use_resume = False
|
||||||
|
|
||||||
|
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||||
|
transcript_content = await download_transcript(user_id, session_id)
|
||||||
|
if transcript_content and validate_transcript(transcript_content):
|
||||||
|
resume_file = write_transcript_to_tempfile(
|
||||||
|
transcript_content, session_id, sdk_cwd
|
||||||
|
)
|
||||||
|
if resume_file:
|
||||||
|
use_resume = True
|
||||||
|
logger.info(
|
||||||
|
f"[SDK] Using --resume with transcript "
|
||||||
|
f"({len(transcript_content)} bytes)"
|
||||||
|
)
|
||||||
|
|
||||||
|
sdk_options_kwargs: dict[str, Any] = {
|
||||||
|
"system_prompt": system_prompt,
|
||||||
|
"mcp_servers": {"copilot": mcp_server},
|
||||||
|
"allowed_tools": COPILOT_TOOL_NAMES,
|
||||||
|
"disallowed_tools": ["Bash"],
|
||||||
|
"hooks": security_hooks,
|
||||||
|
"cwd": sdk_cwd,
|
||||||
|
"max_buffer_size": config.claude_agent_max_buffer_size,
|
||||||
|
}
|
||||||
|
if sdk_env:
|
||||||
|
sdk_options_kwargs["model"] = sdk_model
|
||||||
|
sdk_options_kwargs["env"] = sdk_env
|
||||||
|
if use_resume and resume_file:
|
||||||
|
sdk_options_kwargs["resume"] = resume_file
|
||||||
|
|
||||||
|
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
adapter = SDKResponseAdapter(message_id=message_id)
|
||||||
|
adapter.set_task_id(task_id)
|
||||||
|
|
||||||
|
async with ClaudeSDKClient(options=options) as client:
|
||||||
|
current_message = message or ""
|
||||||
|
if not current_message and session.messages:
|
||||||
|
last_user = [m for m in session.messages if m.role == "user"]
|
||||||
|
if last_user:
|
||||||
|
current_message = last_user[-1].content or ""
|
||||||
|
|
||||||
|
if not current_message.strip():
|
||||||
|
yield StreamError(
|
||||||
|
errorText="Message cannot be empty.",
|
||||||
|
code="empty_prompt",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build query: with --resume the CLI already has full
|
||||||
|
# context, so we only send the new message. Without
|
||||||
|
# resume, compress history into a context prefix.
|
||||||
|
query_message = current_message
|
||||||
|
if not use_resume and len(session.messages) > 1:
|
||||||
|
logger.warning(
|
||||||
|
f"[SDK] Using compression fallback for session "
|
||||||
|
f"{session_id} ({len(session.messages)} messages) — "
|
||||||
|
f"no transcript available for --resume"
|
||||||
|
)
|
||||||
|
compressed = await _compress_conversation_history(session)
|
||||||
|
history_context = _format_conversation_context(compressed)
|
||||||
|
if history_context:
|
||||||
|
query_message = (
|
||||||
|
f"{history_context}\n\n"
|
||||||
|
f"Now, the user says:\n{current_message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[SDK] Sending query ({len(session.messages)} msgs in session)"
|
||||||
|
)
|
||||||
|
logger.debug(f"[SDK] Query preview: {current_message[:80]!r}")
|
||||||
|
await client.query(query_message, session_id=session_id)
|
||||||
|
|
||||||
|
assistant_response = ChatMessage(role="assistant", content="")
|
||||||
|
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||||
|
has_appended_assistant = False
|
||||||
|
has_tool_results = False
|
||||||
|
|
||||||
|
async for sdk_msg in client.receive_messages():
|
||||||
|
logger.debug(
|
||||||
|
f"[SDK] Received: {type(sdk_msg).__name__} "
|
||||||
|
f"{getattr(sdk_msg, 'subtype', '')}"
|
||||||
|
)
|
||||||
|
for response in adapter.convert_message(sdk_msg):
|
||||||
|
if isinstance(response, StreamStart):
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield response
|
||||||
|
|
||||||
|
if isinstance(response, StreamTextDelta):
|
||||||
|
delta = response.delta or ""
|
||||||
|
# After tool results, start a new assistant
|
||||||
|
# message for the post-tool text.
|
||||||
|
if has_tool_results and has_appended_assistant:
|
||||||
|
assistant_response = ChatMessage(
|
||||||
|
role="assistant", content=delta
|
||||||
|
)
|
||||||
|
accumulated_tool_calls = []
|
||||||
|
has_appended_assistant = False
|
||||||
|
has_tool_results = False
|
||||||
|
session.messages.append(assistant_response)
|
||||||
|
has_appended_assistant = True
|
||||||
|
else:
|
||||||
|
assistant_response.content = (
|
||||||
|
assistant_response.content or ""
|
||||||
|
) + delta
|
||||||
|
if not has_appended_assistant:
|
||||||
|
session.messages.append(assistant_response)
|
||||||
|
has_appended_assistant = True
|
||||||
|
|
||||||
|
elif isinstance(response, StreamToolInputAvailable):
|
||||||
|
accumulated_tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": response.toolCallId,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": response.toolName,
|
||||||
|
"arguments": json.dumps(response.input or {}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assistant_response.tool_calls = accumulated_tool_calls
|
||||||
|
if not has_appended_assistant:
|
||||||
|
session.messages.append(assistant_response)
|
||||||
|
has_appended_assistant = True
|
||||||
|
|
||||||
|
elif isinstance(response, StreamToolOutputAvailable):
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="tool",
|
||||||
|
content=(
|
||||||
|
response.output
|
||||||
|
if isinstance(response.output, str)
|
||||||
|
else str(response.output)
|
||||||
|
),
|
||||||
|
tool_call_id=response.toolCallId,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
has_tool_results = True
|
||||||
|
|
||||||
|
elif isinstance(response, StreamFinish):
|
||||||
|
stream_completed = True
|
||||||
|
|
||||||
|
if stream_completed:
|
||||||
|
break
|
||||||
|
|
||||||
|
if (
|
||||||
|
assistant_response.content or assistant_response.tool_calls
|
||||||
|
) and not has_appended_assistant:
|
||||||
|
session.messages.append(assistant_response)
|
||||||
|
|
||||||
|
# --- Capture transcript while CLI is still alive ---
|
||||||
|
# Must happen INSIDE async with: close() sends SIGTERM
|
||||||
|
# which kills the CLI before it can flush the JSONL.
|
||||||
|
if (
|
||||||
|
config.claude_agent_use_resume
|
||||||
|
and user_id
|
||||||
|
and captured_transcript.available
|
||||||
|
):
|
||||||
|
# Give CLI time to flush JSONL writes before we read
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
raw_transcript = read_transcript_file(captured_transcript.path)
|
||||||
|
if raw_transcript:
|
||||||
|
task = asyncio.create_task(
|
||||||
|
_upload_transcript_bg(user_id, session_id, raw_transcript)
|
||||||
|
)
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
else:
|
||||||
|
logger.debug("[SDK] Stop hook fired but transcript not usable")
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"claude-agent-sdk is not installed. "
|
||||||
|
"Disable SDK mode (CHAT_USE_CLAUDE_AGENT_SDK=false) "
|
||||||
|
"to use the OpenAI-compatible fallback."
|
||||||
|
)
|
||||||
|
|
||||||
|
await upsert_chat_session(session)
|
||||||
|
logger.debug(
|
||||||
|
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
|
||||||
|
)
|
||||||
|
if not stream_completed:
|
||||||
|
yield StreamFinish()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
||||||
|
try:
|
||||||
|
await upsert_chat_session(session)
|
||||||
|
except Exception as save_err:
|
||||||
|
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
||||||
|
yield StreamError(
|
||||||
|
errorText="An error occurred. Please try again.",
|
||||||
|
code="sdk_error",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
finally:
|
||||||
|
if sdk_cwd:
|
||||||
|
_cleanup_sdk_tool_results(sdk_cwd)
|
||||||
|
|
||||||
|
|
||||||
|
async def _upload_transcript_bg(
|
||||||
|
user_id: str, session_id: str, raw_content: str
|
||||||
|
) -> None:
|
||||||
|
"""Background task to strip progress entries and upload transcript."""
|
||||||
|
try:
|
||||||
|
await upload_transcript(user_id, session_id, raw_content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_title_async(
|
||||||
|
session_id: str, message: str, user_id: str | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Background task to update session title."""
|
||||||
|
try:
|
||||||
|
title = await _generate_session_title(
|
||||||
|
message, user_id=user_id, session_id=session_id
|
||||||
|
)
|
||||||
|
if title:
|
||||||
|
await update_session_title(session_id, title)
|
||||||
|
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[SDK] Failed to update session title: {e}")
|
||||||
@@ -0,0 +1,325 @@
|
|||||||
|
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
|
||||||
|
|
||||||
|
This module provides the adapter layer that converts existing BaseTool implementations
|
||||||
|
into in-process MCP tools that can be used with the Claude Agent SDK.
|
||||||
|
|
||||||
|
Long-running tools (``is_long_running=True``) are delegated to the non-SDK
|
||||||
|
background infrastructure (stream_registry, Redis persistence, SSE reconnection)
|
||||||
|
via a callback provided by the service layer. This avoids wasteful SDK polling
|
||||||
|
and makes results survive page refreshes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tools import TOOL_REGISTRY
|
||||||
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
|
||||||
|
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
|
||||||
|
# in the path — prevents reading settings, credentials, or other sensitive files.
|
||||||
|
_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/")
|
||||||
|
|
||||||
|
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
|
||||||
|
MCP_SERVER_NAME = "copilot"
|
||||||
|
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
||||||
|
|
||||||
|
# Context variables to pass user/session info to tool execution
|
||||||
|
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
||||||
|
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||||
|
"current_session", default=None
|
||||||
|
)
|
||||||
|
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||||
|
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||||
|
# response adapter when it builds StreamToolOutputAvailable.
|
||||||
|
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
|
||||||
|
"pending_tool_outputs", default=None # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Callback type for delegating long-running tools to the non-SDK infrastructure.
|
||||||
|
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
|
||||||
|
LongRunningCallback = Callable[
|
||||||
|
[str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]]
|
||||||
|
]
|
||||||
|
|
||||||
|
# ContextVar so the service layer can inject the callback per-request.
|
||||||
|
_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar(
|
||||||
|
"long_running_callback", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_execution_context(
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
long_running_callback: LongRunningCallback | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set the execution context for tool calls.
|
||||||
|
|
||||||
|
This must be called before streaming begins to ensure tools have access
|
||||||
|
to user_id and session information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Current user's ID.
|
||||||
|
session: Current chat session.
|
||||||
|
long_running_callback: Optional callback to delegate long-running tools
|
||||||
|
to the non-SDK background infrastructure (stream_registry + Redis).
|
||||||
|
"""
|
||||||
|
_current_user_id.set(user_id)
|
||||||
|
_current_session.set(session)
|
||||||
|
_pending_tool_outputs.set({})
|
||||||
|
_long_running_callback.set(long_running_callback)
|
||||||
|
|
||||||
|
|
||||||
|
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||||
|
"""Get the current execution context."""
|
||||||
|
return (
|
||||||
|
_current_user_id.get(),
|
||||||
|
_current_session.get(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||||
|
"""Pop and return the stashed full output for *tool_name*.
|
||||||
|
|
||||||
|
The SDK CLI may truncate large tool results (writing them to disk and
|
||||||
|
replacing the content with a file reference). This stash keeps the
|
||||||
|
original MCP output so the response adapter can forward it to the
|
||||||
|
frontend for proper widget rendering.
|
||||||
|
|
||||||
|
Returns ``None`` if nothing was stashed for *tool_name*.
|
||||||
|
"""
|
||||||
|
pending = _pending_tool_outputs.get(None)
|
||||||
|
if pending is None:
|
||||||
|
return None
|
||||||
|
return pending.pop(tool_name, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_tool_sync(
|
||||||
|
base_tool: BaseTool,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
args: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Execute a tool synchronously and return MCP-formatted response."""
|
||||||
|
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||||
|
result = await base_tool.execute(
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
tool_call_id=effective_id,
|
||||||
|
**args,
|
||||||
|
)
|
||||||
|
|
||||||
|
text = (
|
||||||
|
result.output if isinstance(result.output, str) else json.dumps(result.output)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stash the full output before the SDK potentially truncates it.
|
||||||
|
pending = _pending_tool_outputs.get(None)
|
||||||
|
if pending is not None:
|
||||||
|
pending[base_tool.name] = text
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": text}],
|
||||||
|
"isError": not result.success,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _mcp_error(message: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": json.dumps({"error": message, "type": "error"})}
|
||||||
|
],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_tool_handler(base_tool: BaseTool):
|
||||||
|
"""Create an async handler function for a BaseTool.
|
||||||
|
|
||||||
|
This wraps the existing BaseTool._execute method to be compatible
|
||||||
|
with the Claude Agent SDK MCP tool format.
|
||||||
|
|
||||||
|
Long-running tools (``is_long_running=True``) are delegated to the
|
||||||
|
non-SDK background infrastructure via a callback set in the execution
|
||||||
|
context. The callback persists the operation in Redis (stream_registry)
|
||||||
|
so results survive page refreshes and pod restarts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Execute the wrapped tool and return MCP-formatted response."""
|
||||||
|
user_id, session = get_execution_context()
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
return _mcp_error("No session context available")
|
||||||
|
|
||||||
|
# --- Long-running: delegate to non-SDK background infrastructure ---
|
||||||
|
if base_tool.is_long_running:
|
||||||
|
callback = _long_running_callback.get(None)
|
||||||
|
if callback:
|
||||||
|
try:
|
||||||
|
return await callback(base_tool.name, args, session)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Long-running callback failed for {base_tool.name}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return _mcp_error(f"Failed to start {base_tool.name}: {e}")
|
||||||
|
# No callback — fall through to synchronous execution
|
||||||
|
logger.warning(
|
||||||
|
f"[SDK] No long-running callback for {base_tool.name}, "
|
||||||
|
f"executing synchronously (may block)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Normal (fast) tool: execute synchronously ---
|
||||||
|
try:
|
||||||
|
return await _execute_tool_sync(base_tool, user_id, session, args)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||||
|
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
||||||
|
|
||||||
|
return tool_handler
|
||||||
|
|
||||||
|
|
||||||
|
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
||||||
|
"""Build a JSON Schema input schema for a tool."""
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": base_tool.parameters.get("properties", {}),
|
||||||
|
"required": base_tool.parameters.get("required", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Read a file with optional offset/limit. Restricted to SDK working directory.
|
||||||
|
|
||||||
|
After reading, the file is deleted to prevent accumulation in long-running pods.
|
||||||
|
"""
|
||||||
|
file_path = args.get("file_path", "")
|
||||||
|
offset = args.get("offset", 0)
|
||||||
|
limit = args.get("limit", 2000)
|
||||||
|
|
||||||
|
# Security: only allow reads under ~/.claude/projects/**/tool-results/
|
||||||
|
real_path = os.path.realpath(file_path)
|
||||||
|
if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path:
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(real_path) as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
selected = lines[offset : offset + limit]
|
||||||
|
content = "".join(selected)
|
||||||
|
# Clean up to prevent accumulation in long-running pods
|
||||||
|
try:
|
||||||
|
os.remove(real_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
return {"content": [{"type": "text", "text": content}], "isError": False}
|
||||||
|
except FileNotFoundError:
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_READ_TOOL_NAME = "Read"
|
||||||
|
_READ_TOOL_DESCRIPTION = (
|
||||||
|
"Read a file from the local filesystem. "
|
||||||
|
"Use offset and limit to read specific line ranges for large files."
|
||||||
|
)
|
||||||
|
_READ_TOOL_SCHEMA = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file_path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The absolute path to the file to read",
|
||||||
|
},
|
||||||
|
"offset": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Line number to start reading from (0-indexed). Default: 0",
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Number of lines to read. Default: 2000",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["file_path"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Create the MCP server configuration
|
||||||
|
def create_copilot_mcp_server():
|
||||||
|
"""Create an in-process MCP server configuration for CoPilot tools.
|
||||||
|
|
||||||
|
This can be passed to ClaudeAgentOptions.mcp_servers.
|
||||||
|
|
||||||
|
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
|
||||||
|
package being available. This function returns the configuration that
|
||||||
|
can be used with the SDK.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||||
|
|
||||||
|
# Create decorated tool functions
|
||||||
|
sdk_tools = []
|
||||||
|
|
||||||
|
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||||
|
handler = create_tool_handler(base_tool)
|
||||||
|
decorated = tool(
|
||||||
|
tool_name,
|
||||||
|
base_tool.description,
|
||||||
|
_build_input_schema(base_tool),
|
||||||
|
)(handler)
|
||||||
|
sdk_tools.append(decorated)
|
||||||
|
|
||||||
|
# Add the Read tool so the SDK can read back oversized tool results
|
||||||
|
read_tool = tool(
|
||||||
|
_READ_TOOL_NAME,
|
||||||
|
_READ_TOOL_DESCRIPTION,
|
||||||
|
_READ_TOOL_SCHEMA,
|
||||||
|
)(_read_file_handler)
|
||||||
|
sdk_tools.append(read_tool)
|
||||||
|
|
||||||
|
server = create_sdk_mcp_server(
|
||||||
|
name=MCP_SERVER_NAME,
|
||||||
|
version="1.0.0",
|
||||||
|
tools=sdk_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
return server
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# Let ImportError propagate so service.py handles the fallback
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# SDK built-in tools allowed within the workspace directory.
|
||||||
|
# Security hooks validate that file paths stay within sdk_cwd.
|
||||||
|
# Bash is NOT included — use the sandboxed MCP bash_exec tool instead,
|
||||||
|
# which provides kernel-level network isolation via unshare --net.
|
||||||
|
# Task allows spawning sub-agents (rate-limited by security hooks).
|
||||||
|
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task"]
|
||||||
|
|
||||||
|
# List of tool names for allowed_tools configuration
|
||||||
|
# Include MCP tools, the MCP Read tool for oversized results,
|
||||||
|
# and SDK built-in file tools for workspace operations.
|
||||||
|
COPILOT_TOOL_NAMES = [
|
||||||
|
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||||
|
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||||
|
*_SDK_BUILTIN_TOOLS,
|
||||||
|
]
|
||||||
@@ -0,0 +1,355 @@
|
|||||||
|
"""JSONL transcript management for stateless multi-turn resume.
|
||||||
|
|
||||||
|
The Claude Code CLI persists conversations as JSONL files (one JSON object per
|
||||||
|
line). When the SDK's ``Stop`` hook fires we read this file, strip bloat
|
||||||
|
(progress entries, metadata), and upload the result to bucket storage. On the
|
||||||
|
next turn we download the transcript, write it to a temp file, and pass
|
||||||
|
``--resume`` so the CLI can reconstruct the full conversation.
|
||||||
|
|
||||||
|
Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
||||||
|
filesystem for self-hosted) — no DB column needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
|
||||||
|
_SAFE_ID_RE = re.compile(r"[^0-9a-fA-F-]")
|
||||||
|
|
||||||
|
# Entry types that can be safely removed from the transcript without breaking
|
||||||
|
# the parentUuid conversation tree that ``--resume`` relies on.
|
||||||
|
# - progress: UI progress ticks, no message content (avg 97KB for agent_progress)
|
||||||
|
# - file-history-snapshot: undo tracking metadata
|
||||||
|
# - queue-operation: internal queue bookkeeping
|
||||||
|
# - summary: session summaries
|
||||||
|
# - pr-link: PR link metadata
|
||||||
|
STRIPPABLE_TYPES = frozenset(
|
||||||
|
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Workspace storage constants — deterministic path from session_id.
|
||||||
|
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Progress stripping
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def strip_progress_entries(content: str) -> str:
|
||||||
|
"""Remove progress/metadata entries from a JSONL transcript.
|
||||||
|
|
||||||
|
Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents
|
||||||
|
any remaining child entries so the ``parentUuid`` chain stays intact.
|
||||||
|
Typically reduces transcript size by ~30%.
|
||||||
|
"""
|
||||||
|
lines = content.strip().split("\n")
|
||||||
|
|
||||||
|
entries: list[dict] = []
|
||||||
|
for line in lines:
|
||||||
|
try:
|
||||||
|
entries.append(json.loads(line))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Keep unparseable lines as-is (safety)
|
||||||
|
entries.append({"_raw": line})
|
||||||
|
|
||||||
|
stripped_uuids: set[str] = set()
|
||||||
|
uuid_to_parent: dict[str, str] = {}
|
||||||
|
kept: list[dict] = []
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
if "_raw" in entry:
|
||||||
|
kept.append(entry)
|
||||||
|
continue
|
||||||
|
uid = entry.get("uuid", "")
|
||||||
|
parent = entry.get("parentUuid", "")
|
||||||
|
entry_type = entry.get("type", "")
|
||||||
|
|
||||||
|
if uid:
|
||||||
|
uuid_to_parent[uid] = parent
|
||||||
|
|
||||||
|
if entry_type in STRIPPABLE_TYPES:
|
||||||
|
if uid:
|
||||||
|
stripped_uuids.add(uid)
|
||||||
|
else:
|
||||||
|
kept.append(entry)
|
||||||
|
|
||||||
|
# Reparent: walk up chain through stripped entries to find surviving ancestor
|
||||||
|
for entry in kept:
|
||||||
|
if "_raw" in entry:
|
||||||
|
continue
|
||||||
|
parent = entry.get("parentUuid", "")
|
||||||
|
original_parent = parent
|
||||||
|
while parent in stripped_uuids:
|
||||||
|
parent = uuid_to_parent.get(parent, "")
|
||||||
|
if parent != original_parent:
|
||||||
|
entry["parentUuid"] = parent
|
||||||
|
|
||||||
|
result_lines: list[str] = []
|
||||||
|
for entry in kept:
|
||||||
|
if "_raw" in entry:
|
||||||
|
result_lines.append(entry["_raw"])
|
||||||
|
else:
|
||||||
|
result_lines.append(json.dumps(entry, separators=(",", ":")))
|
||||||
|
|
||||||
|
return "\n".join(result_lines) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Local file I/O (read from CLI's JSONL, write temp file for --resume)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def read_transcript_file(transcript_path: str) -> str | None:
|
||||||
|
"""Read a JSONL transcript file from disk.
|
||||||
|
|
||||||
|
Returns the raw JSONL content, or ``None`` if the file is missing, empty,
|
||||||
|
or only contains metadata (≤2 lines with no conversation messages).
|
||||||
|
"""
|
||||||
|
if not transcript_path or not os.path.isfile(transcript_path):
|
||||||
|
logger.debug(f"[Transcript] File not found: {transcript_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(transcript_path) as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
if not content.strip():
|
||||||
|
logger.debug(f"[Transcript] Empty file: {transcript_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
lines = content.strip().split("\n")
|
||||||
|
if len(lines) < 2:
|
||||||
|
# Metadata-only files have 1 line (single queue-operation or snapshot).
|
||||||
|
logger.debug(
|
||||||
|
f"[Transcript] Too few lines ({len(lines)}): {transcript_path}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Quick structural validation — parse first and last lines.
|
||||||
|
json.loads(lines[0])
|
||||||
|
json.loads(lines[-1])
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[Transcript] Read {len(lines)} lines, "
|
||||||
|
f"{len(content)} bytes from {transcript_path}"
|
||||||
|
)
|
||||||
|
return content
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, OSError) as e:
|
||||||
|
logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||||
|
"""Sanitize an ID for safe use in file paths.
|
||||||
|
|
||||||
|
Session/user IDs are expected to be UUIDs (hex + hyphens). Strip
|
||||||
|
everything else and truncate to *max_len* so the result cannot introduce
|
||||||
|
path separators or other special characters.
|
||||||
|
"""
|
||||||
|
cleaned = _SAFE_ID_RE.sub("", raw_id or "")[:max_len]
|
||||||
|
return cleaned or "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||||
|
|
||||||
|
|
||||||
|
def write_transcript_to_tempfile(
|
||||||
|
transcript_content: str,
|
||||||
|
session_id: str,
|
||||||
|
cwd: str,
|
||||||
|
) -> str | None:
|
||||||
|
"""Write JSONL transcript to a temp file inside *cwd* for ``--resume``.
|
||||||
|
|
||||||
|
The file lives in the session working directory so it is cleaned up
|
||||||
|
automatically when the session ends.
|
||||||
|
|
||||||
|
Returns the absolute path to the file, or ``None`` on failure.
|
||||||
|
"""
|
||||||
|
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
|
||||||
|
real_cwd = os.path.realpath(cwd)
|
||||||
|
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
|
||||||
|
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.makedirs(real_cwd, exist_ok=True)
|
||||||
|
safe_id = _sanitize_id(session_id, max_len=8)
|
||||||
|
jsonl_path = os.path.realpath(
|
||||||
|
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
|
||||||
|
)
|
||||||
|
if not jsonl_path.startswith(real_cwd):
|
||||||
|
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
with open(jsonl_path, "w") as f:
|
||||||
|
f.write(transcript_content)
|
||||||
|
|
||||||
|
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||||
|
return jsonl_path
|
||||||
|
|
||||||
|
except OSError as e:
|
||||||
|
logger.warning(f"[Transcript] Failed to write resume file: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_transcript(content: str | None) -> bool:
|
||||||
|
"""Check that a transcript has actual conversation messages.
|
||||||
|
|
||||||
|
A valid transcript for resume needs at least one user message and one
|
||||||
|
assistant message (not just queue-operation / file-history-snapshot
|
||||||
|
metadata).
|
||||||
|
"""
|
||||||
|
if not content or not content.strip():
|
||||||
|
return False
|
||||||
|
|
||||||
|
lines = content.strip().split("\n")
|
||||||
|
if len(lines) < 2:
|
||||||
|
return False
|
||||||
|
|
||||||
|
has_user = False
|
||||||
|
has_assistant = False
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
try:
|
||||||
|
entry = json.loads(line)
|
||||||
|
msg_type = entry.get("type")
|
||||||
|
if msg_type == "user":
|
||||||
|
has_user = True
|
||||||
|
elif msg_type == "assistant":
|
||||||
|
has_assistant = True
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return has_user and has_assistant
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Bucket storage (GCS / local via WorkspaceStorageBackend)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||||
|
"""Return (workspace_id, file_id, filename) for a session's transcript.
|
||||||
|
|
||||||
|
Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl``
|
||||||
|
IDs are sanitized to hex+hyphen to prevent path traversal.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
TRANSCRIPT_STORAGE_PREFIX,
|
||||||
|
_sanitize_id(user_id),
|
||||||
|
f"{_sanitize_id(session_id)}.jsonl",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||||
|
"""Build the full storage path string that ``retrieve()`` expects.
|
||||||
|
|
||||||
|
``store()`` returns a path like ``gcs://bucket/workspaces/...`` or
|
||||||
|
``local://workspace_id/file_id/filename``. Since we use deterministic
|
||||||
|
arguments we can reconstruct the same path for download/delete without
|
||||||
|
having stored the return value.
|
||||||
|
"""
|
||||||
|
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||||
|
|
||||||
|
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
||||||
|
|
||||||
|
if isinstance(backend, GCSWorkspaceStorage):
|
||||||
|
blob = f"workspaces/{wid}/{fid}/{fname}"
|
||||||
|
return f"gcs://{backend.bucket_name}/{blob}"
|
||||||
|
else:
|
||||||
|
# LocalWorkspaceStorage returns local://{relative_path}
|
||||||
|
return f"local://{wid}/{fid}/{fname}"
|
||||||
|
|
||||||
|
|
||||||
|
async def upload_transcript(user_id: str, session_id: str, content: str) -> None:
|
||||||
|
"""Strip progress entries and upload transcript to bucket storage.
|
||||||
|
|
||||||
|
Safety: only overwrites when the new (stripped) transcript is larger than
|
||||||
|
what is already stored. Since JSONL is append-only, the latest transcript
|
||||||
|
is always the longest. This prevents a slow/stale background task from
|
||||||
|
clobbering a newer upload from a concurrent turn.
|
||||||
|
"""
|
||||||
|
from backend.util.workspace_storage import get_workspace_storage
|
||||||
|
|
||||||
|
stripped = strip_progress_entries(content)
|
||||||
|
if not validate_transcript(stripped):
|
||||||
|
logger.warning(
|
||||||
|
f"[Transcript] Skipping upload — stripped content is not a valid "
|
||||||
|
f"transcript for session {session_id}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
||||||
|
encoded = stripped.encode("utf-8")
|
||||||
|
new_size = len(encoded)
|
||||||
|
|
||||||
|
# Check existing transcript size to avoid overwriting newer with older
|
||||||
|
path = _build_storage_path(user_id, session_id, storage)
|
||||||
|
try:
|
||||||
|
existing = await storage.retrieve(path)
|
||||||
|
if len(existing) >= new_size:
|
||||||
|
logger.info(
|
||||||
|
f"[Transcript] Skipping upload — existing transcript "
|
||||||
|
f"({len(existing)}B) >= new ({new_size}B) for session "
|
||||||
|
f"{session_id}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except (FileNotFoundError, Exception):
|
||||||
|
pass # No existing transcript or retrieval error — proceed with upload
|
||||||
|
|
||||||
|
await storage.store(
|
||||||
|
workspace_id=wid,
|
||||||
|
file_id=fid,
|
||||||
|
filename=fname,
|
||||||
|
content=encoded,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[Transcript] Uploaded {new_size} bytes "
|
||||||
|
f"(stripped from {len(content)}) for session {session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def download_transcript(user_id: str, session_id: str) -> str | None:
|
||||||
|
"""Download transcript from bucket storage.
|
||||||
|
|
||||||
|
Returns the JSONL content string, or ``None`` if not found.
|
||||||
|
"""
|
||||||
|
from backend.util.workspace_storage import get_workspace_storage
|
||||||
|
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
path = _build_storage_path(user_id, session_id, storage)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await storage.retrieve(path)
|
||||||
|
content = data.decode("utf-8")
|
||||||
|
logger.info(
|
||||||
|
f"[Transcript] Downloaded {len(content)} bytes for session {session_id}"
|
||||||
|
)
|
||||||
|
return content
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[Transcript] Failed to download transcript: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||||
|
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
||||||
|
from backend.util.workspace_storage import get_workspace_storage
|
||||||
|
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
path = _build_storage_path(user_id, session_id, storage)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await storage.delete(path)
|
||||||
|
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[Transcript] Failed to delete transcript: {e}")
|
||||||
@@ -52,8 +52,10 @@ from .response_model import (
|
|||||||
StreamBaseResponse,
|
StreamBaseResponse,
|
||||||
StreamError,
|
StreamError,
|
||||||
StreamFinish,
|
StreamFinish,
|
||||||
|
StreamFinishStep,
|
||||||
StreamHeartbeat,
|
StreamHeartbeat,
|
||||||
StreamStart,
|
StreamStart,
|
||||||
|
StreamStartStep,
|
||||||
StreamTextDelta,
|
StreamTextDelta,
|
||||||
StreamTextEnd,
|
StreamTextEnd,
|
||||||
StreamTextStart,
|
StreamTextStart,
|
||||||
@@ -243,12 +245,16 @@ async def _get_system_prompt_template(context: str) -> str:
|
|||||||
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
||||||
|
|
||||||
|
|
||||||
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
async def _build_system_prompt(
|
||||||
|
user_id: str | None, has_conversation_history: bool = False
|
||||||
|
) -> tuple[str, Any]:
|
||||||
"""Build the full system prompt including business understanding if available.
|
"""Build the full system prompt including business understanding if available.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: The user ID for fetching business understanding
|
user_id: The user ID for fetching business understanding.
|
||||||
If "default" and this is the user's first session, will use "onboarding" instead.
|
has_conversation_history: Whether there's existing conversation history.
|
||||||
|
If True, we don't tell the model to greet/introduce (since they're
|
||||||
|
already in a conversation).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (compiled prompt string, business understanding object)
|
Tuple of (compiled prompt string, business understanding object)
|
||||||
@@ -264,6 +270,8 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
|||||||
|
|
||||||
if understanding:
|
if understanding:
|
||||||
context = format_understanding_for_prompt(understanding)
|
context = format_understanding_for_prompt(understanding)
|
||||||
|
elif has_conversation_history:
|
||||||
|
context = "No prior understanding saved yet. Continue the existing conversation naturally."
|
||||||
else:
|
else:
|
||||||
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||||
|
|
||||||
@@ -351,6 +359,10 @@ async def stream_chat_completion(
|
|||||||
retry_count: int = 0,
|
retry_count: int = 0,
|
||||||
session: ChatSession | None = None,
|
session: ChatSession | None = None,
|
||||||
context: dict[str, str] | None = None, # {url: str, content: str}
|
context: dict[str, str] | None = None, # {url: str, content: str}
|
||||||
|
_continuation_message_id: (
|
||||||
|
str | None
|
||||||
|
) = None, # Internal: reuse message ID for tool call continuations
|
||||||
|
_task_id: str | None = None, # Internal: task ID for SSE reconnection support
|
||||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
"""Main entry point for streaming chat completions with database handling.
|
"""Main entry point for streaming chat completions with database handling.
|
||||||
|
|
||||||
@@ -368,24 +380,47 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotFoundError: If session_id is invalid
|
NotFoundError: If session_id is invalid
|
||||||
ValueError: If max_context_messages is exceeded
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
completion_start = time.monotonic()
|
||||||
|
|
||||||
|
# Build log metadata for structured logging
|
||||||
|
log_meta = {"component": "ChatService", "session_id": session_id}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
|
||||||
|
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"message_len": len(message) if message else 0,
|
||||||
|
"is_user_message": is_user_message,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only fetch from Redis if session not provided (initial call)
|
# Only fetch from Redis if session not provided (initial call)
|
||||||
if session is None:
|
if session is None:
|
||||||
|
fetch_start = time.monotonic()
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
fetch_time = (time.monotonic() - fetch_start) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Fetched session from Redis: {session.session_id if session else 'None'}, "
|
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
|
||||||
f"message_count={len(session.messages) if session else 0}"
|
f"n_messages={len(session.messages) if session else 0}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": fetch_time,
|
||||||
|
"n_messages": len(session.messages) if session else 0,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using provided session object: {session.session_id}, "
|
f"[TIMING] Using provided session, messages={len(session.messages)}",
|
||||||
f"message_count={len(session.messages)}"
|
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||||
)
|
)
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
@@ -406,23 +441,32 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
# Track user message in PostHog
|
# Track user message in PostHog
|
||||||
if is_user_message:
|
if is_user_message:
|
||||||
|
posthog_start = time.monotonic()
|
||||||
track_user_message(
|
track_user_message(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
message_length=len(message),
|
message_length=len(message),
|
||||||
)
|
)
|
||||||
|
posthog_time = (time.monotonic() - posthog_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
upsert_start = time.monotonic()
|
||||||
f"Upserting session: {session.session_id} with user id {session.user_id}, "
|
|
||||||
f"message_count={len(session.messages)}"
|
|
||||||
)
|
|
||||||
session = await upsert_chat_session(session)
|
session = await upsert_chat_session(session)
|
||||||
|
upsert_time = (time.monotonic() - upsert_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
|
||||||
|
)
|
||||||
assert session, "Session not found"
|
assert session, "Session not found"
|
||||||
|
|
||||||
# Generate title for new sessions on first user message (non-blocking)
|
# Generate title for new sessions on first user message (non-blocking)
|
||||||
# Check: is_user_message, no title yet, and this is the first user message
|
# Check: is_user_message, no title yet, and this is the first user message
|
||||||
if is_user_message and message and not session.title:
|
user_messages = [m for m in session.messages if m.role == "user"]
|
||||||
user_messages = [m for m in session.messages if m.role == "user"]
|
first_user_msg = message or (user_messages[0].content if user_messages else None)
|
||||||
|
if is_user_message and first_user_msg and not session.title:
|
||||||
if len(user_messages) == 1:
|
if len(user_messages) == 1:
|
||||||
# First user message - generate title in background
|
# First user message - generate title in background
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -430,7 +474,7 @@ async def stream_chat_completion(
|
|||||||
# Capture only the values we need (not the session object) to avoid
|
# Capture only the values we need (not the session object) to avoid
|
||||||
# stale data issues when the main flow modifies the session
|
# stale data issues when the main flow modifies the session
|
||||||
captured_session_id = session_id
|
captured_session_id = session_id
|
||||||
captured_message = message
|
captured_message = first_user_msg
|
||||||
captured_user_id = user_id
|
captured_user_id = user_id
|
||||||
|
|
||||||
async def _update_title():
|
async def _update_title():
|
||||||
@@ -454,7 +498,13 @@ async def stream_chat_completion(
|
|||||||
asyncio.create_task(_update_title())
|
asyncio.create_task(_update_title())
|
||||||
|
|
||||||
# Build system prompt with business understanding
|
# Build system prompt with business understanding
|
||||||
|
prompt_start = time.monotonic()
|
||||||
system_prompt, understanding = await _build_system_prompt(user_id)
|
system_prompt, understanding = await _build_system_prompt(user_id)
|
||||||
|
prompt_time = (time.monotonic() - prompt_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize variables for streaming
|
# Initialize variables for streaming
|
||||||
assistant_response = ChatMessage(
|
assistant_response = ChatMessage(
|
||||||
@@ -479,13 +529,27 @@ async def stream_chat_completion(
|
|||||||
# Generate unique IDs for AI SDK protocol
|
# Generate unique IDs for AI SDK protocol
|
||||||
import uuid as uuid_module
|
import uuid as uuid_module
|
||||||
|
|
||||||
message_id = str(uuid_module.uuid4())
|
is_continuation = _continuation_message_id is not None
|
||||||
|
message_id = _continuation_message_id or str(uuid_module.uuid4())
|
||||||
text_block_id = str(uuid_module.uuid4())
|
text_block_id = str(uuid_module.uuid4())
|
||||||
|
|
||||||
# Yield message start
|
# Only yield message start for the initial call, not for continuations.
|
||||||
yield StreamStart(messageId=message_id)
|
setup_time = (time.monotonic() - completion_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||||
|
)
|
||||||
|
if not is_continuation:
|
||||||
|
yield StreamStart(messageId=message_id, taskId=_task_id)
|
||||||
|
|
||||||
|
# Emit start-step before each LLM call (AI SDK uses this to add step boundaries)
|
||||||
|
yield StreamStartStep()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] Calling _stream_chat_chunks",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
async for chunk in _stream_chat_chunks(
|
async for chunk in _stream_chat_chunks(
|
||||||
session=session,
|
session=session,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@@ -585,6 +649,10 @@ async def stream_chat_completion(
|
|||||||
)
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
elif isinstance(chunk, StreamFinish):
|
elif isinstance(chunk, StreamFinish):
|
||||||
|
if has_done_tool_call:
|
||||||
|
# Tool calls happened — close the step but don't send message-level finish.
|
||||||
|
# The continuation will open a new step, and finish will come at the end.
|
||||||
|
yield StreamFinishStep()
|
||||||
if not has_done_tool_call:
|
if not has_done_tool_call:
|
||||||
# Emit text-end before finish if we received text but haven't closed it
|
# Emit text-end before finish if we received text but haven't closed it
|
||||||
if has_received_text and not text_streaming_ended:
|
if has_received_text and not text_streaming_ended:
|
||||||
@@ -616,6 +684,8 @@ async def stream_chat_completion(
|
|||||||
has_saved_assistant_message = True
|
has_saved_assistant_message = True
|
||||||
|
|
||||||
has_yielded_end = True
|
has_yielded_end = True
|
||||||
|
# Emit finish-step before finish (resets AI SDK text/reasoning state)
|
||||||
|
yield StreamFinishStep()
|
||||||
yield chunk
|
yield chunk
|
||||||
elif isinstance(chunk, StreamError):
|
elif isinstance(chunk, StreamError):
|
||||||
has_yielded_error = True
|
has_yielded_error = True
|
||||||
@@ -665,6 +735,10 @@ async def stream_chat_completion(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
||||||
)
|
)
|
||||||
|
# Close the current step before retrying so the recursive call's
|
||||||
|
# StreamStartStep doesn't produce unbalanced step events.
|
||||||
|
if not has_yielded_end:
|
||||||
|
yield StreamFinishStep()
|
||||||
should_retry = True
|
should_retry = True
|
||||||
else:
|
else:
|
||||||
# Non-retryable error or max retries exceeded
|
# Non-retryable error or max retries exceeded
|
||||||
@@ -700,6 +774,7 @@ async def stream_chat_completion(
|
|||||||
error_response = StreamError(errorText=error_message)
|
error_response = StreamError(errorText=error_message)
|
||||||
yield error_response
|
yield error_response
|
||||||
if not has_yielded_end:
|
if not has_yielded_end:
|
||||||
|
yield StreamFinishStep()
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -714,6 +789,8 @@ async def stream_chat_completion(
|
|||||||
retry_count=retry_count + 1,
|
retry_count=retry_count + 1,
|
||||||
session=session,
|
session=session,
|
||||||
context=context,
|
context=context,
|
||||||
|
_continuation_message_id=message_id, # Reuse message ID since start was already sent
|
||||||
|
_task_id=_task_id,
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
return # Exit after retry to avoid double-saving in finally block
|
return # Exit after retry to avoid double-saving in finally block
|
||||||
@@ -729,9 +806,13 @@ async def stream_chat_completion(
|
|||||||
# Build the messages list in the correct order
|
# Build the messages list in the correct order
|
||||||
messages_to_save: list[ChatMessage] = []
|
messages_to_save: list[ChatMessage] = []
|
||||||
|
|
||||||
# Add assistant message with tool_calls if any
|
# Add assistant message with tool_calls if any.
|
||||||
|
# Use extend (not assign) to preserve tool_calls already added by
|
||||||
|
# _yield_tool_call for long-running tools.
|
||||||
if accumulated_tool_calls:
|
if accumulated_tool_calls:
|
||||||
assistant_response.tool_calls = accumulated_tool_calls
|
if not assistant_response.tool_calls:
|
||||||
|
assistant_response.tool_calls = []
|
||||||
|
assistant_response.tool_calls.extend(accumulated_tool_calls)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||||
)
|
)
|
||||||
@@ -783,6 +864,8 @@ async def stream_chat_completion(
|
|||||||
session=session, # Pass session object to avoid Redis refetch
|
session=session, # Pass session object to avoid Redis refetch
|
||||||
context=context,
|
context=context,
|
||||||
tool_call_response=str(tool_response_messages),
|
tool_call_response=str(tool_response_messages),
|
||||||
|
_continuation_message_id=message_id, # Reuse message ID to avoid duplicates
|
||||||
|
_task_id=_task_id,
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
@@ -893,9 +976,21 @@ async def _stream_chat_chunks(
|
|||||||
SSE formatted JSON response objects
|
SSE formatted JSON response objects
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import time as time_module
|
||||||
|
|
||||||
|
stream_chunks_start = time_module.perf_counter()
|
||||||
model = config.model
|
model = config.model
|
||||||
|
|
||||||
logger.info("Starting pure chat stream")
|
# Build log metadata for structured logging
|
||||||
|
log_meta = {"component": "ChatService", "session_id": session.session_id}
|
||||||
|
if session.user_id:
|
||||||
|
log_meta["user_id"] = session.user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
|
||||||
|
f"user={session.user_id}, n_messages={len(session.messages)}",
|
||||||
|
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||||
|
)
|
||||||
|
|
||||||
messages = session.to_openai_messages()
|
messages = session.to_openai_messages()
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
@@ -906,12 +1001,18 @@ async def _stream_chat_chunks(
|
|||||||
messages = [system_message] + messages
|
messages = [system_message] + messages
|
||||||
|
|
||||||
# Apply context window management
|
# Apply context window management
|
||||||
|
context_start = time_module.perf_counter()
|
||||||
context_result = await _manage_context_window(
|
context_result = await _manage_context_window(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=model,
|
model=model,
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
base_url=config.base_url,
|
base_url=config.base_url,
|
||||||
)
|
)
|
||||||
|
context_time = (time_module.perf_counter() - context_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
|
||||||
|
)
|
||||||
|
|
||||||
if context_result.error:
|
if context_result.error:
|
||||||
if "System prompt dropped" in context_result.error:
|
if "System prompt dropped" in context_result.error:
|
||||||
@@ -946,9 +1047,19 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
while retry_count <= MAX_RETRIES:
|
while retry_count <= MAX_RETRIES:
|
||||||
try:
|
try:
|
||||||
|
elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||||
|
retry_info = (
|
||||||
|
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Creating OpenAI chat completion stream..."
|
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
|
||||||
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"retry_count": retry_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build extra_body for OpenRouter tracing and PostHog analytics
|
# Build extra_body for OpenRouter tracing and PostHog analytics
|
||||||
@@ -965,6 +1076,11 @@ async def _stream_chat_chunks(
|
|||||||
:128
|
:128
|
||||||
] # OpenRouter limit
|
] # OpenRouter limit
|
||||||
|
|
||||||
|
# Enable adaptive thinking for Anthropic models via OpenRouter
|
||||||
|
if config.thinking_enabled and "anthropic" in model.lower():
|
||||||
|
extra_body["reasoning"] = {"enabled": True}
|
||||||
|
|
||||||
|
api_call_start = time_module.perf_counter()
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
@@ -974,6 +1090,11 @@ async def _stream_chat_chunks(
|
|||||||
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
|
||||||
|
)
|
||||||
|
|
||||||
# Variables to accumulate tool calls
|
# Variables to accumulate tool calls
|
||||||
tool_calls: list[dict[str, Any]] = []
|
tool_calls: list[dict[str, Any]] = []
|
||||||
@@ -984,10 +1105,13 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
# Track if we've started the text block
|
# Track if we've started the text block
|
||||||
text_started = False
|
text_started = False
|
||||||
|
first_content_chunk = True
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
# Process the stream
|
# Process the stream
|
||||||
chunk: ChatCompletionChunk
|
chunk: ChatCompletionChunk
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
|
chunk_count += 1
|
||||||
if chunk.usage:
|
if chunk.usage:
|
||||||
yield StreamUsage(
|
yield StreamUsage(
|
||||||
promptTokens=chunk.usage.prompt_tokens,
|
promptTokens=chunk.usage.prompt_tokens,
|
||||||
@@ -1010,6 +1134,23 @@ async def _stream_chat_chunks(
|
|||||||
if not text_started and text_block_id:
|
if not text_started and text_block_id:
|
||||||
yield StreamTextStart(id=text_block_id)
|
yield StreamTextStart(id=text_block_id)
|
||||||
text_started = True
|
text_started = True
|
||||||
|
# Log timing for first content chunk
|
||||||
|
if first_content_chunk:
|
||||||
|
first_content_chunk = False
|
||||||
|
ttfc = (
|
||||||
|
time_module.perf_counter() - api_call_start
|
||||||
|
) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
|
||||||
|
f"(since API call), n_chunks={chunk_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"time_to_first_chunk_ms": ttfc,
|
||||||
|
"n_chunks": chunk_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
# Stream the text delta
|
# Stream the text delta
|
||||||
text_response = StreamTextDelta(
|
text_response = StreamTextDelta(
|
||||||
id=text_block_id or "",
|
id=text_block_id or "",
|
||||||
@@ -1066,7 +1207,21 @@ async def _stream_chat_chunks(
|
|||||||
toolName=tool_calls[idx]["function"]["name"],
|
toolName=tool_calls[idx]["function"]["name"],
|
||||||
)
|
)
|
||||||
emitted_start_for_idx.add(idx)
|
emitted_start_for_idx.add(idx)
|
||||||
logger.info(f"Stream complete. Finish reason: {finish_reason}")
|
stream_duration = time_module.perf_counter() - api_call_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
|
||||||
|
f"duration={stream_duration:.2f}s, "
|
||||||
|
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"stream_duration_ms": stream_duration * 1000,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
"n_chunks": chunk_count,
|
||||||
|
"n_tool_calls": len(tool_calls),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Yield all accumulated tool calls after the stream is complete
|
# Yield all accumulated tool calls after the stream is complete
|
||||||
# This ensures all tool call arguments have been fully received
|
# This ensures all tool call arguments have been fully received
|
||||||
@@ -1086,6 +1241,12 @@ async def _stream_chat_chunks(
|
|||||||
# Re-raise to trigger retry logic in the parent function
|
# Re-raise to trigger retry logic in the parent function
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
|
||||||
|
f"session={session.session_id}, user={session.user_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||||
|
)
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1253,13 +1414,9 @@ async def _yield_tool_call(
|
|||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save assistant message with tool_call FIRST (required by LLM)
|
# Attach the tool_call to the current turn's assistant message
|
||||||
assistant_message = ChatMessage(
|
# (or create one if this is a tool-only response with no text).
|
||||||
role="assistant",
|
session.add_tool_call_to_current_turn(tool_calls[yield_idx])
|
||||||
content="",
|
|
||||||
tool_calls=[tool_calls[yield_idx]],
|
|
||||||
)
|
|
||||||
session.messages.append(assistant_message)
|
|
||||||
|
|
||||||
# Then save pending tool result
|
# Then save pending tool result
|
||||||
pending_message = ChatMessage(
|
pending_message = ChatMessage(
|
||||||
@@ -1565,6 +1722,7 @@ async def _execute_long_running_tool_with_streaming(
|
|||||||
task_id,
|
task_id,
|
||||||
StreamError(errorText=str(e)),
|
StreamError(errorText=str(e)),
|
||||||
)
|
)
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||||
|
|
||||||
await _update_pending_operation(
|
await _update_pending_operation(
|
||||||
@@ -1681,6 +1839,10 @@ async def _generate_llm_continuation(
|
|||||||
if session_id:
|
if session_id:
|
||||||
extra_body["session_id"] = session_id[:128]
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
|
# Enable adaptive thinking for Anthropic models via OpenRouter
|
||||||
|
if config.thinking_enabled and "anthropic" in config.model.lower():
|
||||||
|
extra_body["reasoning"] = {"enabled": True}
|
||||||
|
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
last_error: Exception | None = None
|
last_error: Exception | None = None
|
||||||
response = None
|
response = None
|
||||||
@@ -1811,6 +1973,10 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
if session_id:
|
if session_id:
|
||||||
extra_body["session_id"] = session_id[:128]
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
|
# Enable adaptive thinking for Anthropic models via OpenRouter
|
||||||
|
if config.thinking_enabled and "anthropic" in config.model.lower():
|
||||||
|
extra_body["reasoning"] = {"enabled": True}
|
||||||
|
|
||||||
# Make streaming LLM call (no tools - just text response)
|
# Make streaming LLM call (no tools - just text response)
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
@@ -1822,6 +1988,7 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
|
|
||||||
# Publish start event
|
# Publish start event
|
||||||
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamStartStep())
|
||||||
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
||||||
|
|
||||||
# Stream the response
|
# Stream the response
|
||||||
@@ -1845,6 +2012,7 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
|
|
||||||
# Publish end events
|
# Publish end events
|
||||||
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||||
|
|
||||||
if assistant_content:
|
if assistant_content:
|
||||||
# Reload session from DB to avoid race condition with user messages
|
# Reload session from DB to avoid race condition with user messages
|
||||||
@@ -1886,4 +2054,5 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
task_id,
|
task_id,
|
||||||
StreamError(errorText=f"Failed to generate response: {e}"),
|
StreamError(errorText=f"Failed to generate response: {e}"),
|
||||||
)
|
)
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from os import getenv
|
from os import getenv
|
||||||
|
|
||||||
@@ -11,6 +12,8 @@ from .response_model import (
|
|||||||
StreamTextDelta,
|
StreamTextDelta,
|
||||||
StreamToolOutputAvailable,
|
StreamToolOutputAvailable,
|
||||||
)
|
)
|
||||||
|
from .sdk import service as sdk_service
|
||||||
|
from .sdk.transcript import download_transcript
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -80,3 +83,96 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
|
|||||||
session = await get_chat_session(session.session_id)
|
session = await get_chat_session(session.session_id)
|
||||||
assert session, "Session not found"
|
assert session, "Session not found"
|
||||||
assert session.usage, "Usage is empty"
|
assert session.usage, "Usage is empty"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||||
|
"""Test that the SDK --resume path captures and uses transcripts across turns.
|
||||||
|
|
||||||
|
Turn 1: Send a message containing a unique keyword.
|
||||||
|
Turn 2: Ask the model to recall that keyword — proving the transcript was
|
||||||
|
persisted and restored via --resume.
|
||||||
|
"""
|
||||||
|
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||||
|
|
||||||
|
from .config import ChatConfig
|
||||||
|
|
||||||
|
cfg = ChatConfig()
|
||||||
|
if not cfg.claude_agent_use_resume:
|
||||||
|
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
|
||||||
|
|
||||||
|
session = await create_chat_session(test_user_id)
|
||||||
|
session = await upsert_chat_session(session)
|
||||||
|
|
||||||
|
# --- Turn 1: send a message with a unique keyword ---
|
||||||
|
keyword = "ZEPHYR42"
|
||||||
|
turn1_msg = (
|
||||||
|
f"Please remember this special keyword: {keyword}. "
|
||||||
|
"Just confirm you've noted it, keep your response brief."
|
||||||
|
)
|
||||||
|
turn1_text = ""
|
||||||
|
turn1_errors: list[str] = []
|
||||||
|
turn1_ended = False
|
||||||
|
|
||||||
|
async for chunk in sdk_service.stream_chat_completion_sdk(
|
||||||
|
session.session_id,
|
||||||
|
turn1_msg,
|
||||||
|
user_id=test_user_id,
|
||||||
|
):
|
||||||
|
if isinstance(chunk, StreamTextDelta):
|
||||||
|
turn1_text += chunk.delta
|
||||||
|
elif isinstance(chunk, StreamError):
|
||||||
|
turn1_errors.append(chunk.errorText)
|
||||||
|
elif isinstance(chunk, StreamFinish):
|
||||||
|
turn1_ended = True
|
||||||
|
|
||||||
|
assert turn1_ended, "Turn 1 did not finish"
|
||||||
|
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
|
||||||
|
assert turn1_text, "Turn 1 produced no text"
|
||||||
|
|
||||||
|
# Wait for background upload task to complete (retry up to 5s)
|
||||||
|
transcript = None
|
||||||
|
for _ in range(10):
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
transcript = await download_transcript(test_user_id, session.session_id)
|
||||||
|
if transcript:
|
||||||
|
break
|
||||||
|
assert transcript, (
|
||||||
|
"Transcript was not uploaded to bucket after turn 1 — "
|
||||||
|
"Stop hook may not have fired or transcript was too small"
|
||||||
|
)
|
||||||
|
logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes")
|
||||||
|
|
||||||
|
# Reload session for turn 2
|
||||||
|
session = await get_chat_session(session.session_id, test_user_id)
|
||||||
|
assert session, "Session not found after turn 1"
|
||||||
|
|
||||||
|
# --- Turn 2: ask model to recall the keyword ---
|
||||||
|
turn2_msg = "What was the special keyword I asked you to remember?"
|
||||||
|
turn2_text = ""
|
||||||
|
turn2_errors: list[str] = []
|
||||||
|
turn2_ended = False
|
||||||
|
|
||||||
|
async for chunk in sdk_service.stream_chat_completion_sdk(
|
||||||
|
session.session_id,
|
||||||
|
turn2_msg,
|
||||||
|
user_id=test_user_id,
|
||||||
|
session=session,
|
||||||
|
):
|
||||||
|
if isinstance(chunk, StreamTextDelta):
|
||||||
|
turn2_text += chunk.delta
|
||||||
|
elif isinstance(chunk, StreamError):
|
||||||
|
turn2_errors.append(chunk.errorText)
|
||||||
|
elif isinstance(chunk, StreamFinish):
|
||||||
|
turn2_ended = True
|
||||||
|
|
||||||
|
assert turn2_ended, "Turn 2 did not finish"
|
||||||
|
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
|
||||||
|
assert turn2_text, "Turn 2 produced no text"
|
||||||
|
assert keyword in turn2_text, (
|
||||||
|
f"Model did not recall keyword '{keyword}' in turn 2. "
|
||||||
|
f"Response: {turn2_text[:200]}"
|
||||||
|
)
|
||||||
|
logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}")
|
||||||
|
|||||||
@@ -104,6 +104,24 @@ async def create_task(
|
|||||||
Returns:
|
Returns:
|
||||||
The created ActiveTask instance (metadata only)
|
The created ActiveTask instance (metadata only)
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Build log metadata for structured logging
|
||||||
|
log_meta = {
|
||||||
|
"component": "StreamRegistry",
|
||||||
|
"task_id": task_id,
|
||||||
|
"session_id": session_id,
|
||||||
|
}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
|
||||||
task = ActiveTask(
|
task = ActiveTask(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -114,10 +132,18 @@ async def create_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Store metadata in Redis
|
# Store metadata in Redis
|
||||||
|
redis_start = time.perf_counter()
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
|
redis_time = (time.perf_counter() - redis_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
|
||||||
|
)
|
||||||
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
meta_key = _get_task_meta_key(task_id)
|
||||||
op_key = _get_operation_mapping_key(operation_id)
|
op_key = _get_operation_mapping_key(operation_id)
|
||||||
|
|
||||||
|
hset_start = time.perf_counter()
|
||||||
await redis.hset( # type: ignore[misc]
|
await redis.hset( # type: ignore[misc]
|
||||||
meta_key,
|
meta_key,
|
||||||
mapping={
|
mapping={
|
||||||
@@ -131,12 +157,22 @@ async def create_task(
|
|||||||
"created_at": task.created_at.isoformat(),
|
"created_at": task.created_at.isoformat(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
hset_time = (time.perf_counter() - hset_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] redis.hset took {hset_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
|
||||||
|
)
|
||||||
|
|
||||||
await redis.expire(meta_key, config.stream_ttl)
|
await redis.expire(meta_key, config.stream_ttl)
|
||||||
|
|
||||||
# Create operation_id -> task_id mapping for webhook lookups
|
# Create operation_id -> task_id mapping for webhook lookups
|
||||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||||
|
|
||||||
logger.debug(f"Created task {task_id} for session {session_id}")
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||||
|
)
|
||||||
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
@@ -156,26 +192,60 @@ async def publish_chunk(
|
|||||||
Returns:
|
Returns:
|
||||||
The Redis Stream message ID
|
The Redis Stream message ID
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
chunk_type = type(chunk).__name__
|
||||||
chunk_json = chunk.model_dump_json()
|
chunk_json = chunk.model_dump_json()
|
||||||
message_id = "0-0"
|
message_id = "0-0"
|
||||||
|
|
||||||
|
# Build log metadata
|
||||||
|
log_meta = {
|
||||||
|
"component": "StreamRegistry",
|
||||||
|
"task_id": task_id,
|
||||||
|
"chunk_type": chunk_type,
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
# Write to Redis Stream for persistence and real-time delivery
|
# Write to Redis Stream for persistence and real-time delivery
|
||||||
|
xadd_start = time.perf_counter()
|
||||||
raw_id = await redis.xadd(
|
raw_id = await redis.xadd(
|
||||||
stream_key,
|
stream_key,
|
||||||
{"data": chunk_json},
|
{"data": chunk_json},
|
||||||
maxlen=config.stream_max_length,
|
maxlen=config.stream_max_length,
|
||||||
)
|
)
|
||||||
|
xadd_time = (time.perf_counter() - xadd_start) * 1000
|
||||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||||
|
|
||||||
# Set TTL on stream to match task metadata TTL
|
# Set TTL on stream to match task metadata TTL
|
||||||
await redis.expire(stream_key, config.stream_ttl)
|
await redis.expire(stream_key, config.stream_ttl)
|
||||||
|
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
# Only log timing for significant chunks or slow operations
|
||||||
|
if (
|
||||||
|
chunk_type
|
||||||
|
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
|
||||||
|
or total_time > 50
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"xadd_time_ms": xadd_time,
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to publish chunk for task {task_id}: {e}",
|
f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}",
|
||||||
|
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -200,24 +270,61 @@ async def subscribe_to_task(
|
|||||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||||
or user doesn't have access
|
or user doesn't have access
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Build log metadata
|
||||||
|
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_start = time.perf_counter()
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
meta_key = _get_task_meta_key(task_id)
|
meta_key = _get_task_meta_key(task_id)
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
|
hgetall_time = (time.perf_counter() - redis_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
||||||
|
)
|
||||||
|
|
||||||
if not meta:
|
if not meta:
|
||||||
logger.debug(f"Task {task_id} not found in Redis")
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"reason": "task_not_found",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||||
task_status = meta.get("status", "")
|
task_status = meta.get("status", "")
|
||||||
task_user_id = meta.get("user_id", "") or None
|
task_user_id = meta.get("user_id", "") or None
|
||||||
|
log_meta["session_id"] = meta.get("session_id", "")
|
||||||
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
# Validate ownership - if task has an owner, requester must match
|
||||||
if task_user_id:
|
if task_user_id:
|
||||||
if user_id != task_user_id:
|
if user_id != task_user_id:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"User {user_id} denied access to task {task_id} "
|
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}",
|
||||||
f"owned by {task_user_id}"
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"task_owner": task_user_id,
|
||||||
|
"reason": "access_denied",
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -225,7 +332,19 @@ async def subscribe_to_task(
|
|||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
# Step 1: Replay messages from Redis Stream
|
# Step 1: Replay messages from Redis Stream
|
||||||
|
xread_start = time.perf_counter()
|
||||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||||
|
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": xread_time,
|
||||||
|
"task_status": task_status,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
replayed_count = 0
|
replayed_count = 0
|
||||||
replay_last_id = last_message_id
|
replay_last_id = last_message_id
|
||||||
@@ -244,19 +363,48 @@ async def subscribe_to_task(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to replay message: {e}")
|
logger.warning(f"Failed to replay message: {e}")
|
||||||
|
|
||||||
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
|
logger.info(
|
||||||
|
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"n_messages_replayed": replayed_count,
|
||||||
|
"replay_last_id": replay_last_id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Step 2: If task is still running, start stream listener for live updates
|
# Step 2: If task is still running, start stream listener for live updates
|
||||||
if task_status == "running":
|
if task_status == "running":
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] Task still running, starting _stream_listener",
|
||||||
|
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||||
|
)
|
||||||
listener_task = asyncio.create_task(
|
listener_task = asyncio.create_task(
|
||||||
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
|
||||||
)
|
)
|
||||||
# Track listener task for cleanup on unsubscribe
|
# Track listener task for cleanup on unsubscribe
|
||||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||||
else:
|
else:
|
||||||
# Task is completed/failed - add finish marker
|
# Task is completed/failed - add finish marker
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Task already {task_status}, adding StreamFinish",
|
||||||
|
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||||
|
)
|
||||||
await subscriber_queue.put(StreamFinish())
|
await subscriber_queue.put(StreamFinish())
|
||||||
|
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
|
||||||
|
f"n_messages_replayed={replayed_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"n_messages_replayed": replayed_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
return subscriber_queue
|
return subscriber_queue
|
||||||
|
|
||||||
|
|
||||||
@@ -264,6 +412,7 @@ async def _stream_listener(
|
|||||||
task_id: str,
|
task_id: str,
|
||||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||||
last_replayed_id: str,
|
last_replayed_id: str,
|
||||||
|
log_meta: dict | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||||
|
|
||||||
@@ -274,10 +423,27 @@ async def _stream_listener(
|
|||||||
task_id: Task ID to listen for
|
task_id: Task ID to listen for
|
||||||
subscriber_queue: Queue to deliver messages to
|
subscriber_queue: Queue to deliver messages to
|
||||||
last_replayed_id: Last message ID from replay (continue from here)
|
last_replayed_id: Last message ID from replay (continue from here)
|
||||||
|
log_meta: Structured logging metadata
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Use provided log_meta or build minimal one
|
||||||
|
if log_meta is None:
|
||||||
|
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
|
||||||
|
)
|
||||||
|
|
||||||
queue_id = id(subscriber_queue)
|
queue_id = id(subscriber_queue)
|
||||||
# Track the last successfully delivered message ID for recovery hints
|
# Track the last successfully delivered message ID for recovery hints
|
||||||
last_delivered_id = last_replayed_id
|
last_delivered_id = last_replayed_id
|
||||||
|
messages_delivered = 0
|
||||||
|
first_message_time = None
|
||||||
|
xread_count = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
@@ -287,9 +453,39 @@ async def _stream_listener(
|
|||||||
while True:
|
while True:
|
||||||
# Block for up to 30 seconds waiting for new messages
|
# Block for up to 30 seconds waiting for new messages
|
||||||
# This allows periodic checking if task is still running
|
# This allows periodic checking if task is still running
|
||||||
|
xread_start = time.perf_counter()
|
||||||
|
xread_count += 1
|
||||||
messages = await redis.xread(
|
messages = await redis.xread(
|
||||||
{stream_key: current_id}, block=30000, count=100
|
{stream_key: current_id}, block=30000, count=100
|
||||||
)
|
)
|
||||||
|
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||||
|
|
||||||
|
if messages:
|
||||||
|
msg_count = sum(len(msgs) for _, msgs in messages)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"xread_count": xread_count,
|
||||||
|
"n_messages": msg_count,
|
||||||
|
"duration_ms": xread_time,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif xread_time > 1000:
|
||||||
|
# Only log timeouts (30s blocking)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"xread_count": xread_count,
|
||||||
|
"duration_ms": xread_time,
|
||||||
|
"reason": "timeout",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if not messages:
|
if not messages:
|
||||||
# Timeout - check if task is still running
|
# Timeout - check if task is still running
|
||||||
@@ -326,10 +522,30 @@ async def _stream_listener(
|
|||||||
)
|
)
|
||||||
# Update last delivered ID on successful delivery
|
# Update last delivered ID on successful delivery
|
||||||
last_delivered_id = current_id
|
last_delivered_id = current_id
|
||||||
|
messages_delivered += 1
|
||||||
|
if first_message_time is None:
|
||||||
|
first_message_time = time.perf_counter()
|
||||||
|
elapsed = (first_message_time - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"chunk_type": type(chunk).__name__,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Subscriber queue full for task {task_id}, "
|
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
|
||||||
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"timeout_s": QUEUE_PUT_TIMEOUT,
|
||||||
|
"reason": "queue_full",
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
# Send overflow error with recovery info
|
# Send overflow error with recovery info
|
||||||
try:
|
try:
|
||||||
@@ -351,15 +567,44 @@ async def _stream_listener(
|
|||||||
|
|
||||||
# Stop listening on finish
|
# Stop listening on finish
|
||||||
if isinstance(chunk, StreamFinish):
|
if isinstance(chunk, StreamFinish):
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"messages_delivered": messages_delivered,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error processing stream message: {e}")
|
logger.warning(
|
||||||
|
f"Error processing stream message: {e}",
|
||||||
|
extra={"json_fields": {**log_meta, "error": str(e)}},
|
||||||
|
)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.debug(f"Stream listener cancelled for task {task_id}")
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"messages_delivered": messages_delivered,
|
||||||
|
"reason": "cancelled",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
raise # Re-raise to propagate cancellation
|
raise # Re-raise to propagate cancellation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Stream listener error for task {task_id}: {e}")
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.error(
|
||||||
|
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
|
||||||
|
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
||||||
|
)
|
||||||
# On error, send finish to unblock subscriber
|
# On error, send finish to unblock subscriber
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
@@ -368,10 +613,24 @@ async def _stream_listener(
|
|||||||
)
|
)
|
||||||
except (asyncio.TimeoutError, asyncio.QueueFull):
|
except (asyncio.TimeoutError, asyncio.QueueFull):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Could not deliver finish event for task {task_id} after error"
|
"Could not deliver finish event after error",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
# Clean up listener task mapping on exit
|
# Clean up listener task mapping on exit
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
|
||||||
|
f"delivered={messages_delivered}, xread_count={xread_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"messages_delivered": messages_delivered,
|
||||||
|
"xread_count": xread_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
_listener_tasks.pop(queue_id, None)
|
_listener_tasks.pop(queue_id, None)
|
||||||
|
|
||||||
|
|
||||||
@@ -555,6 +814,28 @@ async def get_active_task_for_session(
|
|||||||
if task_user_id and user_id != task_user_id:
|
if task_user_id and user_id != task_user_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Auto-expire stale tasks that exceeded stream_timeout
|
||||||
|
created_at_str = meta.get("created_at", "")
|
||||||
|
if created_at_str:
|
||||||
|
try:
|
||||||
|
created_at = datetime.fromisoformat(created_at_str)
|
||||||
|
age_seconds = (
|
||||||
|
datetime.now(timezone.utc) - created_at
|
||||||
|
).total_seconds()
|
||||||
|
if age_seconds > config.stream_timeout:
|
||||||
|
logger.warning(
|
||||||
|
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
|
||||||
|
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
|
||||||
|
)
|
||||||
|
await mark_task_completed(task_id, "failed")
|
||||||
|
continue
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
||||||
|
)
|
||||||
|
|
||||||
# Get the last message ID from Redis Stream
|
# Get the last message ID from Redis Stream
|
||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
last_id = "0-0"
|
last_id = "0-0"
|
||||||
@@ -598,8 +879,10 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
|||||||
ResponseType,
|
ResponseType,
|
||||||
StreamError,
|
StreamError,
|
||||||
StreamFinish,
|
StreamFinish,
|
||||||
|
StreamFinishStep,
|
||||||
StreamHeartbeat,
|
StreamHeartbeat,
|
||||||
StreamStart,
|
StreamStart,
|
||||||
|
StreamStartStep,
|
||||||
StreamTextDelta,
|
StreamTextDelta,
|
||||||
StreamTextEnd,
|
StreamTextEnd,
|
||||||
StreamTextStart,
|
StreamTextStart,
|
||||||
@@ -613,6 +896,8 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
|||||||
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
||||||
ResponseType.START.value: StreamStart,
|
ResponseType.START.value: StreamStart,
|
||||||
ResponseType.FINISH.value: StreamFinish,
|
ResponseType.FINISH.value: StreamFinish,
|
||||||
|
ResponseType.START_STEP.value: StreamStartStep,
|
||||||
|
ResponseType.FINISH_STEP.value: StreamFinishStep,
|
||||||
ResponseType.TEXT_START.value: StreamTextStart,
|
ResponseType.TEXT_START.value: StreamTextStart,
|
||||||
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
||||||
ResponseType.TEXT_END.value: StreamTextEnd,
|
ResponseType.TEXT_END.value: StreamTextEnd,
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ from backend.api.features.chat.tracking import track_tool_called
|
|||||||
from .add_understanding import AddUnderstandingTool
|
from .add_understanding import AddUnderstandingTool
|
||||||
from .agent_output import AgentOutputTool
|
from .agent_output import AgentOutputTool
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .bash_exec import BashExecTool
|
||||||
|
from .check_operation_status import CheckOperationStatusTool
|
||||||
from .create_agent import CreateAgentTool
|
from .create_agent import CreateAgentTool
|
||||||
from .customize_agent import CustomizeAgentTool
|
from .customize_agent import CustomizeAgentTool
|
||||||
from .edit_agent import EditAgentTool
|
from .edit_agent import EditAgentTool
|
||||||
@@ -19,6 +21,7 @@ from .get_doc_page import GetDocPageTool
|
|||||||
from .run_agent import RunAgentTool
|
from .run_agent import RunAgentTool
|
||||||
from .run_block import RunBlockTool
|
from .run_block import RunBlockTool
|
||||||
from .search_docs import SearchDocsTool
|
from .search_docs import SearchDocsTool
|
||||||
|
from .web_fetch import WebFetchTool
|
||||||
from .workspace_files import (
|
from .workspace_files import (
|
||||||
DeleteWorkspaceFileTool,
|
DeleteWorkspaceFileTool,
|
||||||
ListWorkspaceFilesTool,
|
ListWorkspaceFilesTool,
|
||||||
@@ -43,9 +46,14 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
|||||||
"run_agent": RunAgentTool(),
|
"run_agent": RunAgentTool(),
|
||||||
"run_block": RunBlockTool(),
|
"run_block": RunBlockTool(),
|
||||||
"view_agent_output": AgentOutputTool(),
|
"view_agent_output": AgentOutputTool(),
|
||||||
|
"check_operation_status": CheckOperationStatusTool(),
|
||||||
"search_docs": SearchDocsTool(),
|
"search_docs": SearchDocsTool(),
|
||||||
"get_doc_page": GetDocPageTool(),
|
"get_doc_page": GetDocPageTool(),
|
||||||
# Workspace tools for CoPilot file operations
|
# Web fetch for safe URL retrieval
|
||||||
|
"web_fetch": WebFetchTool(),
|
||||||
|
# Sandboxed code execution (bubblewrap)
|
||||||
|
"bash_exec": BashExecTool(),
|
||||||
|
# Persistent workspace tools (cloud storage, survives across sessions)
|
||||||
"list_workspace_files": ListWorkspaceFilesTool(),
|
"list_workspace_files": ListWorkspaceFilesTool(),
|
||||||
"read_workspace_file": ReadWorkspaceFileTool(),
|
"read_workspace_file": ReadWorkspaceFileTool(),
|
||||||
"write_workspace_file": WriteWorkspaceFileTool(),
|
"write_workspace_file": WriteWorkspaceFileTool(),
|
||||||
|
|||||||
@@ -0,0 +1,154 @@
|
|||||||
|
"""Dummy Agent Generator for testing.
|
||||||
|
|
||||||
|
Returns mock responses matching the format expected from the external service.
|
||||||
|
Enable via AGENTGENERATOR_USE_DUMMY=true in settings.
|
||||||
|
|
||||||
|
WARNING: This is for testing only. Do not use in production.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Dummy decomposition result (instructions type)
|
||||||
|
DUMMY_DECOMPOSITION_RESULT: dict[str, Any] = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [
|
||||||
|
{
|
||||||
|
"description": "Get input from user",
|
||||||
|
"action": "input",
|
||||||
|
"block_name": "AgentInputBlock",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Process the input",
|
||||||
|
"action": "process",
|
||||||
|
"block_name": "TextFormatterBlock",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Return output to user",
|
||||||
|
"action": "output",
|
||||||
|
"block_name": "AgentOutputBlock",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Block IDs from backend/blocks/io.py
|
||||||
|
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
|
||||||
|
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_dummy_agent_json() -> dict[str, Any]:
|
||||||
|
"""Generate a minimal valid agent JSON for testing."""
|
||||||
|
input_node_id = str(uuid.uuid4())
|
||||||
|
output_node_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"version": 1,
|
||||||
|
"is_active": True,
|
||||||
|
"name": "Dummy Test Agent",
|
||||||
|
"description": "A dummy agent generated for testing purposes",
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"id": input_node_id,
|
||||||
|
"block_id": AGENT_INPUT_BLOCK_ID,
|
||||||
|
"input_default": {
|
||||||
|
"name": "input",
|
||||||
|
"title": "Input",
|
||||||
|
"description": "Enter your input",
|
||||||
|
"placeholder_values": [],
|
||||||
|
},
|
||||||
|
"metadata": {"position": {"x": 0, "y": 0}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": output_node_id,
|
||||||
|
"block_id": AGENT_OUTPUT_BLOCK_ID,
|
||||||
|
"input_default": {
|
||||||
|
"name": "output",
|
||||||
|
"title": "Output",
|
||||||
|
"description": "Agent output",
|
||||||
|
"format": "{output}",
|
||||||
|
},
|
||||||
|
"metadata": {"position": {"x": 400, "y": 0}},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"links": [
|
||||||
|
{
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"source_id": input_node_id,
|
||||||
|
"sink_id": output_node_id,
|
||||||
|
"source_name": "result",
|
||||||
|
"sink_name": "value",
|
||||||
|
"is_static": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def decompose_goal_dummy(
|
||||||
|
description: str,
|
||||||
|
context: str = "",
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return dummy decomposition result."""
|
||||||
|
logger.info("Using dummy agent generator for decompose_goal")
|
||||||
|
return DUMMY_DECOMPOSITION_RESULT.copy()
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_agent_dummy(
|
||||||
|
instructions: dict[str, Any],
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
|
operation_id: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return dummy agent JSON after a simulated delay."""
|
||||||
|
logger.info("Using dummy agent generator for generate_agent (30s delay)")
|
||||||
|
await asyncio.sleep(30)
|
||||||
|
return _generate_dummy_agent_json()
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_agent_patch_dummy(
|
||||||
|
update_request: str,
|
||||||
|
current_agent: dict[str, Any],
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
|
operation_id: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return dummy patched agent (returns the current agent with updated description)."""
|
||||||
|
logger.info("Using dummy agent generator for generate_agent_patch")
|
||||||
|
patched = current_agent.copy()
|
||||||
|
patched["description"] = (
|
||||||
|
f"{current_agent.get('description', '')} (updated: {update_request})"
|
||||||
|
)
|
||||||
|
return patched
|
||||||
|
|
||||||
|
|
||||||
|
async def customize_template_dummy(
|
||||||
|
template_agent: dict[str, Any],
|
||||||
|
modification_request: str,
|
||||||
|
context: str = "",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return dummy customized template (returns template with updated description)."""
|
||||||
|
logger.info("Using dummy agent generator for customize_template")
|
||||||
|
customized = template_agent.copy()
|
||||||
|
customized["description"] = (
|
||||||
|
f"{template_agent.get('description', '')} (customized: {modification_request})"
|
||||||
|
)
|
||||||
|
return customized
|
||||||
|
|
||||||
|
|
||||||
|
async def get_blocks_dummy() -> list[dict[str, Any]]:
|
||||||
|
"""Return dummy blocks list."""
|
||||||
|
logger.info("Using dummy agent generator for get_blocks")
|
||||||
|
return [
|
||||||
|
{"id": AGENT_INPUT_BLOCK_ID, "name": "AgentInputBlock"},
|
||||||
|
{"id": AGENT_OUTPUT_BLOCK_ID, "name": "AgentOutputBlock"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def health_check_dummy() -> bool:
|
||||||
|
"""Always returns healthy for dummy service."""
|
||||||
|
return True
|
||||||
@@ -12,8 +12,19 @@ import httpx
|
|||||||
|
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
from .dummy import (
|
||||||
|
customize_template_dummy,
|
||||||
|
decompose_goal_dummy,
|
||||||
|
generate_agent_dummy,
|
||||||
|
generate_agent_patch_dummy,
|
||||||
|
get_blocks_dummy,
|
||||||
|
health_check_dummy,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_dummy_mode_warned = False
|
||||||
|
|
||||||
|
|
||||||
def _create_error_response(
|
def _create_error_response(
|
||||||
error_message: str,
|
error_message: str,
|
||||||
@@ -90,10 +101,26 @@ def _get_settings() -> Settings:
|
|||||||
return _settings
|
return _settings
|
||||||
|
|
||||||
|
|
||||||
def is_external_service_configured() -> bool:
|
def _is_dummy_mode() -> bool:
|
||||||
"""Check if external Agent Generator service is configured."""
|
"""Check if dummy mode is enabled for testing."""
|
||||||
|
global _dummy_mode_warned
|
||||||
settings = _get_settings()
|
settings = _get_settings()
|
||||||
return bool(settings.config.agentgenerator_host)
|
is_dummy = bool(settings.config.agentgenerator_use_dummy)
|
||||||
|
if is_dummy and not _dummy_mode_warned:
|
||||||
|
logger.warning(
|
||||||
|
"Agent Generator running in DUMMY MODE - returning mock responses. "
|
||||||
|
"Do not use in production!"
|
||||||
|
)
|
||||||
|
_dummy_mode_warned = True
|
||||||
|
return is_dummy
|
||||||
|
|
||||||
|
|
||||||
|
def is_external_service_configured() -> bool:
|
||||||
|
"""Check if external Agent Generator service is configured (or dummy mode)."""
|
||||||
|
settings = _get_settings()
|
||||||
|
return bool(settings.config.agentgenerator_host) or bool(
|
||||||
|
settings.config.agentgenerator_use_dummy
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_base_url() -> str:
|
def _get_base_url() -> str:
|
||||||
@@ -137,6 +164,9 @@ async def decompose_goal_external(
|
|||||||
- {"type": "error", "error": "...", "error_type": "..."} on error
|
- {"type": "error", "error": "...", "error_type": "..."} on error
|
||||||
Or None on unexpected error
|
Or None on unexpected error
|
||||||
"""
|
"""
|
||||||
|
if _is_dummy_mode():
|
||||||
|
return await decompose_goal_dummy(description, context, library_agents)
|
||||||
|
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
if context:
|
if context:
|
||||||
@@ -226,6 +256,11 @@ async def generate_agent_external(
|
|||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
||||||
"""
|
"""
|
||||||
|
if _is_dummy_mode():
|
||||||
|
return await generate_agent_dummy(
|
||||||
|
instructions, library_agents, operation_id, task_id
|
||||||
|
)
|
||||||
|
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
# Build request payload
|
# Build request payload
|
||||||
@@ -297,6 +332,11 @@ async def generate_agent_patch_external(
|
|||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
||||||
"""
|
"""
|
||||||
|
if _is_dummy_mode():
|
||||||
|
return await generate_agent_patch_dummy(
|
||||||
|
update_request, current_agent, library_agents, operation_id, task_id
|
||||||
|
)
|
||||||
|
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
# Build request payload
|
# Build request payload
|
||||||
@@ -383,6 +423,11 @@ async def customize_template_external(
|
|||||||
Returns:
|
Returns:
|
||||||
Customized agent JSON, clarifying questions dict, or error dict on error
|
Customized agent JSON, clarifying questions dict, or error dict on error
|
||||||
"""
|
"""
|
||||||
|
if _is_dummy_mode():
|
||||||
|
return await customize_template_dummy(
|
||||||
|
template_agent, modification_request, context
|
||||||
|
)
|
||||||
|
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
request = modification_request
|
request = modification_request
|
||||||
@@ -445,6 +490,9 @@ async def get_blocks_external() -> list[dict[str, Any]] | None:
|
|||||||
Returns:
|
Returns:
|
||||||
List of block info dicts or None on error
|
List of block info dicts or None on error
|
||||||
"""
|
"""
|
||||||
|
if _is_dummy_mode():
|
||||||
|
return await get_blocks_dummy()
|
||||||
|
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -478,6 +526,9 @@ async def health_check() -> bool:
|
|||||||
if not is_external_service_configured():
|
if not is_external_service_configured():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if _is_dummy_mode():
|
||||||
|
return await health_check_dummy()
|
||||||
|
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -0,0 +1,131 @@
|
|||||||
|
"""Bash execution tool — run shell commands in a bubblewrap sandbox.
|
||||||
|
|
||||||
|
Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.).
|
||||||
|
Safety comes from OS-level isolation (bubblewrap): only system dirs visible
|
||||||
|
read-only, writable workspace only, clean env, no network.
|
||||||
|
|
||||||
|
Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not
|
||||||
|
available (e.g. macOS development).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
|
from backend.api.features.chat.tools.models import (
|
||||||
|
BashExecResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
ToolResponseBase,
|
||||||
|
)
|
||||||
|
from backend.api.features.chat.tools.sandbox import (
|
||||||
|
get_workspace_dir,
|
||||||
|
has_full_sandbox,
|
||||||
|
run_sandboxed,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BashExecTool(BaseTool):
|
||||||
|
"""Execute Bash commands in a bubblewrap sandbox."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "bash_exec"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
if not has_full_sandbox():
|
||||||
|
return (
|
||||||
|
"Bash execution is DISABLED — bubblewrap sandbox is not "
|
||||||
|
"available on this platform. Do not call this tool."
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
"Execute a Bash command or script in a bubblewrap sandbox. "
|
||||||
|
"Full Bash scripting is supported (loops, conditionals, pipes, "
|
||||||
|
"functions, etc.). "
|
||||||
|
"The sandbox shares the same working directory as the SDK Read/Write "
|
||||||
|
"tools — files created by either are accessible to both. "
|
||||||
|
"SECURITY: Only system directories (/usr, /bin, /lib, /etc) are "
|
||||||
|
"visible read-only, the per-session workspace is the only writable "
|
||||||
|
"path, environment variables are wiped (no secrets), all network "
|
||||||
|
"access is blocked at the kernel level, and resource limits are "
|
||||||
|
"enforced (max 64 processes, 512MB memory, 50MB file size). "
|
||||||
|
"Application code, configs, and other directories are NOT accessible. "
|
||||||
|
"To fetch web content, use the web_fetch tool instead. "
|
||||||
|
"Execution is killed after the timeout (default 30s, max 120s). "
|
||||||
|
"Returns stdout and stderr. "
|
||||||
|
"Useful for file manipulation, data processing with Unix tools "
|
||||||
|
"(grep, awk, sed, jq, etc.), and running shell scripts."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Bash command or script to execute.",
|
||||||
|
},
|
||||||
|
"timeout": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": (
|
||||||
|
"Max execution time in seconds (default 30, max 120)."
|
||||||
|
),
|
||||||
|
"default": 30,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["command"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
|
if not has_full_sandbox():
|
||||||
|
return ErrorResponse(
|
||||||
|
message="bash_exec requires bubblewrap sandbox (Linux only).",
|
||||||
|
error="sandbox_unavailable",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
command: str = (kwargs.get("command") or "").strip()
|
||||||
|
timeout: int = kwargs.get("timeout", 30)
|
||||||
|
|
||||||
|
if not command:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="No command provided.",
|
||||||
|
error="empty_command",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
workspace = get_workspace_dir(session_id or "default")
|
||||||
|
|
||||||
|
stdout, stderr, exit_code, timed_out = await run_sandboxed(
|
||||||
|
command=["bash", "-c", command],
|
||||||
|
cwd=workspace,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
return BashExecResponse(
|
||||||
|
message=(
|
||||||
|
"Execution timed out"
|
||||||
|
if timed_out
|
||||||
|
else f"Command executed (exit {exit_code})"
|
||||||
|
),
|
||||||
|
stdout=stdout,
|
||||||
|
stderr=stderr,
|
||||||
|
exit_code=exit_code,
|
||||||
|
timed_out=timed_out,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
@@ -0,0 +1,127 @@
|
|||||||
|
"""CheckOperationStatusTool — query the status of a long-running operation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
|
from backend.api.features.chat.tools.models import (
|
||||||
|
ErrorResponse,
|
||||||
|
ResponseType,
|
||||||
|
ToolResponseBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OperationStatusResponse(ToolResponseBase):
|
||||||
|
"""Response for check_operation_status tool."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_STATUS
|
||||||
|
task_id: str
|
||||||
|
operation_id: str
|
||||||
|
status: str # "running", "completed", "failed"
|
||||||
|
tool_name: str | None = None
|
||||||
|
message: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class CheckOperationStatusTool(BaseTool):
|
||||||
|
"""Check the status of a long-running operation (create_agent, edit_agent, etc.).
|
||||||
|
|
||||||
|
The CoPilot uses this tool to report back to the user whether an
|
||||||
|
operation that was started earlier has completed, failed, or is still
|
||||||
|
running.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "check_operation_status"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Check the current status of a long-running operation such as "
|
||||||
|
"create_agent or edit_agent. Accepts either an operation_id or "
|
||||||
|
"task_id from a previous operation_started response. "
|
||||||
|
"Returns the current status: running, completed, or failed."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"operation_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The operation_id from an operation_started response."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"task_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The task_id from an operation_started response. "
|
||||||
|
"Used as fallback if operation_id is not provided."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
from backend.api.features.chat import stream_registry
|
||||||
|
|
||||||
|
operation_id = (kwargs.get("operation_id") or "").strip()
|
||||||
|
task_id = (kwargs.get("task_id") or "").strip()
|
||||||
|
|
||||||
|
if not operation_id and not task_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide an operation_id or task_id.",
|
||||||
|
error="missing_parameter",
|
||||||
|
)
|
||||||
|
|
||||||
|
task = None
|
||||||
|
if operation_id:
|
||||||
|
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||||
|
if task is None and task_id:
|
||||||
|
task = await stream_registry.get_task(task_id)
|
||||||
|
|
||||||
|
if task is None:
|
||||||
|
# Task not in Redis — it may have already expired (TTL).
|
||||||
|
# Check conversation history for the result instead.
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
"Operation not found — it may have already completed and "
|
||||||
|
"expired from the status tracker. Check the conversation "
|
||||||
|
"history for the result."
|
||||||
|
),
|
||||||
|
error="not_found",
|
||||||
|
)
|
||||||
|
|
||||||
|
status_messages = {
|
||||||
|
"running": (
|
||||||
|
f"The {task.tool_name or 'operation'} is still running. "
|
||||||
|
"Please wait for it to complete."
|
||||||
|
),
|
||||||
|
"completed": (
|
||||||
|
f"The {task.tool_name or 'operation'} has completed successfully."
|
||||||
|
),
|
||||||
|
"failed": f"The {task.tool_name or 'operation'} has failed.",
|
||||||
|
}
|
||||||
|
|
||||||
|
return OperationStatusResponse(
|
||||||
|
task_id=task.task_id,
|
||||||
|
operation_id=task.operation_id,
|
||||||
|
status=task.status,
|
||||||
|
tool_name=task.tool_name,
|
||||||
|
message=status_messages.get(task.status, f"Status: {task.status}"),
|
||||||
|
)
|
||||||
@@ -7,16 +7,38 @@ from backend.api.features.chat.model import ChatSession
|
|||||||
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||||
from backend.api.features.chat.tools.models import (
|
from backend.api.features.chat.tools.models import (
|
||||||
BlockInfoSummary,
|
BlockInfoSummary,
|
||||||
BlockInputFieldInfo,
|
|
||||||
BlockListResponse,
|
BlockListResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
from backend.data.block import get_block
|
from backend.blocks import get_block
|
||||||
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_TARGET_RESULTS = 10
|
||||||
|
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
|
||||||
|
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
|
||||||
|
_OVERFETCH_PAGE_SIZE = 40
|
||||||
|
|
||||||
|
# Block types that only work within graphs and cannot run standalone in CoPilot.
|
||||||
|
COPILOT_EXCLUDED_BLOCK_TYPES = {
|
||||||
|
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
|
||||||
|
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
|
||||||
|
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
|
||||||
|
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
|
||||||
|
BlockType.NOTE, # Visual annotation only - no runtime behavior
|
||||||
|
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
|
||||||
|
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
|
||||||
|
}
|
||||||
|
|
||||||
|
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
||||||
|
COPILOT_EXCLUDED_BLOCK_IDS = {
|
||||||
|
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
|
||||||
|
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class FindBlockTool(BaseTool):
|
class FindBlockTool(BaseTool):
|
||||||
"""Tool for searching available blocks."""
|
"""Tool for searching available blocks."""
|
||||||
@@ -32,7 +54,8 @@ class FindBlockTool(BaseTool):
|
|||||||
"Blocks are reusable components that perform specific tasks like "
|
"Blocks are reusable components that perform specific tasks like "
|
||||||
"sending emails, making API calls, processing text, etc. "
|
"sending emails, making API calls, processing text, etc. "
|
||||||
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
|
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
|
||||||
"The response includes each block's id, required_inputs, and input_schema."
|
"The response includes each block's id, name, and description. "
|
||||||
|
"Call run_block with the block's id **with no inputs** to see detailed inputs/outputs and execute it."
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -88,7 +111,7 @@ class FindBlockTool(BaseTool):
|
|||||||
query=query,
|
query=query,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
page_size=10,
|
page_size=_OVERFETCH_PAGE_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
@@ -101,67 +124,44 @@ class FindBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Enrich results with full block information
|
# Enrich results with block information
|
||||||
blocks: list[BlockInfoSummary] = []
|
blocks: list[BlockInfoSummary] = []
|
||||||
for result in results:
|
for result in results:
|
||||||
block_id = result["content_id"]
|
block_id = result["content_id"]
|
||||||
block = get_block(block_id)
|
block = get_block(block_id)
|
||||||
|
|
||||||
# Skip disabled blocks
|
# Skip disabled blocks
|
||||||
if block and not block.disabled:
|
if not block or block.disabled:
|
||||||
# Get input/output schemas
|
continue
|
||||||
input_schema = {}
|
|
||||||
output_schema = {}
|
|
||||||
try:
|
|
||||||
input_schema = block.input_schema.jsonschema()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
output_schema = block.output_schema.jsonschema()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Get categories from block instance
|
# Skip blocks excluded from CoPilot (graph-only blocks)
|
||||||
categories = []
|
if (
|
||||||
if hasattr(block, "categories") and block.categories:
|
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
categories = [cat.value for cat in block.categories]
|
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
# Extract required inputs for easier use
|
blocks.append(
|
||||||
required_inputs: list[BlockInputFieldInfo] = []
|
BlockInfoSummary(
|
||||||
if input_schema:
|
id=block_id,
|
||||||
properties = input_schema.get("properties", {})
|
name=block.name,
|
||||||
required_fields = set(input_schema.get("required", []))
|
description=block.description or "",
|
||||||
# Get credential field names to exclude from required inputs
|
categories=[c.value for c in block.categories],
|
||||||
credentials_fields = set(
|
|
||||||
block.input_schema.get_credentials_fields().keys()
|
|
||||||
)
|
|
||||||
|
|
||||||
for field_name, field_schema in properties.items():
|
|
||||||
# Skip credential fields - they're handled separately
|
|
||||||
if field_name in credentials_fields:
|
|
||||||
continue
|
|
||||||
|
|
||||||
required_inputs.append(
|
|
||||||
BlockInputFieldInfo(
|
|
||||||
name=field_name,
|
|
||||||
type=field_schema.get("type", "string"),
|
|
||||||
description=field_schema.get("description", ""),
|
|
||||||
required=field_name in required_fields,
|
|
||||||
default=field_schema.get("default"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
blocks.append(
|
|
||||||
BlockInfoSummary(
|
|
||||||
id=block_id,
|
|
||||||
name=block.name,
|
|
||||||
description=block.description or "",
|
|
||||||
categories=categories,
|
|
||||||
input_schema=input_schema,
|
|
||||||
output_schema=output_schema,
|
|
||||||
required_inputs=required_inputs,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(blocks) >= _TARGET_RESULTS:
|
||||||
|
break
|
||||||
|
|
||||||
|
if blocks and len(blocks) < _TARGET_RESULTS:
|
||||||
|
logger.debug(
|
||||||
|
"find_block returned %d/%d results for query '%s' "
|
||||||
|
"(filtered %d excluded/disabled blocks)",
|
||||||
|
len(blocks),
|
||||||
|
_TARGET_RESULTS,
|
||||||
|
query,
|
||||||
|
len(results) - len(blocks),
|
||||||
|
)
|
||||||
|
|
||||||
if not blocks:
|
if not blocks:
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
@@ -175,8 +175,7 @@ class FindBlockTool(BaseTool):
|
|||||||
return BlockListResponse(
|
return BlockListResponse(
|
||||||
message=(
|
message=(
|
||||||
f"Found {len(blocks)} block(s) matching '{query}'. "
|
f"Found {len(blocks)} block(s) matching '{query}'. "
|
||||||
"To execute a block, use run_block with the block's 'id' field "
|
"To see a block's inputs/outputs and execute it, use run_block with the block's 'id' - providing no inputs."
|
||||||
"and provide 'input_data' matching the block's input_schema."
|
|
||||||
),
|
),
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
count=len(blocks),
|
count=len(blocks),
|
||||||
|
|||||||
@@ -0,0 +1,386 @@
|
|||||||
|
"""Tests for block filtering in FindBlockTool."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.find_block import (
|
||||||
|
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||||
|
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||||
|
FindBlockTool,
|
||||||
|
)
|
||||||
|
from backend.api.features.chat.tools.models import BlockListResponse
|
||||||
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
_TEST_USER_ID = "test-user-find-block"
|
||||||
|
|
||||||
|
|
||||||
|
def make_mock_block(
|
||||||
|
block_id: str,
|
||||||
|
name: str,
|
||||||
|
block_type: BlockType,
|
||||||
|
disabled: bool = False,
|
||||||
|
input_schema: dict | None = None,
|
||||||
|
output_schema: dict | None = None,
|
||||||
|
credentials_fields: dict | None = None,
|
||||||
|
):
|
||||||
|
"""Create a mock block for testing."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.id = block_id
|
||||||
|
mock.name = name
|
||||||
|
mock.description = f"{name} description"
|
||||||
|
mock.block_type = block_type
|
||||||
|
mock.disabled = disabled
|
||||||
|
mock.input_schema = MagicMock()
|
||||||
|
mock.input_schema.jsonschema.return_value = input_schema or {
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
mock.input_schema.get_credentials_fields.return_value = credentials_fields or {}
|
||||||
|
mock.output_schema = MagicMock()
|
||||||
|
mock.output_schema.jsonschema.return_value = output_schema or {}
|
||||||
|
mock.categories = []
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindBlockFiltering:
|
||||||
|
"""Tests for block filtering in FindBlockTool."""
|
||||||
|
|
||||||
|
def test_excluded_block_types_contains_expected_types(self):
|
||||||
|
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
|
||||||
|
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
|
||||||
|
def test_excluded_block_ids_contains_smart_decision_maker(self):
|
||||||
|
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
||||||
|
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_type_filtered_from_results(self):
|
||||||
|
"""Verify blocks with excluded BlockTypes are filtered from search results."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
|
||||||
|
search_results = [
|
||||||
|
{"content_id": "input-block-id", "score": 0.9},
|
||||||
|
{"content_id": "standard-block-id", "score": 0.8},
|
||||||
|
]
|
||||||
|
|
||||||
|
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||||
|
standard_block = make_mock_block(
|
||||||
|
"standard-block-id", "HTTP Request", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_get_block(block_id):
|
||||||
|
return {
|
||||||
|
"input-block-id": input_block,
|
||||||
|
"standard-block-id": standard_block,
|
||||||
|
}.get(block_id)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=(search_results, 2),
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.get_block",
|
||||||
|
side_effect=mock_get_block,
|
||||||
|
):
|
||||||
|
tool = FindBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID, session=session, query="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only return the standard block, not the INPUT block
|
||||||
|
assert isinstance(response, BlockListResponse)
|
||||||
|
assert len(response.blocks) == 1
|
||||||
|
assert response.blocks[0].id == "standard-block-id"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_id_filtered_from_results(self):
|
||||||
|
"""Verify SmartDecisionMakerBlock is filtered from search results."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||||
|
search_results = [
|
||||||
|
{"content_id": smart_decision_id, "score": 0.9},
|
||||||
|
{"content_id": "normal-block-id", "score": 0.8},
|
||||||
|
]
|
||||||
|
|
||||||
|
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
|
||||||
|
smart_block = make_mock_block(
|
||||||
|
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
normal_block = make_mock_block(
|
||||||
|
"normal-block-id", "Normal Block", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_get_block(block_id):
|
||||||
|
return {
|
||||||
|
smart_decision_id: smart_block,
|
||||||
|
"normal-block-id": normal_block,
|
||||||
|
}.get(block_id)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=(search_results, 2),
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.get_block",
|
||||||
|
side_effect=mock_get_block,
|
||||||
|
):
|
||||||
|
tool = FindBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID, session=session, query="decision"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only return normal block, not SmartDecisionMakerBlock
|
||||||
|
assert isinstance(response, BlockListResponse)
|
||||||
|
assert len(response.blocks) == 1
|
||||||
|
assert response.blocks[0].id == "normal-block-id"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_response_size_average_chars_per_block(self):
|
||||||
|
"""Measure average chars per block in the serialized response."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
# Realistic block definitions modeled after real blocks
|
||||||
|
block_defs = [
|
||||||
|
{
|
||||||
|
"id": "http-block-id",
|
||||||
|
"name": "Send Web Request",
|
||||||
|
"input_schema": {
|
||||||
|
"properties": {
|
||||||
|
"url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The URL to send the request to",
|
||||||
|
},
|
||||||
|
"method": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The HTTP method to use",
|
||||||
|
},
|
||||||
|
"headers": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Headers to include in the request",
|
||||||
|
},
|
||||||
|
"json_format": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "If true, send the body as JSON",
|
||||||
|
},
|
||||||
|
"body": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Form/JSON body payload",
|
||||||
|
},
|
||||||
|
"credentials": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "HTTP credentials",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["url", "method"],
|
||||||
|
},
|
||||||
|
"output_schema": {
|
||||||
|
"properties": {
|
||||||
|
"response": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "The response from the server",
|
||||||
|
},
|
||||||
|
"client_error": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Errors on 4xx status codes",
|
||||||
|
},
|
||||||
|
"server_error": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Errors on 5xx status codes",
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Errors for all other exceptions",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"credentials_fields": {"credentials": True},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "email-block-id",
|
||||||
|
"name": "Send Email",
|
||||||
|
"input_schema": {
|
||||||
|
"properties": {
|
||||||
|
"to_email": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Recipient email address",
|
||||||
|
},
|
||||||
|
"subject": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Subject of the email",
|
||||||
|
},
|
||||||
|
"body": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Body of the email",
|
||||||
|
},
|
||||||
|
"config": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "SMTP Config",
|
||||||
|
},
|
||||||
|
"credentials": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "SMTP credentials",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["to_email", "subject", "body", "credentials"],
|
||||||
|
},
|
||||||
|
"output_schema": {
|
||||||
|
"properties": {
|
||||||
|
"status": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Status of the email sending operation",
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Error message if sending failed",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"credentials_fields": {"credentials": True},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "claude-code-block-id",
|
||||||
|
"name": "Claude Code",
|
||||||
|
"input_schema": {
|
||||||
|
"properties": {
|
||||||
|
"e2b_credentials": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "API key for E2B platform",
|
||||||
|
},
|
||||||
|
"anthropic_credentials": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "API key for Anthropic",
|
||||||
|
},
|
||||||
|
"prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Task or instruction for Claude Code",
|
||||||
|
},
|
||||||
|
"timeout": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Sandbox timeout in seconds",
|
||||||
|
},
|
||||||
|
"setup_commands": {
|
||||||
|
"type": "array",
|
||||||
|
"description": "Shell commands to run before execution",
|
||||||
|
},
|
||||||
|
"working_directory": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Working directory for Claude Code",
|
||||||
|
},
|
||||||
|
"session_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Session ID to resume a conversation",
|
||||||
|
},
|
||||||
|
"sandbox_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Sandbox ID to reconnect to",
|
||||||
|
},
|
||||||
|
"conversation_history": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Previous conversation history",
|
||||||
|
},
|
||||||
|
"dispose_sandbox": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to dispose sandbox after execution",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"e2b_credentials",
|
||||||
|
"anthropic_credentials",
|
||||||
|
"prompt",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"output_schema": {
|
||||||
|
"properties": {
|
||||||
|
"response": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Output from Claude Code execution",
|
||||||
|
},
|
||||||
|
"files": {
|
||||||
|
"type": "array",
|
||||||
|
"description": "Files created/modified by Claude Code",
|
||||||
|
},
|
||||||
|
"conversation_history": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Full conversation history",
|
||||||
|
},
|
||||||
|
"session_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Session ID for this conversation",
|
||||||
|
},
|
||||||
|
"sandbox_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "ID of the sandbox instance",
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Error message if execution failed",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"credentials_fields": {
|
||||||
|
"e2b_credentials": True,
|
||||||
|
"anthropic_credentials": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
search_results = [
|
||||||
|
{"content_id": d["id"], "score": 0.9 - i * 0.1}
|
||||||
|
for i, d in enumerate(block_defs)
|
||||||
|
]
|
||||||
|
mock_blocks = {
|
||||||
|
d["id"]: make_mock_block(
|
||||||
|
block_id=d["id"],
|
||||||
|
name=d["name"],
|
||||||
|
block_type=BlockType.STANDARD,
|
||||||
|
input_schema=d["input_schema"],
|
||||||
|
output_schema=d["output_schema"],
|
||||||
|
credentials_fields=d["credentials_fields"],
|
||||||
|
)
|
||||||
|
for d in block_defs
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=(search_results, len(search_results)),
|
||||||
|
), patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.get_block",
|
||||||
|
side_effect=lambda bid: mock_blocks.get(bid),
|
||||||
|
):
|
||||||
|
tool = FindBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID, session=session, query="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, BlockListResponse)
|
||||||
|
assert response.count == len(block_defs)
|
||||||
|
|
||||||
|
total_chars = len(response.model_dump_json())
|
||||||
|
avg_chars = total_chars // response.count
|
||||||
|
|
||||||
|
# Print for visibility in test output
|
||||||
|
print(f"\nTotal response size: {total_chars} chars")
|
||||||
|
print(f"Number of blocks: {response.count}")
|
||||||
|
print(f"Average chars per block: {avg_chars}")
|
||||||
|
|
||||||
|
# The old response was ~90K for 10 blocks (~9K per block).
|
||||||
|
# Previous optimization reduced it to ~1.5K per block (no raw JSON schemas).
|
||||||
|
# Now with only id/name/description, we expect ~300 chars per block.
|
||||||
|
assert avg_chars < 500, (
|
||||||
|
f"Average chars per block ({avg_chars}) exceeds 500. "
|
||||||
|
f"Total response: {total_chars} chars for {response.count} blocks."
|
||||||
|
)
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
"""Shared helpers for chat tools."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def get_inputs_from_schema(
|
||||||
|
input_schema: dict[str, Any],
|
||||||
|
exclude_fields: set[str] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Extract input field info from JSON schema."""
|
||||||
|
if not isinstance(input_schema, dict):
|
||||||
|
return []
|
||||||
|
|
||||||
|
exclude = exclude_fields or set()
|
||||||
|
properties = input_schema.get("properties", {})
|
||||||
|
required = set(input_schema.get("required", []))
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"title": schema.get("title", name),
|
||||||
|
"type": schema.get("type", "string"),
|
||||||
|
"description": schema.get("description", ""),
|
||||||
|
"required": name in required,
|
||||||
|
"default": schema.get("default"),
|
||||||
|
}
|
||||||
|
for name, schema in properties.items()
|
||||||
|
if name not in exclude
|
||||||
|
]
|
||||||
@@ -25,6 +25,7 @@ class ResponseType(str, Enum):
|
|||||||
AGENT_SAVED = "agent_saved"
|
AGENT_SAVED = "agent_saved"
|
||||||
CLARIFICATION_NEEDED = "clarification_needed"
|
CLARIFICATION_NEEDED = "clarification_needed"
|
||||||
BLOCK_LIST = "block_list"
|
BLOCK_LIST = "block_list"
|
||||||
|
BLOCK_DETAILS = "block_details"
|
||||||
BLOCK_OUTPUT = "block_output"
|
BLOCK_OUTPUT = "block_output"
|
||||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||||
DOC_PAGE = "doc_page"
|
DOC_PAGE = "doc_page"
|
||||||
@@ -40,6 +41,12 @@ class ResponseType(str, Enum):
|
|||||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||||
# Input validation
|
# Input validation
|
||||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||||
|
# Web fetch
|
||||||
|
WEB_FETCH = "web_fetch"
|
||||||
|
# Code execution
|
||||||
|
BASH_EXEC = "bash_exec"
|
||||||
|
# Operation status check
|
||||||
|
OPERATION_STATUS = "operation_status"
|
||||||
|
|
||||||
|
|
||||||
# Base response model
|
# Base response model
|
||||||
@@ -335,11 +342,17 @@ class BlockInfoSummary(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
categories: list[str]
|
categories: list[str]
|
||||||
input_schema: dict[str, Any]
|
input_schema: dict[str, Any] = Field(
|
||||||
output_schema: dict[str, Any]
|
default_factory=dict,
|
||||||
|
description="Full JSON schema for block inputs",
|
||||||
|
)
|
||||||
|
output_schema: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Full JSON schema for block outputs",
|
||||||
|
)
|
||||||
required_inputs: list[BlockInputFieldInfo] = Field(
|
required_inputs: list[BlockInputFieldInfo] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="List of required input fields for this block",
|
description="List of input fields for this block",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -352,10 +365,29 @@ class BlockListResponse(ToolResponseBase):
|
|||||||
query: str
|
query: str
|
||||||
usage_hint: str = Field(
|
usage_hint: str = Field(
|
||||||
default="To execute a block, call run_block with block_id set to the block's "
|
default="To execute a block, call run_block with block_id set to the block's "
|
||||||
"'id' field and input_data containing the required fields from input_schema."
|
"'id' field and input_data containing the fields listed in required_inputs."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockDetails(BaseModel):
|
||||||
|
"""Detailed block information."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
inputs: dict[str, Any] = {}
|
||||||
|
outputs: dict[str, Any] = {}
|
||||||
|
credentials: list[CredentialsMetaInput] = []
|
||||||
|
|
||||||
|
|
||||||
|
class BlockDetailsResponse(ToolResponseBase):
|
||||||
|
"""Response for block details (first run_block attempt)."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.BLOCK_DETAILS
|
||||||
|
block: BlockDetails
|
||||||
|
user_authenticated: bool = False
|
||||||
|
|
||||||
|
|
||||||
class BlockOutputResponse(ToolResponseBase):
|
class BlockOutputResponse(ToolResponseBase):
|
||||||
"""Response for run_block tool."""
|
"""Response for run_block tool."""
|
||||||
|
|
||||||
@@ -421,3 +453,24 @@ class AsyncProcessingResponse(ToolResponseBase):
|
|||||||
status: str = "accepted" # Must be "accepted" for detection
|
status: str = "accepted" # Must be "accepted" for detection
|
||||||
operation_id: str | None = None
|
operation_id: str | None = None
|
||||||
task_id: str | None = None
|
task_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WebFetchResponse(ToolResponseBase):
|
||||||
|
"""Response for web_fetch tool."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WEB_FETCH
|
||||||
|
url: str
|
||||||
|
status_code: int
|
||||||
|
content_type: str
|
||||||
|
content: str
|
||||||
|
truncated: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class BashExecResponse(ToolResponseBase):
|
||||||
|
"""Response for bash_exec tool."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.BASH_EXEC
|
||||||
|
stdout: str
|
||||||
|
stderr: str
|
||||||
|
exit_code: int
|
||||||
|
timed_out: bool = False
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from backend.util.timezone_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
AgentDetails,
|
AgentDetails,
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
@@ -261,7 +262,7 @@ class RunAgentTool(BaseTool):
|
|||||||
),
|
),
|
||||||
requirements={
|
requirements={
|
||||||
"credentials": requirements_creds_list,
|
"credentials": requirements_creds_list,
|
||||||
"inputs": self._get_inputs_list(graph.input_schema),
|
"inputs": get_inputs_from_schema(graph.input_schema),
|
||||||
"execution_modes": self._get_execution_modes(graph),
|
"execution_modes": self._get_execution_modes(graph),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@@ -369,22 +370,6 @@ class RunAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
|
||||||
"""Extract inputs list from schema."""
|
|
||||||
inputs_list = []
|
|
||||||
if isinstance(input_schema, dict) and "properties" in input_schema:
|
|
||||||
for field_name, field_schema in input_schema["properties"].items():
|
|
||||||
inputs_list.append(
|
|
||||||
{
|
|
||||||
"name": field_name,
|
|
||||||
"title": field_schema.get("title", field_name),
|
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
|
||||||
"required": field_name in input_schema.get("required", []),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return inputs_list
|
|
||||||
|
|
||||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||||
"""Get available execution modes for the graph."""
|
"""Get available execution modes for the graph."""
|
||||||
trigger_info = graph.trigger_setup_info
|
trigger_info = graph.trigger_setup_info
|
||||||
@@ -398,7 +383,7 @@ class RunAgentTool(BaseTool):
|
|||||||
suffix: str,
|
suffix: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build a message describing available inputs for an agent."""
|
"""Build a message describing available inputs for an agent."""
|
||||||
inputs_list = self._get_inputs_list(graph.input_schema)
|
inputs_list = get_inputs_from_schema(graph.input_schema)
|
||||||
required_names = [i["name"] for i in inputs_list if i["required"]]
|
required_names = [i["name"] for i in inputs_list if i["required"]]
|
||||||
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
||||||
|
|
||||||
|
|||||||
@@ -8,23 +8,35 @@ from typing import Any
|
|||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.block import get_block
|
from backend.api.features.chat.tools.find_block import (
|
||||||
|
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||||
|
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||||
|
)
|
||||||
|
from backend.blocks import get_block
|
||||||
|
from backend.blocks._base import AnyBlockSchema
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
|
BlockDetails,
|
||||||
|
BlockDetailsResponse,
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
|
InputValidationErrorResponse,
|
||||||
SetupInfo,
|
SetupInfo,
|
||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
UserReadiness,
|
UserReadiness,
|
||||||
)
|
)
|
||||||
from .utils import build_missing_credentials_from_field_info
|
from .utils import (
|
||||||
|
build_missing_credentials_from_field_info,
|
||||||
|
match_credentials_to_requirements,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -42,8 +54,8 @@ class RunBlockTool(BaseTool):
|
|||||||
"Execute a specific block with the provided input data. "
|
"Execute a specific block with the provided input data. "
|
||||||
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
||||||
"do NOT guess or make up block IDs. "
|
"do NOT guess or make up block IDs. "
|
||||||
"Use the 'id' from find_block results and provide input_data "
|
"On first attempt (without input_data), returns detailed schema showing "
|
||||||
"matching the block's required_inputs."
|
"required inputs and outputs. Then call again with proper input_data to execute."
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -58,11 +70,19 @@ class RunBlockTool(BaseTool):
|
|||||||
"NEVER guess this - always get it from find_block first."
|
"NEVER guess this - always get it from find_block first."
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
"block_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The block's human-readable name from find_block results. "
|
||||||
|
"Used for display purposes in the UI."
|
||||||
|
),
|
||||||
|
},
|
||||||
"input_data": {
|
"input_data": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"description": (
|
"description": (
|
||||||
"Input values for the block. Use the 'required_inputs' field "
|
"Input values for the block. "
|
||||||
"from find_block to see what fields are needed."
|
"First call with empty {} to see the block's schema, "
|
||||||
|
"then call again with proper values to execute."
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -73,91 +93,6 @@ class RunBlockTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def _check_block_credentials(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
block: Any,
|
|
||||||
input_data: dict[str, Any] | None = None,
|
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
|
||||||
"""
|
|
||||||
Check if user has required credentials for a block.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
block: Block to check credentials for
|
|
||||||
input_data: Input data for the block (used to determine provider via discriminator)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[matched_credentials, missing_credentials]
|
|
||||||
"""
|
|
||||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
|
||||||
missing_credentials: list[CredentialsMetaInput] = []
|
|
||||||
input_data = input_data or {}
|
|
||||||
|
|
||||||
# Get credential field info from block's input schema
|
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
|
||||||
|
|
||||||
if not credentials_fields_info:
|
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
# Get user's available credentials
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
|
||||||
|
|
||||||
for field_name, field_info in credentials_fields_info.items():
|
|
||||||
effective_field_info = field_info
|
|
||||||
if field_info.discriminator and field_info.discriminator_mapping:
|
|
||||||
# Get discriminator from input, falling back to schema default
|
|
||||||
discriminator_value = input_data.get(field_info.discriminator)
|
|
||||||
if discriminator_value is None:
|
|
||||||
field = block.input_schema.model_fields.get(
|
|
||||||
field_info.discriminator
|
|
||||||
)
|
|
||||||
if field and field.default is not PydanticUndefined:
|
|
||||||
discriminator_value = field.default
|
|
||||||
|
|
||||||
if (
|
|
||||||
discriminator_value
|
|
||||||
and discriminator_value in field_info.discriminator_mapping
|
|
||||||
):
|
|
||||||
effective_field_info = field_info.discriminate(discriminator_value)
|
|
||||||
logger.debug(
|
|
||||||
f"Discriminated provider for {field_name}: "
|
|
||||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
|
||||||
)
|
|
||||||
|
|
||||||
matching_cred = next(
|
|
||||||
(
|
|
||||||
cred
|
|
||||||
for cred in available_creds
|
|
||||||
if cred.provider in effective_field_info.provider
|
|
||||||
and cred.type in effective_field_info.supported_types
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if matching_cred:
|
|
||||||
matched_credentials[field_name] = CredentialsMetaInput(
|
|
||||||
id=matching_cred.id,
|
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Create a placeholder for the missing credential
|
|
||||||
provider = next(iter(effective_field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
|
||||||
missing_credentials.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=field_name.replace("_", " ").title(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
@@ -212,13 +147,54 @@ class RunBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if block is excluded from CoPilot (graph-only blocks)
|
||||||
|
if (
|
||||||
|
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||||
|
):
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Block '{block.name}' cannot be run directly in CoPilot. "
|
||||||
|
"This block is designed for use within graphs only."
|
||||||
|
),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||||
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
creds_manager = IntegrationCredentialsManager()
|
||||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
matched_credentials, missing_credentials = (
|
||||||
user_id, block, input_data
|
await self._resolve_block_credentials(user_id, block, input_data)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get block schemas for details/validation
|
||||||
|
try:
|
||||||
|
input_schema: dict[str, Any] = block.input_schema.jsonschema()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to generate input schema for block %s: %s",
|
||||||
|
block_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Block '{block.name}' has an invalid input schema",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
output_schema: dict[str, Any] = block.output_schema.jsonschema()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to generate output schema for block %s: %s",
|
||||||
|
block_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Block '{block.name}' has an invalid output schema",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
if missing_credentials:
|
if missing_credentials:
|
||||||
# Return setup requirements response with missing credentials
|
# Return setup requirements response with missing credentials
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
@@ -251,6 +227,53 @@ class RunBlockTool(BaseTool):
|
|||||||
graph_version=None,
|
graph_version=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if this is a first attempt (required inputs missing)
|
||||||
|
# Return block details so user can see what inputs are needed
|
||||||
|
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||||
|
required_keys = set(input_schema.get("required", []))
|
||||||
|
required_non_credential_keys = required_keys - credentials_fields
|
||||||
|
provided_input_keys = set(input_data.keys()) - credentials_fields
|
||||||
|
|
||||||
|
# Check for unknown input fields
|
||||||
|
valid_fields = (
|
||||||
|
set(input_schema.get("properties", {}).keys()) - credentials_fields
|
||||||
|
)
|
||||||
|
unrecognized_fields = provided_input_keys - valid_fields
|
||||||
|
if unrecognized_fields:
|
||||||
|
return InputValidationErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
||||||
|
f"Block was not executed. Please use the correct field names from the schema."
|
||||||
|
),
|
||||||
|
session_id=session_id,
|
||||||
|
unrecognized_fields=sorted(unrecognized_fields),
|
||||||
|
inputs=input_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show details when not all required non-credential inputs are provided
|
||||||
|
if not (required_non_credential_keys <= provided_input_keys):
|
||||||
|
# Get credentials info for the response
|
||||||
|
credentials_meta = []
|
||||||
|
for field_name, cred_meta in matched_credentials.items():
|
||||||
|
credentials_meta.append(cred_meta)
|
||||||
|
|
||||||
|
return BlockDetailsResponse(
|
||||||
|
message=(
|
||||||
|
f"Block '{block.name}' details. "
|
||||||
|
"Provide input_data matching the inputs schema to execute the block."
|
||||||
|
),
|
||||||
|
session_id=session_id,
|
||||||
|
block=BlockDetails(
|
||||||
|
id=block_id,
|
||||||
|
name=block.name,
|
||||||
|
description=block.description or "",
|
||||||
|
inputs=input_schema,
|
||||||
|
outputs=output_schema,
|
||||||
|
credentials=credentials_meta,
|
||||||
|
),
|
||||||
|
user_authenticated=True,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get or create user's workspace for CoPilot file operations
|
# Get or create user's workspace for CoPilot file operations
|
||||||
workspace = await get_or_create_workspace(user_id)
|
workspace = await get_or_create_workspace(user_id)
|
||||||
@@ -345,29 +368,75 @@ class RunBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
async def _resolve_block_credentials(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
block: AnyBlockSchema,
|
||||||
|
input_data: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Resolve credentials for a block by matching user's available credentials.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
block: Block to resolve credentials for
|
||||||
|
input_data: Input data for the block (used to determine provider via discriminator)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple of (matched_credentials, missing_credentials) - matched credentials
|
||||||
|
are used for block execution, missing ones indicate setup requirements.
|
||||||
|
"""
|
||||||
|
input_data = input_data or {}
|
||||||
|
requirements = self._resolve_discriminated_credentials(block, input_data)
|
||||||
|
|
||||||
|
if not requirements:
|
||||||
|
return {}, []
|
||||||
|
|
||||||
|
return await match_credentials_to_requirements(user_id, requirements)
|
||||||
|
|
||||||
|
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
|
||||||
"""Extract non-credential inputs from block schema."""
|
"""Extract non-credential inputs from block schema."""
|
||||||
inputs_list = []
|
|
||||||
schema = block.input_schema.jsonschema()
|
schema = block.input_schema.jsonschema()
|
||||||
properties = schema.get("properties", {})
|
|
||||||
required_fields = set(schema.get("required", []))
|
|
||||||
|
|
||||||
# Get credential field names to exclude
|
|
||||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||||
|
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||||
|
|
||||||
for field_name, field_schema in properties.items():
|
def _resolve_discriminated_credentials(
|
||||||
# Skip credential fields
|
self,
|
||||||
if field_name in credentials_fields:
|
block: AnyBlockSchema,
|
||||||
continue
|
input_data: dict[str, Any],
|
||||||
|
) -> dict[str, CredentialsFieldInfo]:
|
||||||
|
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||||
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
|
if not credentials_fields_info:
|
||||||
|
return {}
|
||||||
|
|
||||||
inputs_list.append(
|
resolved: dict[str, CredentialsFieldInfo] = {}
|
||||||
{
|
|
||||||
"name": field_name,
|
|
||||||
"title": field_schema.get("title", field_name),
|
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
|
||||||
"required": field_name in required_fields,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs_list
|
for field_name, field_info in credentials_fields_info.items():
|
||||||
|
effective_field_info = field_info
|
||||||
|
|
||||||
|
if field_info.discriminator and field_info.discriminator_mapping:
|
||||||
|
discriminator_value = input_data.get(field_info.discriminator)
|
||||||
|
if discriminator_value is None:
|
||||||
|
field = block.input_schema.model_fields.get(
|
||||||
|
field_info.discriminator
|
||||||
|
)
|
||||||
|
if field and field.default is not PydanticUndefined:
|
||||||
|
discriminator_value = field.default
|
||||||
|
|
||||||
|
if (
|
||||||
|
discriminator_value
|
||||||
|
and discriminator_value in field_info.discriminator_mapping
|
||||||
|
):
|
||||||
|
effective_field_info = field_info.discriminate(discriminator_value)
|
||||||
|
# For host-scoped credentials, add the discriminator value
|
||||||
|
# (e.g., URL) so _credential_is_for_host can match it
|
||||||
|
effective_field_info.discriminator_values.add(discriminator_value)
|
||||||
|
logger.debug(
|
||||||
|
f"Discriminated provider for {field_name}: "
|
||||||
|
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved[field_name] = effective_field_info
|
||||||
|
|
||||||
|
return resolved
|
||||||
|
|||||||
@@ -0,0 +1,362 @@
|
|||||||
|
"""Tests for block execution guards and input validation in RunBlockTool."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.models import (
|
||||||
|
BlockDetailsResponse,
|
||||||
|
BlockOutputResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
InputValidationErrorResponse,
|
||||||
|
)
|
||||||
|
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||||
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
_TEST_USER_ID = "test-user-run-block"
|
||||||
|
|
||||||
|
|
||||||
|
def make_mock_block(
|
||||||
|
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||||
|
):
|
||||||
|
"""Create a mock block for testing."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.id = block_id
|
||||||
|
mock.name = name
|
||||||
|
mock.block_type = block_type
|
||||||
|
mock.disabled = disabled
|
||||||
|
mock.input_schema = MagicMock()
|
||||||
|
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||||
|
mock.input_schema.get_credentials_fields_info.return_value = []
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
def make_mock_block_with_schema(
|
||||||
|
block_id: str,
|
||||||
|
name: str,
|
||||||
|
input_properties: dict,
|
||||||
|
required_fields: list[str],
|
||||||
|
output_properties: dict | None = None,
|
||||||
|
):
|
||||||
|
"""Create a mock block with a defined input/output schema for validation tests."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.id = block_id
|
||||||
|
mock.name = name
|
||||||
|
mock.block_type = BlockType.STANDARD
|
||||||
|
mock.disabled = False
|
||||||
|
mock.description = f"Test block: {name}"
|
||||||
|
|
||||||
|
input_schema = {
|
||||||
|
"properties": input_properties,
|
||||||
|
"required": required_fields,
|
||||||
|
}
|
||||||
|
mock.input_schema = MagicMock()
|
||||||
|
mock.input_schema.jsonschema.return_value = input_schema
|
||||||
|
mock.input_schema.get_credentials_fields_info.return_value = {}
|
||||||
|
mock.input_schema.get_credentials_fields.return_value = {}
|
||||||
|
|
||||||
|
output_schema = {
|
||||||
|
"properties": output_properties or {"result": {"type": "string"}},
|
||||||
|
}
|
||||||
|
mock.output_schema = MagicMock()
|
||||||
|
mock.output_schema.jsonschema.return_value = output_schema
|
||||||
|
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunBlockFiltering:
|
||||||
|
"""Tests for block execution guards in RunBlockTool."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_type_returns_error(self):
|
||||||
|
"""Attempting to execute a block with excluded BlockType returns error."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=input_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="input-block-id",
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, ErrorResponse)
|
||||||
|
assert "cannot be run directly in CoPilot" in response.message
|
||||||
|
assert "designed for use within graphs only" in response.message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_id_returns_error(self):
|
||||||
|
"""Attempting to execute SmartDecisionMakerBlock returns error."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||||
|
smart_block = make_mock_block(
|
||||||
|
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=smart_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id=smart_decision_id,
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, ErrorResponse)
|
||||||
|
assert "cannot be run directly in CoPilot" in response.message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_non_excluded_block_passes_guard(self):
|
||||||
|
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
standard_block = make_mock_block(
|
||||||
|
"standard-id", "HTTP Request", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=standard_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="standard-id",
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should NOT be an ErrorResponse about CoPilot exclusion
|
||||||
|
# (may be other errors like missing credentials, but not the exclusion guard)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
assert "cannot be run directly in CoPilot" not in response.message
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunBlockInputValidation:
|
||||||
|
"""Tests for input field validation in RunBlockTool.
|
||||||
|
|
||||||
|
run_block rejects unknown input field names with InputValidationErrorResponse,
|
||||||
|
preventing silent failures where incorrect keys would be ignored and the block
|
||||||
|
would execute with default values instead of the caller's intended values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_unknown_input_fields_are_rejected(self):
|
||||||
|
"""run_block rejects unknown input fields instead of silently ignoring them.
|
||||||
|
|
||||||
|
Scenario: The AI Text Generator block has a field called 'model' (for LLM model
|
||||||
|
selection), but the LLM calling the tool guesses wrong and sends 'LLM_Model'
|
||||||
|
instead. The block should reject the request and return the valid schema.
|
||||||
|
"""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
mock_block = make_mock_block_with_schema(
|
||||||
|
block_id="ai-text-gen-id",
|
||||||
|
name="AI Text Generator",
|
||||||
|
input_properties={
|
||||||
|
"prompt": {"type": "string", "description": "The prompt to send"},
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The LLM model to use",
|
||||||
|
"default": "gpt-4o-mini",
|
||||||
|
},
|
||||||
|
"sys_prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "System prompt",
|
||||||
|
"default": "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required_fields=["prompt"],
|
||||||
|
output_properties={"response": {"type": "string"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=mock_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
|
||||||
|
# Provide 'prompt' (correct) but 'LLM_Model' instead of 'model' (wrong key)
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="ai-text-gen-id",
|
||||||
|
input_data={
|
||||||
|
"prompt": "Write a haiku about coding",
|
||||||
|
"LLM_Model": "claude-opus-4-6", # WRONG KEY - should be 'model'
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, InputValidationErrorResponse)
|
||||||
|
assert "LLM_Model" in response.unrecognized_fields
|
||||||
|
assert "Block was not executed" in response.message
|
||||||
|
assert "inputs" in response.model_dump() # valid schema included
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_multiple_wrong_keys_are_all_reported(self):
|
||||||
|
"""All unrecognized field names are reported in a single error response."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
mock_block = make_mock_block_with_schema(
|
||||||
|
block_id="ai-text-gen-id",
|
||||||
|
name="AI Text Generator",
|
||||||
|
input_properties={
|
||||||
|
"prompt": {"type": "string"},
|
||||||
|
"model": {"type": "string", "default": "gpt-4o-mini"},
|
||||||
|
"sys_prompt": {"type": "string", "default": ""},
|
||||||
|
"retry": {"type": "integer", "default": 3},
|
||||||
|
},
|
||||||
|
required_fields=["prompt"],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=mock_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="ai-text-gen-id",
|
||||||
|
input_data={
|
||||||
|
"prompt": "Hello", # correct
|
||||||
|
"llm_model": "claude-opus-4-6", # WRONG - should be 'model'
|
||||||
|
"system_prompt": "Be helpful", # WRONG - should be 'sys_prompt'
|
||||||
|
"retries": 5, # WRONG - should be 'retry'
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, InputValidationErrorResponse)
|
||||||
|
assert set(response.unrecognized_fields) == {
|
||||||
|
"llm_model",
|
||||||
|
"system_prompt",
|
||||||
|
"retries",
|
||||||
|
}
|
||||||
|
assert "Block was not executed" in response.message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_unknown_fields_rejected_even_with_missing_required(self):
|
||||||
|
"""Unknown fields are caught before the missing-required-fields check."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
mock_block = make_mock_block_with_schema(
|
||||||
|
block_id="ai-text-gen-id",
|
||||||
|
name="AI Text Generator",
|
||||||
|
input_properties={
|
||||||
|
"prompt": {"type": "string"},
|
||||||
|
"model": {"type": "string", "default": "gpt-4o-mini"},
|
||||||
|
},
|
||||||
|
required_fields=["prompt"],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=mock_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
|
||||||
|
# 'prompt' is missing AND 'LLM_Model' is an unknown field
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="ai-text-gen-id",
|
||||||
|
input_data={
|
||||||
|
"LLM_Model": "claude-opus-4-6", # wrong key, and 'prompt' is missing
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unknown fields are caught first
|
||||||
|
assert isinstance(response, InputValidationErrorResponse)
|
||||||
|
assert "LLM_Model" in response.unrecognized_fields
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_correct_inputs_still_execute(self):
|
||||||
|
"""Correct input field names pass validation and the block executes."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
mock_block = make_mock_block_with_schema(
|
||||||
|
block_id="ai-text-gen-id",
|
||||||
|
name="AI Text Generator",
|
||||||
|
input_properties={
|
||||||
|
"prompt": {"type": "string"},
|
||||||
|
"model": {"type": "string", "default": "gpt-4o-mini"},
|
||||||
|
},
|
||||||
|
required_fields=["prompt"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_execute(input_data, **kwargs):
|
||||||
|
yield "response", "Generated text"
|
||||||
|
|
||||||
|
mock_block.execute = mock_execute
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=mock_block,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_or_create_workspace",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=MagicMock(id="test-workspace-id"),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="ai-text-gen-id",
|
||||||
|
input_data={
|
||||||
|
"prompt": "Write a haiku",
|
||||||
|
"model": "gpt-4o-mini", # correct field name
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, BlockOutputResponse)
|
||||||
|
assert response.success is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_missing_required_fields_returns_details(self):
|
||||||
|
"""Missing required fields returns BlockDetailsResponse with schema."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
mock_block = make_mock_block_with_schema(
|
||||||
|
block_id="ai-text-gen-id",
|
||||||
|
name="AI Text Generator",
|
||||||
|
input_properties={
|
||||||
|
"prompt": {"type": "string"},
|
||||||
|
"model": {"type": "string", "default": "gpt-4o-mini"},
|
||||||
|
},
|
||||||
|
required_fields=["prompt"],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=mock_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
|
||||||
|
# Only provide valid optional field, missing required 'prompt'
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="ai-text-gen-id",
|
||||||
|
input_data={
|
||||||
|
"model": "gpt-4o-mini", # valid but optional
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, BlockDetailsResponse)
|
||||||
@@ -0,0 +1,265 @@
|
|||||||
|
"""Sandbox execution utilities for code execution tools.
|
||||||
|
|
||||||
|
Provides filesystem + network isolated command execution using **bubblewrap**
|
||||||
|
(``bwrap``): whitelist-only filesystem (only system dirs visible read-only),
|
||||||
|
writable workspace only, clean environment, network blocked.
|
||||||
|
|
||||||
|
Tools that call :func:`run_sandboxed` must first check :func:`has_full_sandbox`
|
||||||
|
and refuse to run if bubblewrap is not available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_DEFAULT_TIMEOUT = 30
|
||||||
|
_MAX_TIMEOUT = 120
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sandbox capability detection (cached at first call)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_BWRAP_AVAILABLE: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def has_full_sandbox() -> bool:
|
||||||
|
"""Return True if bubblewrap is available (filesystem + network isolation).
|
||||||
|
|
||||||
|
On non-Linux platforms (macOS), always returns False.
|
||||||
|
"""
|
||||||
|
global _BWRAP_AVAILABLE
|
||||||
|
if _BWRAP_AVAILABLE is None:
|
||||||
|
_BWRAP_AVAILABLE = (
|
||||||
|
platform.system() == "Linux" and shutil.which("bwrap") is not None
|
||||||
|
)
|
||||||
|
return _BWRAP_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
|
WORKSPACE_PREFIX = "/tmp/copilot-"
|
||||||
|
|
||||||
|
|
||||||
|
def make_session_path(session_id: str) -> str:
|
||||||
|
"""Build a sanitized, session-specific path under :data:`WORKSPACE_PREFIX`.
|
||||||
|
|
||||||
|
Shared by both the SDK working-directory setup and the sandbox tools so
|
||||||
|
they always resolve to the same directory for a given session.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Strip all characters except ``[A-Za-z0-9-]``.
|
||||||
|
2. Construct ``/tmp/copilot-<safe_id>``.
|
||||||
|
3. Validate via ``os.path.normpath`` + ``startswith`` (CodeQL-recognised
|
||||||
|
sanitizer) to prevent path traversal.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the resulting path escapes the prefix.
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id)
|
||||||
|
if not safe_id:
|
||||||
|
safe_id = "default"
|
||||||
|
path = os.path.normpath(f"{WORKSPACE_PREFIX}{safe_id}")
|
||||||
|
if not path.startswith(WORKSPACE_PREFIX):
|
||||||
|
raise ValueError(f"Session path escaped prefix: {path}")
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def get_workspace_dir(session_id: str) -> str:
|
||||||
|
"""Get or create the workspace directory for a session.
|
||||||
|
|
||||||
|
Uses :func:`make_session_path` — the same path the SDK uses — so that
|
||||||
|
bash_exec shares the workspace with the SDK file tools.
|
||||||
|
"""
|
||||||
|
workspace = make_session_path(session_id)
|
||||||
|
os.makedirs(workspace, exist_ok=True)
|
||||||
|
return workspace
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Bubblewrap command builder
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# System directories mounted read-only inside the sandbox.
|
||||||
|
# ONLY these are visible — /app, /root, /home, /opt, /var etc. are NOT accessible.
|
||||||
|
_SYSTEM_RO_BINDS = [
|
||||||
|
"/usr", # binaries, libraries, Python interpreter
|
||||||
|
"/etc", # system config: ld.so, locale, passwd, alternatives
|
||||||
|
]
|
||||||
|
|
||||||
|
# Compat paths: symlinks to /usr/* on modern Debian, real dirs on older systems.
|
||||||
|
# On Debian 13 these are symlinks (e.g. /bin -> usr/bin). bwrap --ro-bind
|
||||||
|
# can't create a symlink target, so we detect and use --symlink instead.
|
||||||
|
# /lib64 is critical: the ELF dynamic linker lives at /lib64/ld-linux-x86-64.so.2.
|
||||||
|
_COMPAT_PATHS = [
|
||||||
|
("/bin", "usr/bin"), # -> /usr/bin on Debian 13
|
||||||
|
("/sbin", "usr/sbin"), # -> /usr/sbin on Debian 13
|
||||||
|
("/lib", "usr/lib"), # -> /usr/lib on Debian 13
|
||||||
|
("/lib64", "usr/lib64"), # 64-bit libraries / ELF interpreter
|
||||||
|
]
|
||||||
|
|
||||||
|
# Resource limits to prevent fork bombs, memory exhaustion, and disk abuse.
|
||||||
|
# Applied via ulimit inside the sandbox before exec'ing the user command.
|
||||||
|
_RESOURCE_LIMITS = (
|
||||||
|
"ulimit -u 64" # max 64 processes (prevents fork bombs)
|
||||||
|
" -v 524288" # 512 MB virtual memory
|
||||||
|
" -f 51200" # 50 MB max file size (1024-byte blocks)
|
||||||
|
" -n 256" # 256 open file descriptors
|
||||||
|
" 2>/dev/null"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_bwrap_command(
|
||||||
|
command: list[str], cwd: str, env: dict[str, str]
|
||||||
|
) -> list[str]:
|
||||||
|
"""Build a bubblewrap command with strict filesystem + network isolation.
|
||||||
|
|
||||||
|
Security model:
|
||||||
|
- **Whitelist-only filesystem**: only system directories (``/usr``, ``/etc``,
|
||||||
|
``/bin``, ``/lib``) are mounted read-only. Application code (``/app``),
|
||||||
|
home directories, ``/var``, ``/opt``, etc. are NOT accessible at all.
|
||||||
|
- **Writable workspace only**: the per-session workspace is the sole
|
||||||
|
writable path.
|
||||||
|
- **Clean environment**: ``--clearenv`` wipes all inherited env vars.
|
||||||
|
Only the explicitly-passed safe env vars are set inside the sandbox.
|
||||||
|
- **Network isolation**: ``--unshare-net`` blocks all network access.
|
||||||
|
- **Resource limits**: ulimit caps on processes (64), memory (512MB),
|
||||||
|
file size (50MB), and open FDs (256) to prevent fork bombs and abuse.
|
||||||
|
- **New session**: prevents terminal control escape.
|
||||||
|
- **Die with parent**: prevents orphaned sandbox processes.
|
||||||
|
"""
|
||||||
|
cmd = [
|
||||||
|
"bwrap",
|
||||||
|
# Create a new user namespace so bwrap can set up sandboxing
|
||||||
|
# inside unprivileged Docker containers (no CAP_SYS_ADMIN needed).
|
||||||
|
"--unshare-user",
|
||||||
|
# Wipe all inherited environment variables (API keys, secrets, etc.)
|
||||||
|
"--clearenv",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Set only the safe env vars inside the sandbox
|
||||||
|
for key, value in env.items():
|
||||||
|
cmd.extend(["--setenv", key, value])
|
||||||
|
|
||||||
|
# System directories: read-only
|
||||||
|
for path in _SYSTEM_RO_BINDS:
|
||||||
|
cmd.extend(["--ro-bind", path, path])
|
||||||
|
|
||||||
|
# Compat paths: use --symlink when host path is a symlink (Debian 13),
|
||||||
|
# --ro-bind when it's a real directory (older distros).
|
||||||
|
for path, symlink_target in _COMPAT_PATHS:
|
||||||
|
if os.path.islink(path):
|
||||||
|
cmd.extend(["--symlink", symlink_target, path])
|
||||||
|
elif os.path.exists(path):
|
||||||
|
cmd.extend(["--ro-bind", path, path])
|
||||||
|
|
||||||
|
# Wrap the user command with resource limits:
|
||||||
|
# sh -c 'ulimit ...; exec "$@"' -- <original command>
|
||||||
|
# `exec "$@"` replaces the shell so there's no extra process overhead,
|
||||||
|
# and properly handles arguments with spaces.
|
||||||
|
limited_command = [
|
||||||
|
"sh",
|
||||||
|
"-c",
|
||||||
|
f'{_RESOURCE_LIMITS}; exec "$@"',
|
||||||
|
"--",
|
||||||
|
*command,
|
||||||
|
]
|
||||||
|
|
||||||
|
cmd.extend(
|
||||||
|
[
|
||||||
|
# Fresh virtual filesystems
|
||||||
|
"--dev",
|
||||||
|
"/dev",
|
||||||
|
"--proc",
|
||||||
|
"/proc",
|
||||||
|
"--tmpfs",
|
||||||
|
"/tmp",
|
||||||
|
# Workspace bind AFTER --tmpfs /tmp so it's visible through the tmpfs.
|
||||||
|
# (workspace lives under /tmp/copilot-<session>)
|
||||||
|
"--bind",
|
||||||
|
cwd,
|
||||||
|
cwd,
|
||||||
|
# Isolation
|
||||||
|
"--unshare-net",
|
||||||
|
"--die-with-parent",
|
||||||
|
"--new-session",
|
||||||
|
"--chdir",
|
||||||
|
cwd,
|
||||||
|
"--",
|
||||||
|
*limited_command,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def run_sandboxed(
|
||||||
|
command: list[str],
|
||||||
|
cwd: str,
|
||||||
|
timeout: int = _DEFAULT_TIMEOUT,
|
||||||
|
env: dict[str, str] | None = None,
|
||||||
|
) -> tuple[str, str, int, bool]:
|
||||||
|
"""Run a command inside a bubblewrap sandbox.
|
||||||
|
|
||||||
|
Callers **must** check :func:`has_full_sandbox` before calling this
|
||||||
|
function. If bubblewrap is not available, this function raises
|
||||||
|
:class:`RuntimeError` rather than running unsandboxed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(stdout, stderr, exit_code, timed_out)
|
||||||
|
"""
|
||||||
|
if not has_full_sandbox():
|
||||||
|
raise RuntimeError(
|
||||||
|
"run_sandboxed() requires bubblewrap but bwrap is not available. "
|
||||||
|
"Callers must check has_full_sandbox() before calling this function."
|
||||||
|
)
|
||||||
|
|
||||||
|
timeout = min(max(timeout, 1), _MAX_TIMEOUT)
|
||||||
|
|
||||||
|
safe_env = {
|
||||||
|
"PATH": "/usr/local/bin:/usr/bin:/bin",
|
||||||
|
"HOME": cwd,
|
||||||
|
"TMPDIR": cwd,
|
||||||
|
"LANG": "en_US.UTF-8",
|
||||||
|
"PYTHONDONTWRITEBYTECODE": "1",
|
||||||
|
"PYTHONIOENCODING": "utf-8",
|
||||||
|
}
|
||||||
|
if env:
|
||||||
|
safe_env.update(env)
|
||||||
|
|
||||||
|
full_command = _build_bwrap_command(command, cwd, safe_env)
|
||||||
|
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*full_command,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
cwd=cwd,
|
||||||
|
env=safe_env,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdout_bytes, stderr_bytes = await asyncio.wait_for(
|
||||||
|
proc.communicate(), timeout=timeout
|
||||||
|
)
|
||||||
|
stdout = stdout_bytes.decode("utf-8", errors="replace")
|
||||||
|
stderr = stderr_bytes.decode("utf-8", errors="replace")
|
||||||
|
return stdout, stderr, proc.returncode or 0, False
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
proc.kill()
|
||||||
|
await proc.communicate()
|
||||||
|
return "", f"Execution timed out after {timeout}s", -1, True
|
||||||
|
|
||||||
|
except RuntimeError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
return "", f"Sandbox error: {e}", -1, False
|
||||||
@@ -0,0 +1,153 @@
|
|||||||
|
"""Tests for BlockDetailsResponse in RunBlockTool."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.models import BlockDetailsResponse
|
||||||
|
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||||
|
from backend.blocks._base import BlockType
|
||||||
|
from backend.data.model import CredentialsMetaInput
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
_TEST_USER_ID = "test-user-run-block-details"
|
||||||
|
|
||||||
|
|
||||||
|
def make_mock_block_with_inputs(
|
||||||
|
block_id: str, name: str, description: str = "Test description"
|
||||||
|
):
|
||||||
|
"""Create a mock block with input/output schemas for testing."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.id = block_id
|
||||||
|
mock.name = name
|
||||||
|
mock.description = description
|
||||||
|
mock.block_type = BlockType.STANDARD
|
||||||
|
mock.disabled = False
|
||||||
|
|
||||||
|
# Input schema with non-credential fields
|
||||||
|
mock.input_schema = MagicMock()
|
||||||
|
mock.input_schema.jsonschema.return_value = {
|
||||||
|
"properties": {
|
||||||
|
"url": {"type": "string", "description": "URL to fetch"},
|
||||||
|
"method": {"type": "string", "description": "HTTP method"},
|
||||||
|
},
|
||||||
|
"required": ["url"],
|
||||||
|
}
|
||||||
|
mock.input_schema.get_credentials_fields.return_value = {}
|
||||||
|
mock.input_schema.get_credentials_fields_info.return_value = {}
|
||||||
|
|
||||||
|
# Output schema
|
||||||
|
mock.output_schema = MagicMock()
|
||||||
|
mock.output_schema.jsonschema.return_value = {
|
||||||
|
"properties": {
|
||||||
|
"response": {"type": "object", "description": "HTTP response"},
|
||||||
|
"error": {"type": "string", "description": "Error message"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_run_block_returns_details_when_no_input_provided():
|
||||||
|
"""When run_block is called without input_data, it should return BlockDetailsResponse."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
# Create a block with inputs
|
||||||
|
http_block = make_mock_block_with_inputs(
|
||||||
|
"http-block-id", "HTTP Request", "Send HTTP requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=http_block,
|
||||||
|
):
|
||||||
|
# Mock credentials check to return no missing credentials
|
||||||
|
with patch.object(
|
||||||
|
RunBlockTool,
|
||||||
|
"_resolve_block_credentials",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=({}, []), # (matched_credentials, missing_credentials)
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="http-block-id",
|
||||||
|
input_data={}, # Empty input data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return BlockDetailsResponse showing the schema
|
||||||
|
assert isinstance(response, BlockDetailsResponse)
|
||||||
|
assert response.block.id == "http-block-id"
|
||||||
|
assert response.block.name == "HTTP Request"
|
||||||
|
assert response.block.description == "Send HTTP requests"
|
||||||
|
assert "url" in response.block.inputs["properties"]
|
||||||
|
assert "method" in response.block.inputs["properties"]
|
||||||
|
assert "response" in response.block.outputs["properties"]
|
||||||
|
assert response.user_authenticated is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_run_block_returns_details_when_only_credentials_provided():
|
||||||
|
"""When only credentials are provided (no actual input), should return details."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
# Create a block with both credential and non-credential inputs
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.id = "api-block-id"
|
||||||
|
mock.name = "API Call"
|
||||||
|
mock.description = "Make API calls"
|
||||||
|
mock.block_type = BlockType.STANDARD
|
||||||
|
mock.disabled = False
|
||||||
|
|
||||||
|
mock.input_schema = MagicMock()
|
||||||
|
mock.input_schema.jsonschema.return_value = {
|
||||||
|
"properties": {
|
||||||
|
"credentials": {"type": "object", "description": "API credentials"},
|
||||||
|
"endpoint": {"type": "string", "description": "API endpoint"},
|
||||||
|
},
|
||||||
|
"required": ["credentials", "endpoint"],
|
||||||
|
}
|
||||||
|
mock.input_schema.get_credentials_fields.return_value = {"credentials": True}
|
||||||
|
mock.input_schema.get_credentials_fields_info.return_value = {}
|
||||||
|
|
||||||
|
mock.output_schema = MagicMock()
|
||||||
|
mock.output_schema.jsonschema.return_value = {
|
||||||
|
"properties": {"result": {"type": "object"}}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=mock,
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
RunBlockTool,
|
||||||
|
"_resolve_block_credentials",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=(
|
||||||
|
{
|
||||||
|
"credentials": CredentialsMetaInput(
|
||||||
|
id="cred-id",
|
||||||
|
provider=ProviderName("test_provider"),
|
||||||
|
type="api_key",
|
||||||
|
title="Test Credential",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="api-block-id",
|
||||||
|
input_data={"credentials": {"some": "cred"}}, # Only credential
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return details because no non-credential inputs provided
|
||||||
|
assert isinstance(response, BlockDetailsResponse)
|
||||||
|
assert response.block.id == "api-block-id"
|
||||||
|
assert response.block.name == "API Call"
|
||||||
@@ -6,9 +6,9 @@ from typing import Any
|
|||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.library import model as library_model
|
from backend.api.features.library import model as library_model
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data import graph as graph_db
|
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
|
Credentials,
|
||||||
CredentialsFieldInfo,
|
CredentialsFieldInfo,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
HostScopedCredentials,
|
HostScopedCredentials,
|
||||||
@@ -44,14 +44,8 @@ async def fetch_graph_from_store_slug(
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Get the graph from store listing version
|
# Get the graph from store listing version
|
||||||
graph_meta = await store_db.get_available_graph(
|
graph = await store_db.get_available_graph(
|
||||||
store_agent.store_listing_version_id
|
store_agent.store_listing_version_id, hide_nodes=False
|
||||||
)
|
|
||||||
graph = await graph_db.get_graph(
|
|
||||||
graph_id=graph_meta.id,
|
|
||||||
version=graph_meta.version,
|
|
||||||
user_id=None, # Public access
|
|
||||||
include_subgraphs=True,
|
|
||||||
)
|
)
|
||||||
return graph, store_agent
|
return graph, store_agent
|
||||||
|
|
||||||
@@ -128,7 +122,7 @@ def build_missing_credentials_from_graph(
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
field_key: _serialize_missing_credential(field_key, field_info)
|
field_key: _serialize_missing_credential(field_key, field_info)
|
||||||
for field_key, (field_info, _node_fields) in aggregated_fields.items()
|
for field_key, (field_info, _, _) in aggregated_fields.items()
|
||||||
if field_key not in matched_keys
|
if field_key not in matched_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,6 +224,99 @@ async def get_or_create_library_agent(
|
|||||||
return library_agents[0]
|
return library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def match_credentials_to_requirements(
|
||||||
|
user_id: str,
|
||||||
|
requirements: dict[str, CredentialsFieldInfo],
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Match user's credentials against a dictionary of credential requirements.
|
||||||
|
|
||||||
|
This is the core matching logic shared by both graph and block credential matching.
|
||||||
|
"""
|
||||||
|
matched: dict[str, CredentialsMetaInput] = {}
|
||||||
|
missing: list[CredentialsMetaInput] = []
|
||||||
|
|
||||||
|
if not requirements:
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
available_creds = await get_user_credentials(user_id)
|
||||||
|
|
||||||
|
for field_name, field_info in requirements.items():
|
||||||
|
matching_cred = find_matching_credential(available_creds, field_info)
|
||||||
|
|
||||||
|
if matching_cred:
|
||||||
|
try:
|
||||||
|
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
||||||
|
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
||||||
|
f"credential_id={matching_cred.id}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=f"{field_name} (validation failed: {e})",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=field_name.replace("_", " ").title(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_credentials(user_id: str) -> list[Credentials]:
|
||||||
|
"""Get all available credentials for a user."""
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
return await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def find_matching_credential(
|
||||||
|
available_creds: list[Credentials],
|
||||||
|
field_info: CredentialsFieldInfo,
|
||||||
|
) -> Credentials | None:
|
||||||
|
"""Find a credential that matches the required provider, type, scopes, and host."""
|
||||||
|
for cred in available_creds:
|
||||||
|
if cred.provider not in field_info.provider:
|
||||||
|
continue
|
||||||
|
if cred.type not in field_info.supported_types:
|
||||||
|
continue
|
||||||
|
if cred.type == "oauth2" and not _credential_has_required_scopes(
|
||||||
|
cred, field_info
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
|
||||||
|
continue
|
||||||
|
return cred
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_credential_meta_from_match(
|
||||||
|
matching_cred: Credentials,
|
||||||
|
) -> CredentialsMetaInput:
|
||||||
|
"""Create a CredentialsMetaInput from a matched credential."""
|
||||||
|
return CredentialsMetaInput(
|
||||||
|
id=matching_cred.id,
|
||||||
|
provider=matching_cred.provider, # type: ignore
|
||||||
|
type=matching_cred.type,
|
||||||
|
title=matching_cred.title,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def match_user_credentials_to_graph(
|
async def match_user_credentials_to_graph(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
graph: GraphModel,
|
graph: GraphModel,
|
||||||
@@ -269,7 +356,8 @@ async def match_user_credentials_to_graph(
|
|||||||
# provider is in the set of acceptable providers.
|
# provider is in the set of acceptable providers.
|
||||||
for credential_field_name, (
|
for credential_field_name, (
|
||||||
credential_requirements,
|
credential_requirements,
|
||||||
_node_fields,
|
_,
|
||||||
|
_,
|
||||||
) in aggregated_creds.items():
|
) in aggregated_creds.items():
|
||||||
# Find first matching credential by provider, type, and scopes
|
# Find first matching credential by provider, type, and scopes
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
@@ -337,8 +425,6 @@ def _credential_has_required_scopes(
|
|||||||
# If no scopes are required, any credential matches
|
# If no scopes are required, any credential matches
|
||||||
if not requirements.required_scopes:
|
if not requirements.required_scopes:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check that credential scopes are a superset of required scopes
|
|
||||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,151 @@
|
|||||||
|
"""Web fetch tool — safely retrieve public web page content."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import html2text
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
|
from backend.api.features.chat.tools.models import (
|
||||||
|
ErrorResponse,
|
||||||
|
ToolResponseBase,
|
||||||
|
WebFetchResponse,
|
||||||
|
)
|
||||||
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Limits
|
||||||
|
_MAX_CONTENT_BYTES = 102_400 # 100 KB download cap
|
||||||
|
_REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=15)
|
||||||
|
|
||||||
|
# Content types we'll read as text
|
||||||
|
_TEXT_CONTENT_TYPES = {
|
||||||
|
"text/html",
|
||||||
|
"text/plain",
|
||||||
|
"text/xml",
|
||||||
|
"text/csv",
|
||||||
|
"text/markdown",
|
||||||
|
"application/json",
|
||||||
|
"application/xml",
|
||||||
|
"application/xhtml+xml",
|
||||||
|
"application/rss+xml",
|
||||||
|
"application/atom+xml",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_text_content(content_type: str) -> bool:
|
||||||
|
base = content_type.split(";")[0].strip().lower()
|
||||||
|
return base in _TEXT_CONTENT_TYPES or base.startswith("text/")
|
||||||
|
|
||||||
|
|
||||||
|
def _html_to_text(html: str) -> str:
|
||||||
|
h = html2text.HTML2Text()
|
||||||
|
h.ignore_links = False
|
||||||
|
h.ignore_images = True
|
||||||
|
h.body_width = 0
|
||||||
|
return h.handle(html)
|
||||||
|
|
||||||
|
|
||||||
|
class WebFetchTool(BaseTool):
|
||||||
|
"""Safely fetch content from a public URL using SSRF-protected HTTP."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "web_fetch"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Fetch the content of a public web page by URL. "
|
||||||
|
"Returns readable text extracted from HTML by default. "
|
||||||
|
"Useful for reading documentation, articles, and API responses. "
|
||||||
|
"Only supports HTTP/HTTPS GET requests to public URLs "
|
||||||
|
"(private/internal network addresses are blocked)."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The public HTTP/HTTPS URL to fetch.",
|
||||||
|
},
|
||||||
|
"extract_text": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": (
|
||||||
|
"If true (default), extract readable text from HTML. "
|
||||||
|
"If false, return raw content."
|
||||||
|
),
|
||||||
|
"default": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["url"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
url: str = (kwargs.get("url") or "").strip()
|
||||||
|
extract_text: bool = kwargs.get("extract_text", True)
|
||||||
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
|
if not url:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide a URL to fetch.",
|
||||||
|
error="missing_url",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = Requests(raise_for_status=False, retry_max_attempts=1)
|
||||||
|
response = await client.get(url, timeout=_REQUEST_TIMEOUT)
|
||||||
|
except ValueError as e:
|
||||||
|
# validate_url raises ValueError for SSRF / blocked IPs
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"URL blocked: {e}",
|
||||||
|
error="url_blocked",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[web_fetch] Request failed for {url}: {e}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to fetch URL: {e}",
|
||||||
|
error="fetch_failed",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
content_type = response.headers.get("content-type", "")
|
||||||
|
if not _is_text_content(content_type):
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Non-text content type: {content_type.split(';')[0]}",
|
||||||
|
error="unsupported_content_type",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw = response.content[:_MAX_CONTENT_BYTES]
|
||||||
|
text = raw.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
if extract_text and "html" in content_type.lower():
|
||||||
|
text = _html_to_text(text)
|
||||||
|
|
||||||
|
return WebFetchResponse(
|
||||||
|
message=f"Fetched {url}",
|
||||||
|
url=response.url,
|
||||||
|
status_code=response.status,
|
||||||
|
content_type=content_type.split(";")[0].strip(),
|
||||||
|
content=text,
|
||||||
|
truncated=False,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
@@ -88,7 +88,9 @@ class ListWorkspaceFilesTool(BaseTool):
|
|||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"List files in the user's workspace. "
|
"List files in the user's persistent workspace (cloud storage). "
|
||||||
|
"These files survive across sessions. "
|
||||||
|
"For ephemeral session files, use the SDK Read/Glob tools instead. "
|
||||||
"Returns file names, paths, sizes, and metadata. "
|
"Returns file names, paths, sizes, and metadata. "
|
||||||
"Optionally filter by path prefix."
|
"Optionally filter by path prefix."
|
||||||
)
|
)
|
||||||
@@ -204,7 +206,9 @@ class ReadWorkspaceFileTool(BaseTool):
|
|||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Read a file from the user's workspace. "
|
"Read a file from the user's persistent workspace (cloud storage). "
|
||||||
|
"These files survive across sessions. "
|
||||||
|
"For ephemeral session files, use the SDK Read tool instead. "
|
||||||
"Specify either file_id or path to identify the file. "
|
"Specify either file_id or path to identify the file. "
|
||||||
"For small text files, returns content directly. "
|
"For small text files, returns content directly. "
|
||||||
"For large or binary files, returns metadata and a download URL. "
|
"For large or binary files, returns metadata and a download URL. "
|
||||||
@@ -378,7 +382,9 @@ class WriteWorkspaceFileTool(BaseTool):
|
|||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Write or create a file in the user's workspace. "
|
"Write or create a file in the user's persistent workspace (cloud storage). "
|
||||||
|
"These files survive across sessions. "
|
||||||
|
"For ephemeral session files, use the SDK Write tool instead. "
|
||||||
"Provide the content as a base64-encoded string. "
|
"Provide the content as a base64-encoded string. "
|
||||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||||
"Files are saved to the current session's folder by default. "
|
"Files are saved to the current session's folder by default. "
|
||||||
@@ -523,7 +529,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
|||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Delete a file from the user's workspace. "
|
"Delete a file from the user's persistent workspace (cloud storage). "
|
||||||
"Specify either file_id or path to identify the file. "
|
"Specify either file_id or path to identify the file. "
|
||||||
"Paths are scoped to the current session by default. "
|
"Paths are scoped to the current session by default. "
|
||||||
"Use /sessions/<session_id>/... for cross-session access."
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
|||||||
@@ -12,12 +12,11 @@ import backend.api.features.store.image_gen as store_image_gen
|
|||||||
import backend.api.features.store.media as store_media
|
import backend.api.features.store.media as store_media
|
||||||
import backend.data.graph as graph_db
|
import backend.data.graph as graph_db
|
||||||
import backend.data.integrations as integrations_db
|
import backend.data.integrations as integrations_db
|
||||||
from backend.data.block import BlockInput
|
|
||||||
from backend.data.db import transaction
|
from backend.data.db import transaction
|
||||||
from backend.data.execution import get_graph_execution
|
from backend.data.execution import get_graph_execution
|
||||||
from backend.data.graph import GraphSettings
|
from backend.data.graph import GraphSettings
|
||||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput, GraphInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||||
on_graph_activate,
|
on_graph_activate,
|
||||||
@@ -374,7 +373,7 @@ async def get_library_agent_by_graph_id(
|
|||||||
|
|
||||||
|
|
||||||
async def add_generated_agent_image(
|
async def add_generated_agent_image(
|
||||||
graph: graph_db.BaseGraph,
|
graph: graph_db.GraphBaseMeta,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
) -> Optional[prisma.models.LibraryAgent]:
|
) -> Optional[prisma.models.LibraryAgent]:
|
||||||
@@ -1130,7 +1129,7 @@ async def create_preset_from_graph_execution(
|
|||||||
async def update_preset(
|
async def update_preset(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
preset_id: str,
|
preset_id: str,
|
||||||
inputs: Optional[BlockInput] = None,
|
inputs: Optional[GraphInput] = None,
|
||||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
|
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
description: Optional[str] = None,
|
description: Optional[str] = None,
|
||||||
|
|||||||
@@ -6,9 +6,12 @@ import prisma.enums
|
|||||||
import prisma.models
|
import prisma.models
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
from backend.data.block import BlockInput
|
|
||||||
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
||||||
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
from backend.data.model import (
|
||||||
|
CredentialsMetaInput,
|
||||||
|
GraphInput,
|
||||||
|
is_credentials_field_name,
|
||||||
|
)
|
||||||
from backend.util.json import loads as json_loads
|
from backend.util.json import loads as json_loads
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
@@ -323,7 +326,7 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
|
|||||||
graph_id: str
|
graph_id: str
|
||||||
graph_version: int
|
graph_version: int
|
||||||
|
|
||||||
inputs: BlockInput
|
inputs: GraphInput
|
||||||
credentials: dict[str, CredentialsMetaInput]
|
credentials: dict[str, CredentialsMetaInput]
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
@@ -352,7 +355,7 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
|
|||||||
Request model used when updating a preset for a library agent.
|
Request model used when updating a preset for a library agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
inputs: Optional[BlockInput] = None
|
inputs: Optional[GraphInput] = None
|
||||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None
|
credentials: Optional[dict[str, CredentialsMetaInput]] = None
|
||||||
|
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
@@ -395,7 +398,7 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
|||||||
"Webhook must be included in AgentPreset query when webhookId is set"
|
"Webhook must be included in AgentPreset query when webhookId is set"
|
||||||
)
|
)
|
||||||
|
|
||||||
input_data: BlockInput = {}
|
input_data: GraphInput = {}
|
||||||
input_credentials: dict[str, CredentialsMetaInput] = {}
|
input_credentials: dict[str, CredentialsMetaInput] = {}
|
||||||
|
|
||||||
for preset_input in preset.InputPresets:
|
for preset_input in preset.InputPresets:
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ from typing import Optional
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from backend.blocks import get_block
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.block import get_block
|
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from .models import ApiResponse, ChatRequest, GraphData
|
from .models import ApiResponse, ChatRequest, GraphData
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ class BlockHandler(ContentHandler):
|
|||||||
|
|
||||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||||
"""Fetch blocks without embeddings."""
|
"""Fetch blocks without embeddings."""
|
||||||
from backend.data.block import get_blocks
|
from backend.blocks import get_blocks
|
||||||
|
|
||||||
# Get all available blocks
|
# Get all available blocks
|
||||||
all_blocks = get_blocks()
|
all_blocks = get_blocks()
|
||||||
@@ -249,7 +249,7 @@ class BlockHandler(ContentHandler):
|
|||||||
|
|
||||||
async def get_stats(self) -> dict[str, int]:
|
async def get_stats(self) -> dict[str, int]:
|
||||||
"""Get statistics about block embedding coverage."""
|
"""Get statistics about block embedding coverage."""
|
||||||
from backend.data.block import get_blocks
|
from backend.blocks import get_blocks
|
||||||
|
|
||||||
all_blocks = get_blocks()
|
all_blocks = get_blocks()
|
||||||
|
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ async def test_block_handler_get_missing_items(mocker):
|
|||||||
mock_existing = []
|
mock_existing = []
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.data.block.get_blocks",
|
"backend.blocks.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
@@ -135,7 +135,7 @@ async def test_block_handler_get_stats(mocker):
|
|||||||
mock_embedded = [{"count": 2}]
|
mock_embedded = [{"count": 2}]
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.data.block.get_blocks",
|
"backend.blocks.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
@@ -327,7 +327,7 @@ async def test_block_handler_handles_missing_attributes():
|
|||||||
mock_blocks = {"block-minimal": mock_block_class}
|
mock_blocks = {"block-minimal": mock_block_class}
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.data.block.get_blocks",
|
"backend.blocks.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
@@ -360,7 +360,7 @@ async def test_block_handler_skips_failed_blocks():
|
|||||||
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.data.block.get_blocks",
|
"backend.blocks.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, overload
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
@@ -11,8 +11,8 @@ import prisma.types
|
|||||||
|
|
||||||
from backend.data.db import transaction
|
from backend.data.db import transaction
|
||||||
from backend.data.graph import (
|
from backend.data.graph import (
|
||||||
GraphMeta,
|
|
||||||
GraphModel,
|
GraphModel,
|
||||||
|
GraphModelWithoutNodes,
|
||||||
get_graph,
|
get_graph,
|
||||||
get_graph_as_admin,
|
get_graph_as_admin,
|
||||||
get_sub_graphs,
|
get_sub_graphs,
|
||||||
@@ -334,7 +334,22 @@ async def get_store_agent_details(
|
|||||||
raise DatabaseError("Failed to fetch agent details") from e
|
raise DatabaseError("Failed to fetch agent details") from e
|
||||||
|
|
||||||
|
|
||||||
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
@overload
|
||||||
|
async def get_available_graph(
|
||||||
|
store_listing_version_id: str, hide_nodes: Literal[False]
|
||||||
|
) -> GraphModel: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def get_available_graph(
|
||||||
|
store_listing_version_id: str, hide_nodes: Literal[True] = True
|
||||||
|
) -> GraphModelWithoutNodes: ...
|
||||||
|
|
||||||
|
|
||||||
|
async def get_available_graph(
|
||||||
|
store_listing_version_id: str,
|
||||||
|
hide_nodes: bool = True,
|
||||||
|
) -> GraphModelWithoutNodes | GraphModel:
|
||||||
try:
|
try:
|
||||||
# Get avaialble, non-deleted store listing version
|
# Get avaialble, non-deleted store listing version
|
||||||
store_listing_version = (
|
store_listing_version = (
|
||||||
@@ -344,7 +359,7 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
|||||||
"isAvailable": True,
|
"isAvailable": True,
|
||||||
"isDeleted": False,
|
"isDeleted": False,
|
||||||
},
|
},
|
||||||
include={"AgentGraph": {"include": {"Nodes": True}}},
|
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -354,7 +369,9 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
|||||||
detail=f"Store listing version {store_listing_version_id} not found",
|
detail=f"Store listing version {store_listing_version_id} not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
return GraphModel.from_db(store_listing_version.AgentGraph).meta()
|
return (GraphModelWithoutNodes if hide_nodes else GraphModel).from_db(
|
||||||
|
store_listing_version.AgentGraph
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting agent: {e}")
|
logger.error(f"Error getting agent: {e}")
|
||||||
|
|||||||
@@ -662,7 +662,7 @@ async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
|||||||
)
|
)
|
||||||
current_ids = {row["id"] for row in valid_agents}
|
current_ids = {row["id"] for row in valid_agents}
|
||||||
elif content_type == ContentType.BLOCK:
|
elif content_type == ContentType.BLOCK:
|
||||||
from backend.data.block import get_blocks
|
from backend.blocks import get_blocks
|
||||||
|
|
||||||
current_ids = set(get_blocks().keys())
|
current_ids = set(get_blocks().keys())
|
||||||
elif content_type == ContentType.DOCUMENTATION:
|
elif content_type == ContentType.DOCUMENTATION:
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ Includes BM25 reranking for improved lexical relevance.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@@ -362,7 +363,11 @@ async def unified_hybrid_search(
|
|||||||
LIMIT {limit_param} OFFSET {offset_param}
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = await query_raw_with_schema(sql_query, *params)
|
try:
|
||||||
|
results = await query_raw_with_schema(sql_query, *params)
|
||||||
|
except Exception as e:
|
||||||
|
await _log_vector_error_diagnostics(e)
|
||||||
|
raise
|
||||||
|
|
||||||
total = results[0]["total_count"] if results else 0
|
total = results[0]["total_count"] if results else 0
|
||||||
# Apply BM25 reranking
|
# Apply BM25 reranking
|
||||||
@@ -686,7 +691,11 @@ async def hybrid_search(
|
|||||||
LIMIT {limit_param} OFFSET {offset_param}
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = await query_raw_with_schema(sql_query, *params)
|
try:
|
||||||
|
results = await query_raw_with_schema(sql_query, *params)
|
||||||
|
except Exception as e:
|
||||||
|
await _log_vector_error_diagnostics(e)
|
||||||
|
raise
|
||||||
|
|
||||||
total = results[0]["total_count"] if results else 0
|
total = results[0]["total_count"] if results else 0
|
||||||
|
|
||||||
@@ -718,6 +727,87 @@ async def hybrid_search_simple(
|
|||||||
return await hybrid_search(query=query, page=page, page_size=page_size)
|
return await hybrid_search(query=query, page=page, page_size=page_size)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Diagnostics
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# Rate limit: only log vector error diagnostics once per this interval
|
||||||
|
_VECTOR_DIAG_INTERVAL_SECONDS = 60
|
||||||
|
_last_vector_diag_time: float = 0
|
||||||
|
|
||||||
|
|
||||||
|
async def _log_vector_error_diagnostics(error: Exception) -> None:
|
||||||
|
"""Log diagnostic info when 'type vector does not exist' error occurs.
|
||||||
|
|
||||||
|
Note: Diagnostic queries use query_raw_with_schema which may run on a different
|
||||||
|
pooled connection than the one that failed. Session-level search_path can differ,
|
||||||
|
so these diagnostics show cluster-wide state, not necessarily the failed session.
|
||||||
|
|
||||||
|
Includes rate limiting to avoid log spam - only logs once per minute.
|
||||||
|
Caller should re-raise the error after calling this function.
|
||||||
|
"""
|
||||||
|
global _last_vector_diag_time
|
||||||
|
|
||||||
|
# Check if this is the vector type error
|
||||||
|
error_str = str(error).lower()
|
||||||
|
if not (
|
||||||
|
"type" in error_str and "vector" in error_str and "does not exist" in error_str
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Rate limit: only log once per interval
|
||||||
|
now = time.time()
|
||||||
|
if now - _last_vector_diag_time < _VECTOR_DIAG_INTERVAL_SECONDS:
|
||||||
|
return
|
||||||
|
_last_vector_diag_time = now
|
||||||
|
|
||||||
|
try:
|
||||||
|
diagnostics: dict[str, object] = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
search_path_result = await query_raw_with_schema("SHOW search_path")
|
||||||
|
diagnostics["search_path"] = search_path_result
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics["search_path"] = f"Error: {e}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
schema_result = await query_raw_with_schema("SELECT current_schema()")
|
||||||
|
diagnostics["current_schema"] = schema_result
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics["current_schema"] = f"Error: {e}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_result = await query_raw_with_schema(
|
||||||
|
"SELECT current_user, session_user, current_database()"
|
||||||
|
)
|
||||||
|
diagnostics["user_info"] = user_result
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics["user_info"] = f"Error: {e}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check pgvector extension installation (cluster-wide, stable info)
|
||||||
|
ext_result = await query_raw_with_schema(
|
||||||
|
"SELECT extname, extversion, nspname as schema "
|
||||||
|
"FROM pg_extension e "
|
||||||
|
"JOIN pg_namespace n ON e.extnamespace = n.oid "
|
||||||
|
"WHERE extname = 'vector'"
|
||||||
|
)
|
||||||
|
diagnostics["pgvector_extension"] = ext_result
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics["pgvector_extension"] = f"Error: {e}"
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"Vector type error diagnostics:\n"
|
||||||
|
f" Error: {error}\n"
|
||||||
|
f" search_path: {diagnostics.get('search_path')}\n"
|
||||||
|
f" current_schema: {diagnostics.get('current_schema')}\n"
|
||||||
|
f" user_info: {diagnostics.get('user_info')}\n"
|
||||||
|
f" pgvector_extension: {diagnostics.get('pgvector_extension')}"
|
||||||
|
)
|
||||||
|
except Exception as diag_error:
|
||||||
|
logger.error(f"Failed to collect vector error diagnostics: {diag_error}")
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
||||||
# for existing code that expects the popularity parameter
|
# for existing code that expects the popularity parameter
|
||||||
HybridSearchWeights = StoreAgentSearchWeights
|
HybridSearchWeights = StoreAgentSearchWeights
|
||||||
|
|||||||
@@ -7,16 +7,7 @@ from replicate.client import Client as ReplicateClient
|
|||||||
from replicate.exceptions import ReplicateError
|
from replicate.exceptions import ReplicateError
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.blocks.ideogram import (
|
from backend.data.graph import GraphBaseMeta
|
||||||
AspectRatio,
|
|
||||||
ColorPalettePreset,
|
|
||||||
IdeogramModelBlock,
|
|
||||||
IdeogramModelName,
|
|
||||||
MagicPromptOption,
|
|
||||||
StyleType,
|
|
||||||
UpscaleOption,
|
|
||||||
)
|
|
||||||
from backend.data.graph import BaseGraph
|
|
||||||
from backend.data.model import CredentialsMetaInput, ProviderName
|
from backend.data.model import CredentialsMetaInput, ProviderName
|
||||||
from backend.integrations.credentials_store import ideogram_credentials
|
from backend.integrations.credentials_store import ideogram_credentials
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
@@ -34,14 +25,14 @@ class ImageStyle(str, Enum):
|
|||||||
DIGITAL_ART = "digital art"
|
DIGITAL_ART = "digital art"
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
||||||
if settings.config.use_agent_image_generation_v2:
|
if settings.config.use_agent_image_generation_v2:
|
||||||
return await generate_agent_image_v2(graph=agent)
|
return await generate_agent_image_v2(graph=agent)
|
||||||
else:
|
else:
|
||||||
return await generate_agent_image_v1(agent=agent)
|
return await generate_agent_image_v1(agent=agent)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Generate an image for an agent using Ideogram model.
|
Generate an image for an agent using Ideogram model.
|
||||||
Returns:
|
Returns:
|
||||||
@@ -50,18 +41,31 @@ async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
|||||||
if not ideogram_credentials.api_key:
|
if not ideogram_credentials.api_key:
|
||||||
raise ValueError("Missing Ideogram API key")
|
raise ValueError("Missing Ideogram API key")
|
||||||
|
|
||||||
|
from backend.blocks.ideogram import (
|
||||||
|
AspectRatio,
|
||||||
|
ColorPalettePreset,
|
||||||
|
IdeogramModelBlock,
|
||||||
|
IdeogramModelName,
|
||||||
|
MagicPromptOption,
|
||||||
|
StyleType,
|
||||||
|
UpscaleOption,
|
||||||
|
)
|
||||||
|
|
||||||
name = graph.name
|
name = graph.name
|
||||||
description = f"{name} ({graph.description})" if graph.description else name
|
description = f"{name} ({graph.description})" if graph.description else name
|
||||||
|
|
||||||
prompt = (
|
prompt = (
|
||||||
f"Create a visually striking retro-futuristic vector pop art illustration prominently featuring "
|
"Create a visually striking retro-futuristic vector pop art illustration "
|
||||||
f'"{name}" in bold typography. The image clearly and literally depicts a {description}, '
|
f'prominently featuring "{name}" in bold typography. The image clearly and '
|
||||||
f"along with recognizable objects directly associated with the primary function of a {name}. "
|
f"literally depicts a {description}, along with recognizable objects directly "
|
||||||
f"Ensure the imagery is concrete, intuitive, and immediately understandable, clearly conveying the "
|
f"associated with the primary function of a {name}. "
|
||||||
f"purpose of a {name}. Maintain vibrant, limited-palette colors, sharp vector lines, geometric "
|
f"Ensure the imagery is concrete, intuitive, and immediately understandable, "
|
||||||
f"shapes, flat illustration techniques, and solid colors without gradients or shading. Preserve a "
|
f"clearly conveying the purpose of a {name}. "
|
||||||
f"retro-futuristic aesthetic influenced by mid-century futurism and 1960s psychedelia, "
|
"Maintain vibrant, limited-palette colors, sharp vector lines, "
|
||||||
f"prioritizing clear visual storytelling and thematic clarity above all else."
|
"geometric shapes, flat illustration techniques, and solid colors "
|
||||||
|
"without gradients or shading. Preserve a retro-futuristic aesthetic "
|
||||||
|
"influenced by mid-century futurism and 1960s psychedelia, "
|
||||||
|
"prioritizing clear visual storytelling and thematic clarity above all else."
|
||||||
)
|
)
|
||||||
|
|
||||||
custom_colors = [
|
custom_colors = [
|
||||||
@@ -99,12 +103,12 @@ async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
|||||||
return io.BytesIO(response.content)
|
return io.BytesIO(response.content)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Generate an image for an agent using Flux model via Replicate API.
|
Generate an image for an agent using Flux model via Replicate API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent (Graph): The agent to generate an image for
|
agent (GraphBaseMeta | AgentGraph): The agent to generate an image for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
io.BytesIO: The generated image as bytes
|
io.BytesIO: The generated image as bytes
|
||||||
@@ -114,7 +118,13 @@ async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
|||||||
raise ValueError("Missing Replicate API key in settings")
|
raise ValueError("Missing Replicate API key in settings")
|
||||||
|
|
||||||
# Construct prompt from agent details
|
# Construct prompt from agent details
|
||||||
prompt = f"Create a visually engaging app store thumbnail for the AI agent that highlights what it does in a clear and captivating way:\n- **Name**: {agent.name}\n- **Description**: {agent.description}\nFocus on showcasing its core functionality with an appealing design."
|
prompt = (
|
||||||
|
"Create a visually engaging app store thumbnail for the AI agent "
|
||||||
|
"that highlights what it does in a clear and captivating way:\n"
|
||||||
|
f"- **Name**: {agent.name}\n"
|
||||||
|
f"- **Description**: {agent.description}\n"
|
||||||
|
f"Focus on showcasing its core functionality with an appealing design."
|
||||||
|
)
|
||||||
|
|
||||||
# Set up Replicate client
|
# Set up Replicate client
|
||||||
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ async def get_agent(
|
|||||||
)
|
)
|
||||||
async def get_graph_meta_by_store_listing_version_id(
|
async def get_graph_meta_by_store_listing_version_id(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
) -> backend.data.graph.GraphMeta:
|
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||||
"""
|
"""
|
||||||
Get Agent Graph from Store Listing Version ID.
|
Get Agent Graph from Store Listing Version ID.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -40,10 +40,11 @@ from backend.api.model import (
|
|||||||
UpdateTimezoneRequest,
|
UpdateTimezoneRequest,
|
||||||
UploadFileResponse,
|
UploadFileResponse,
|
||||||
)
|
)
|
||||||
|
from backend.blocks import get_block, get_blocks
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.auth import api_key as api_key_db
|
from backend.data.auth import api_key as api_key_db
|
||||||
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
|
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||||
from backend.data.credit import (
|
from backend.data.credit import (
|
||||||
AutoTopUpConfig,
|
AutoTopUpConfig,
|
||||||
RefundRequest,
|
RefundRequest,
|
||||||
|
|||||||
@@ -3,22 +3,19 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, TypeVar
|
from typing import Sequence, Type, TypeVar
|
||||||
|
|
||||||
|
from backend.blocks._base import AnyBlockSchema, BlockType
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from backend.data.block import Block
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl_seconds=3600)
|
@cached(ttl_seconds=3600)
|
||||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
|
||||||
from backend.data.block import Block
|
from backend.blocks._base import Block
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
|
|
||||||
# Check if example blocks should be loaded from settings
|
# Check if example blocks should be loaded from settings
|
||||||
@@ -50,8 +47,8 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
|||||||
importlib.import_module(f".{module}", package=__name__)
|
importlib.import_module(f".{module}", package=__name__)
|
||||||
|
|
||||||
# Load all Block instances from the available modules
|
# Load all Block instances from the available modules
|
||||||
available_blocks: dict[str, type["Block"]] = {}
|
available_blocks: dict[str, type["AnyBlockSchema"]] = {}
|
||||||
for block_cls in all_subclasses(Block):
|
for block_cls in _all_subclasses(Block):
|
||||||
class_name = block_cls.__name__
|
class_name = block_cls.__name__
|
||||||
|
|
||||||
if class_name.endswith("Base"):
|
if class_name.endswith("Base"):
|
||||||
@@ -64,7 +61,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
|||||||
"please name the class with 'Base' at the end"
|
"please name the class with 'Base' at the end"
|
||||||
)
|
)
|
||||||
|
|
||||||
block = block_cls.create()
|
block = block_cls() # pyright: ignore[reportAbstractUsage]
|
||||||
|
|
||||||
if not isinstance(block.id, str) or len(block.id) != 36:
|
if not isinstance(block.id, str) or len(block.id) != 36:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -105,7 +102,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
|||||||
available_blocks[block.id] = block_cls
|
available_blocks[block.id] = block_cls
|
||||||
|
|
||||||
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
|
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
|
||||||
from backend.data.block import is_block_auth_configured
|
from ._utils import is_block_auth_configured
|
||||||
|
|
||||||
filtered_blocks = {}
|
filtered_blocks = {}
|
||||||
for block_id, block_cls in available_blocks.items():
|
for block_id, block_cls in available_blocks.items():
|
||||||
@@ -115,11 +112,48 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
|||||||
return filtered_blocks
|
return filtered_blocks
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["load_all_blocks"]
|
def _all_subclasses(cls: type[T]) -> list[type[T]]:
|
||||||
|
|
||||||
|
|
||||||
def all_subclasses(cls: type[T]) -> list[type[T]]:
|
|
||||||
subclasses = cls.__subclasses__()
|
subclasses = cls.__subclasses__()
|
||||||
for subclass in subclasses:
|
for subclass in subclasses:
|
||||||
subclasses += all_subclasses(subclass)
|
subclasses += _all_subclasses(subclass)
|
||||||
return subclasses
|
return subclasses
|
||||||
|
|
||||||
|
|
||||||
|
# ============== Block access helper functions ============== #
|
||||||
|
|
||||||
|
|
||||||
|
def get_blocks() -> dict[str, Type["AnyBlockSchema"]]:
|
||||||
|
return load_all_blocks()
|
||||||
|
|
||||||
|
|
||||||
|
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||||
|
def get_block(block_id: str) -> "AnyBlockSchema | None":
|
||||||
|
cls = get_blocks().get(block_id)
|
||||||
|
return cls() if cls else None
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=3600)
|
||||||
|
def get_webhook_block_ids() -> Sequence[str]:
|
||||||
|
return [
|
||||||
|
id
|
||||||
|
for id, B in get_blocks().items()
|
||||||
|
if B().block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=3600)
|
||||||
|
def get_io_block_ids() -> Sequence[str]:
|
||||||
|
return [
|
||||||
|
id
|
||||||
|
for id, B in get_blocks().items()
|
||||||
|
if B().block_type in (BlockType.INPUT, BlockType.OUTPUT)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=3600)
|
||||||
|
def get_human_in_the_loop_block_ids() -> Sequence[str]:
|
||||||
|
return [
|
||||||
|
id
|
||||||
|
for id, B in get_blocks().items()
|
||||||
|
if B().block_type == BlockType.HUMAN_IN_THE_LOOP
|
||||||
|
]
|
||||||
|
|||||||
739
autogpt_platform/backend/backend/blocks/_base.py
Normal file
739
autogpt_platform/backend/backend/blocks/_base.py
Normal file
@@ -0,0 +1,739 @@
|
|||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
ClassVar,
|
||||||
|
Generic,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
TypeAlias,
|
||||||
|
TypeVar,
|
||||||
|
cast,
|
||||||
|
get_origin,
|
||||||
|
)
|
||||||
|
|
||||||
|
import jsonref
|
||||||
|
import jsonschema
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
|
||||||
|
from backend.data.model import (
|
||||||
|
Credentials,
|
||||||
|
CredentialsFieldInfo,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
SchemaField,
|
||||||
|
is_credentials_field_name,
|
||||||
|
)
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util import json
|
||||||
|
from backend.util.exceptions import (
|
||||||
|
BlockError,
|
||||||
|
BlockExecutionError,
|
||||||
|
BlockInputError,
|
||||||
|
BlockOutputError,
|
||||||
|
BlockUnknownError,
|
||||||
|
)
|
||||||
|
from backend.util.settings import Config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import ContributorDetails, NodeExecutionStats
|
||||||
|
|
||||||
|
from ..data.graph import Link
|
||||||
|
|
||||||
|
app_config = Config()
|
||||||
|
|
||||||
|
|
||||||
|
BlockTestOutput = BlockOutputEntry | tuple[str, Callable[[Any], bool]]
|
||||||
|
|
||||||
|
|
||||||
|
class BlockType(Enum):
|
||||||
|
STANDARD = "Standard"
|
||||||
|
INPUT = "Input"
|
||||||
|
OUTPUT = "Output"
|
||||||
|
NOTE = "Note"
|
||||||
|
WEBHOOK = "Webhook"
|
||||||
|
WEBHOOK_MANUAL = "Webhook (manual)"
|
||||||
|
AGENT = "Agent"
|
||||||
|
AI = "AI"
|
||||||
|
AYRSHARE = "Ayrshare"
|
||||||
|
HUMAN_IN_THE_LOOP = "Human In The Loop"
|
||||||
|
|
||||||
|
|
||||||
|
class BlockCategory(Enum):
|
||||||
|
AI = "Block that leverages AI to perform a task."
|
||||||
|
SOCIAL = "Block that interacts with social media platforms."
|
||||||
|
TEXT = "Block that processes text data."
|
||||||
|
SEARCH = "Block that searches or extracts information from the internet."
|
||||||
|
BASIC = "Block that performs basic operations."
|
||||||
|
INPUT = "Block that interacts with input of the graph."
|
||||||
|
OUTPUT = "Block that interacts with output of the graph."
|
||||||
|
LOGIC = "Programming logic to control the flow of your agent"
|
||||||
|
COMMUNICATION = "Block that interacts with communication platforms."
|
||||||
|
DEVELOPER_TOOLS = "Developer tools such as GitHub blocks."
|
||||||
|
DATA = "Block that interacts with structured data."
|
||||||
|
HARDWARE = "Block that interacts with hardware."
|
||||||
|
AGENT = "Block that interacts with other agents."
|
||||||
|
CRM = "Block that interacts with CRM services."
|
||||||
|
SAFETY = (
|
||||||
|
"Block that provides AI safety mechanisms such as detecting harmful content"
|
||||||
|
)
|
||||||
|
PRODUCTIVITY = "Block that helps with productivity"
|
||||||
|
ISSUE_TRACKING = "Block that helps with issue tracking"
|
||||||
|
MULTIMEDIA = "Block that interacts with multimedia content"
|
||||||
|
MARKETING = "Block that helps with marketing"
|
||||||
|
|
||||||
|
def dict(self) -> dict[str, str]:
|
||||||
|
return {"category": self.name, "description": self.value}
|
||||||
|
|
||||||
|
|
||||||
|
class BlockCostType(str, Enum):
|
||||||
|
RUN = "run" # cost X credits per run
|
||||||
|
BYTE = "byte" # cost X credits per byte
|
||||||
|
SECOND = "second" # cost X credits per second
|
||||||
|
|
||||||
|
|
||||||
|
class BlockCost(BaseModel):
|
||||||
|
cost_amount: int
|
||||||
|
cost_filter: BlockInput
|
||||||
|
cost_type: BlockCostType
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cost_amount: int,
|
||||||
|
cost_type: BlockCostType = BlockCostType.RUN,
|
||||||
|
cost_filter: Optional[BlockInput] = None,
|
||||||
|
**data: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
cost_amount=cost_amount,
|
||||||
|
cost_filter=cost_filter or {},
|
||||||
|
cost_type=cost_type,
|
||||||
|
**data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockInfo(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
inputSchema: dict[str, Any]
|
||||||
|
outputSchema: dict[str, Any]
|
||||||
|
costs: list[BlockCost]
|
||||||
|
description: str
|
||||||
|
categories: list[dict[str, str]]
|
||||||
|
contributors: list[dict[str, Any]]
|
||||||
|
staticOutput: bool
|
||||||
|
uiType: str
|
||||||
|
|
||||||
|
|
||||||
|
class BlockSchema(BaseModel):
|
||||||
|
cached_jsonschema: ClassVar[dict[str, Any]]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def jsonschema(cls) -> dict[str, Any]:
|
||||||
|
if cls.cached_jsonschema:
|
||||||
|
return cls.cached_jsonschema
|
||||||
|
|
||||||
|
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
|
||||||
|
|
||||||
|
def ref_to_dict(obj):
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
# OpenAPI <3.1 does not support sibling fields that has a $ref key
|
||||||
|
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
|
||||||
|
keys = {"allOf", "anyOf", "oneOf"}
|
||||||
|
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
|
||||||
|
if one_key:
|
||||||
|
obj.update(obj[one_key][0])
|
||||||
|
|
||||||
|
return {
|
||||||
|
key: ref_to_dict(value)
|
||||||
|
for key, value in obj.items()
|
||||||
|
if not key.startswith("$") and key != one_key
|
||||||
|
}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [ref_to_dict(item) for item in obj]
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||||
|
|
||||||
|
return cls.cached_jsonschema
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_data(cls, data: BlockInput) -> str | None:
|
||||||
|
return json.validate_with_jsonschema(
|
||||||
|
schema=cls.jsonschema(),
|
||||||
|
data={k: v for k, v in data.items() if v is not None},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||||
|
return cls.validate_data(data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_field_schema(cls, field_name: str) -> dict[str, Any]:
|
||||||
|
model_schema = cls.jsonschema().get("properties", {})
|
||||||
|
if not model_schema:
|
||||||
|
raise ValueError(f"Invalid model schema {cls}")
|
||||||
|
|
||||||
|
property_schema = model_schema.get(field_name)
|
||||||
|
if not property_schema:
|
||||||
|
raise ValueError(f"Invalid property name {field_name}")
|
||||||
|
|
||||||
|
return property_schema
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:
|
||||||
|
"""
|
||||||
|
Validate the data against a specific property (one of the input/output name).
|
||||||
|
Returns the validation error message if the data does not match the schema.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
property_schema = cls.get_field_schema(field_name)
|
||||||
|
jsonschema.validate(json.to_dict(data), property_schema)
|
||||||
|
return None
|
||||||
|
except jsonschema.ValidationError as e:
|
||||||
|
return str(e)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_fields(cls) -> set[str]:
|
||||||
|
return set(cls.model_fields.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_required_fields(cls) -> set[str]:
|
||||||
|
return {
|
||||||
|
field
|
||||||
|
for field, field_info in cls.model_fields.items()
|
||||||
|
if field_info.is_required()
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __pydantic_init_subclass__(cls, **kwargs):
|
||||||
|
"""Validates the schema definition. Rules:
|
||||||
|
- Fields with annotation `CredentialsMetaInput` MUST be
|
||||||
|
named `credentials` or `*_credentials`
|
||||||
|
- Fields named `credentials` or `*_credentials` MUST be
|
||||||
|
of type `CredentialsMetaInput`
|
||||||
|
"""
|
||||||
|
super().__pydantic_init_subclass__(**kwargs)
|
||||||
|
|
||||||
|
# Reset cached JSON schema to prevent inheriting it from parent class
|
||||||
|
cls.cached_jsonschema = {}
|
||||||
|
|
||||||
|
credentials_fields = cls.get_credentials_fields()
|
||||||
|
|
||||||
|
for field_name in cls.get_fields():
|
||||||
|
if is_credentials_field_name(field_name):
|
||||||
|
if field_name not in credentials_fields:
|
||||||
|
raise TypeError(
|
||||||
|
f"Credentials field '{field_name}' on {cls.__qualname__} "
|
||||||
|
f"is not of type {CredentialsMetaInput.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
CredentialsMetaInput.validate_credentials_field_schema(
|
||||||
|
cls.get_field_schema(field_name), field_name
|
||||||
|
)
|
||||||
|
|
||||||
|
elif field_name in credentials_fields:
|
||||||
|
raise KeyError(
|
||||||
|
f"Credentials field '{field_name}' on {cls.__qualname__} "
|
||||||
|
"has invalid name: must be 'credentials' or *_credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_credentials_fields(cls) -> dict[str, type[CredentialsMetaInput]]:
|
||||||
|
return {
|
||||||
|
field_name: info.annotation
|
||||||
|
for field_name, info in cls.model_fields.items()
|
||||||
|
if (
|
||||||
|
inspect.isclass(info.annotation)
|
||||||
|
and issubclass(
|
||||||
|
get_origin(info.annotation) or info.annotation,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_auto_credentials_fields(cls) -> dict[str, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get fields that have auto_credentials metadata (e.g., GoogleDriveFileInput).
|
||||||
|
|
||||||
|
Returns a dict mapping kwarg_name -> {field_name, auto_credentials_config}
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If multiple fields have the same kwarg_name, as this would
|
||||||
|
cause silent overwriting and only the last field would be processed.
|
||||||
|
"""
|
||||||
|
result: dict[str, dict[str, Any]] = {}
|
||||||
|
schema = cls.jsonschema()
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
|
||||||
|
for field_name, field_schema in properties.items():
|
||||||
|
auto_creds = field_schema.get("auto_credentials")
|
||||||
|
if auto_creds:
|
||||||
|
kwarg_name = auto_creds.get("kwarg_name", "credentials")
|
||||||
|
if kwarg_name in result:
|
||||||
|
raise ValueError(
|
||||||
|
f"Duplicate auto_credentials kwarg_name '{kwarg_name}' "
|
||||||
|
f"in fields '{result[kwarg_name]['field_name']}' and "
|
||||||
|
f"'{field_name}' on {cls.__qualname__}"
|
||||||
|
)
|
||||||
|
result[kwarg_name] = {
|
||||||
|
"field_name": field_name,
|
||||||
|
"config": auto_creds,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# Regular credentials fields
|
||||||
|
for field_name in cls.get_credentials_fields().keys():
|
||||||
|
result[field_name] = CredentialsFieldInfo.model_validate(
|
||||||
|
cls.get_field_schema(field_name), by_alias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Auto-generated credentials fields (from GoogleDriveFileInput etc.)
|
||||||
|
for kwarg_name, info in cls.get_auto_credentials_fields().items():
|
||||||
|
config = info["config"]
|
||||||
|
# Build a schema-like dict that CredentialsFieldInfo can parse
|
||||||
|
auto_schema = {
|
||||||
|
"credentials_provider": [config.get("provider", "google")],
|
||||||
|
"credentials_types": [config.get("type", "oauth2")],
|
||||||
|
"credentials_scopes": config.get("scopes"),
|
||||||
|
}
|
||||||
|
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
||||||
|
auto_schema, by_alias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||||
|
return data # Return as is, by default.
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
|
||||||
|
input_fields_from_nodes = {link.sink_name for link in links}
|
||||||
|
return input_fields_from_nodes - set(data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||||
|
return cls.get_required_fields() - set(data)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockSchemaInput(BlockSchema):
|
||||||
|
"""
|
||||||
|
Base schema class for block inputs.
|
||||||
|
All block input schemas should extend this class for consistency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BlockSchemaOutput(BlockSchema):
|
||||||
|
"""
|
||||||
|
Base schema class for block outputs that includes a standard error field.
|
||||||
|
All block output schemas should extend this class to ensure consistent error handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
error: str = SchemaField(
|
||||||
|
description="Error message if the operation failed", default=""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchemaInput)
|
||||||
|
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchemaOutput)
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyInputSchema(BlockSchemaInput):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyOutputSchema(BlockSchemaOutput):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility - will be deprecated
|
||||||
|
EmptySchema = EmptyOutputSchema
|
||||||
|
|
||||||
|
|
||||||
|
# --8<-- [start:BlockWebhookConfig]
|
||||||
|
class BlockManualWebhookConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration model for webhook-triggered blocks on which
|
||||||
|
the user has to manually set up the webhook at the provider.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: ProviderName
|
||||||
|
"""The service provider that the webhook connects to"""
|
||||||
|
|
||||||
|
webhook_type: str
|
||||||
|
"""
|
||||||
|
Identifier for the webhook type. E.g. GitHub has repo and organization level hooks.
|
||||||
|
|
||||||
|
Only for use in the corresponding `WebhooksManager`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
event_filter_input: str = ""
|
||||||
|
"""
|
||||||
|
Name of the block's event filter input.
|
||||||
|
Leave empty if the corresponding webhook doesn't have distinct event/payload types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
event_format: str = "{event}"
|
||||||
|
"""
|
||||||
|
Template string for the event(s) that a block instance subscribes to.
|
||||||
|
Applied individually to each event selected in the event filter input.
|
||||||
|
|
||||||
|
Example: `"pull_request.{event}"` -> `"pull_request.opened"`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||||
|
"""
|
||||||
|
Configuration model for webhook-triggered blocks for which
|
||||||
|
the webhook can be automatically set up through the provider's API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
resource_format: str
|
||||||
|
"""
|
||||||
|
Template string for the resource that a block instance subscribes to.
|
||||||
|
Fields will be filled from the block's inputs (except `payload`).
|
||||||
|
|
||||||
|
Example: `f"{repo}/pull_requests"` (note: not how it's actually implemented)
|
||||||
|
|
||||||
|
Only for use in the corresponding `WebhooksManager`.
|
||||||
|
"""
|
||||||
|
# --8<-- [end:BlockWebhookConfig]
|
||||||
|
|
||||||
|
|
||||||
|
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: str = "",
|
||||||
|
description: str = "",
|
||||||
|
contributors: list["ContributorDetails"] = [],
|
||||||
|
categories: set[BlockCategory] | None = None,
|
||||||
|
input_schema: Type[BlockSchemaInputType] = EmptyInputSchema,
|
||||||
|
output_schema: Type[BlockSchemaOutputType] = EmptyOutputSchema,
|
||||||
|
test_input: BlockInput | list[BlockInput] | None = None,
|
||||||
|
test_output: BlockTestOutput | list[BlockTestOutput] | None = None,
|
||||||
|
test_mock: dict[str, Any] | None = None,
|
||||||
|
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
|
||||||
|
disabled: bool = False,
|
||||||
|
static_output: bool = False,
|
||||||
|
block_type: BlockType = BlockType.STANDARD,
|
||||||
|
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
|
||||||
|
is_sensitive_action: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the block with the given schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id: The unique identifier for the block, this value will be persisted in the
|
||||||
|
DB. So it should be a unique and constant across the application run.
|
||||||
|
Use the UUID format for the ID.
|
||||||
|
description: The description of the block, explaining what the block does.
|
||||||
|
contributors: The list of contributors who contributed to the block.
|
||||||
|
input_schema: The schema, defined as a Pydantic model, for the input data.
|
||||||
|
output_schema: The schema, defined as a Pydantic model, for the output data.
|
||||||
|
test_input: The list or single sample input data for the block, for testing.
|
||||||
|
test_output: The list or single expected output if the test_input is run.
|
||||||
|
test_mock: function names on the block implementation to mock on test run.
|
||||||
|
disabled: If the block is disabled, it will not be available for execution.
|
||||||
|
static_output: Whether the output links of the block are static by default.
|
||||||
|
"""
|
||||||
|
from backend.data.model import NodeExecutionStats
|
||||||
|
|
||||||
|
self.id = id
|
||||||
|
self.input_schema = input_schema
|
||||||
|
self.output_schema = output_schema
|
||||||
|
self.test_input = test_input
|
||||||
|
self.test_output = test_output
|
||||||
|
self.test_mock = test_mock
|
||||||
|
self.test_credentials = test_credentials
|
||||||
|
self.description = description
|
||||||
|
self.categories = categories or set()
|
||||||
|
self.contributors = contributors or set()
|
||||||
|
self.disabled = disabled
|
||||||
|
self.static_output = static_output
|
||||||
|
self.block_type = block_type
|
||||||
|
self.webhook_config = webhook_config
|
||||||
|
self.is_sensitive_action = is_sensitive_action
|
||||||
|
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
||||||
|
|
||||||
|
if self.webhook_config:
|
||||||
|
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||||
|
# Enforce presence of credentials field on auto-setup webhook blocks
|
||||||
|
if not (cred_fields := self.input_schema.get_credentials_fields()):
|
||||||
|
raise TypeError(
|
||||||
|
"credentials field is required on auto-setup webhook blocks"
|
||||||
|
)
|
||||||
|
# Disallow multiple credentials inputs on webhook blocks
|
||||||
|
elif len(cred_fields) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Multiple credentials inputs not supported on webhook blocks"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.block_type = BlockType.WEBHOOK
|
||||||
|
else:
|
||||||
|
self.block_type = BlockType.WEBHOOK_MANUAL
|
||||||
|
|
||||||
|
# Enforce shape of webhook event filter, if present
|
||||||
|
if self.webhook_config.event_filter_input:
|
||||||
|
event_filter_field = self.input_schema.model_fields[
|
||||||
|
self.webhook_config.event_filter_input
|
||||||
|
]
|
||||||
|
if not (
|
||||||
|
isinstance(event_filter_field.annotation, type)
|
||||||
|
and issubclass(event_filter_field.annotation, BaseModel)
|
||||||
|
and all(
|
||||||
|
field.annotation is bool
|
||||||
|
for field in event_filter_field.annotation.model_fields.values()
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.name} has an invalid webhook event selector: "
|
||||||
|
"field must be a BaseModel and all its fields must be boolean"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enforce presence of 'payload' input
|
||||||
|
if "payload" not in self.input_schema.model_fields:
|
||||||
|
raise TypeError(
|
||||||
|
f"{self.name} is webhook-triggered but has no 'payload' input"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Disable webhook-triggered block if webhook functionality not available
|
||||||
|
if not app_config.platform_base_url:
|
||||||
|
self.disabled = True
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def run(self, input_data: BlockSchemaInputType, **kwargs) -> BlockOutput:
|
||||||
|
"""
|
||||||
|
Run the block with the given input data.
|
||||||
|
Args:
|
||||||
|
input_data: The input data with the structure of input_schema.
|
||||||
|
|
||||||
|
Kwargs: Currently 14/02/2025 these include
|
||||||
|
graph_id: The ID of the graph.
|
||||||
|
node_id: The ID of the node.
|
||||||
|
graph_exec_id: The ID of the graph execution.
|
||||||
|
node_exec_id: The ID of the node execution.
|
||||||
|
user_id: The ID of the user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Generator that yields (output_name, output_data).
|
||||||
|
output_name: One of the output name defined in Block's output_schema.
|
||||||
|
output_data: The data for the output_name, matching the defined schema.
|
||||||
|
"""
|
||||||
|
# --- satisfy the type checker, never executed -------------
|
||||||
|
if False: # noqa: SIM115
|
||||||
|
yield "name", "value" # pyright: ignore[reportMissingYield]
|
||||||
|
raise NotImplementedError(f"{self.name} does not implement the run method.")
|
||||||
|
|
||||||
|
async def run_once(
|
||||||
|
self, input_data: BlockSchemaInputType, output: str, **kwargs
|
||||||
|
) -> Any:
|
||||||
|
async for item in self.run(input_data, **kwargs):
|
||||||
|
name, data = item
|
||||||
|
if name == output:
|
||||||
|
return data
|
||||||
|
raise ValueError(f"{self.name} did not produce any output for {output}")
|
||||||
|
|
||||||
|
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
|
||||||
|
self.execution_stats += stats
|
||||||
|
return self.execution_stats
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"inputSchema": self.input_schema.jsonschema(),
|
||||||
|
"outputSchema": self.output_schema.jsonschema(),
|
||||||
|
"description": self.description,
|
||||||
|
"categories": [category.dict() for category in self.categories],
|
||||||
|
"contributors": [
|
||||||
|
contributor.model_dump() for contributor in self.contributors
|
||||||
|
],
|
||||||
|
"staticOutput": self.static_output,
|
||||||
|
"uiType": self.block_type.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_info(self) -> BlockInfo:
|
||||||
|
from backend.data.credit import get_block_cost
|
||||||
|
|
||||||
|
return BlockInfo(
|
||||||
|
id=self.id,
|
||||||
|
name=self.name,
|
||||||
|
inputSchema=self.input_schema.jsonschema(),
|
||||||
|
outputSchema=self.output_schema.jsonschema(),
|
||||||
|
costs=get_block_cost(self),
|
||||||
|
description=self.description,
|
||||||
|
categories=[category.dict() for category in self.categories],
|
||||||
|
contributors=[
|
||||||
|
contributor.model_dump() for contributor in self.contributors
|
||||||
|
],
|
||||||
|
staticOutput=self.static_output,
|
||||||
|
uiType=self.block_type.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
async for output_name, output_data in self._execute(input_data, **kwargs):
|
||||||
|
yield output_name, output_data
|
||||||
|
except Exception as ex:
|
||||||
|
if isinstance(ex, BlockError):
|
||||||
|
raise ex
|
||||||
|
else:
|
||||||
|
raise (
|
||||||
|
BlockExecutionError
|
||||||
|
if isinstance(ex, ValueError)
|
||||||
|
else BlockUnknownError
|
||||||
|
)(
|
||||||
|
message=str(ex),
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
) from ex
|
||||||
|
|
||||||
|
async def is_block_exec_need_review(
|
||||||
|
self,
|
||||||
|
input_data: BlockInput,
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
node_id: str,
|
||||||
|
node_exec_id: str,
|
||||||
|
graph_exec_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
execution_context: "ExecutionContext",
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[bool, BlockInput]:
|
||||||
|
"""
|
||||||
|
Check if this block execution needs human review and handle the review process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (should_pause, input_data_to_use)
|
||||||
|
- should_pause: True if execution should be paused for review
|
||||||
|
- input_data_to_use: The input data to use (may be modified by reviewer)
|
||||||
|
"""
|
||||||
|
if not (
|
||||||
|
self.is_sensitive_action and execution_context.sensitive_action_safe_mode
|
||||||
|
):
|
||||||
|
return False, input_data
|
||||||
|
|
||||||
|
from backend.blocks.helpers.review import HITLReviewHelper
|
||||||
|
|
||||||
|
# Handle the review request and get decision
|
||||||
|
decision = await HITLReviewHelper.handle_review_decision(
|
||||||
|
input_data=input_data,
|
||||||
|
user_id=user_id,
|
||||||
|
node_id=node_id,
|
||||||
|
node_exec_id=node_exec_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_version=graph_version,
|
||||||
|
block_name=self.name,
|
||||||
|
editable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if decision is None:
|
||||||
|
# We're awaiting review - pause execution
|
||||||
|
return True, input_data
|
||||||
|
|
||||||
|
if not decision.should_proceed:
|
||||||
|
# Review was rejected, raise an error to stop execution
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Block execution rejected by reviewer: {decision.message}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Review was approved - use the potentially modified data
|
||||||
|
# ReviewResult.data must be a dict for block inputs
|
||||||
|
reviewed_data = decision.review_result.data
|
||||||
|
if not isinstance(reviewed_data, dict):
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
return False, reviewed_data
|
||||||
|
|
||||||
|
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||||
|
# Check for review requirement only if running within a graph execution context
|
||||||
|
# Direct block execution (e.g., from chat) skips the review process
|
||||||
|
has_graph_context = all(
|
||||||
|
key in kwargs
|
||||||
|
for key in (
|
||||||
|
"node_exec_id",
|
||||||
|
"graph_exec_id",
|
||||||
|
"graph_id",
|
||||||
|
"execution_context",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if has_graph_context:
|
||||||
|
should_pause, input_data = await self.is_block_exec_need_review(
|
||||||
|
input_data, **kwargs
|
||||||
|
)
|
||||||
|
if should_pause:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Validate the input data (original or reviewer-modified) once
|
||||||
|
if error := self.input_schema.validate_data(input_data):
|
||||||
|
raise BlockInputError(
|
||||||
|
message=f"Unable to execute block with invalid input data: {error}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the validated input data
|
||||||
|
async for output_name, output_data in self.run(
|
||||||
|
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if output_name == "error":
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=output_data, block_name=self.name, block_id=self.id
|
||||||
|
)
|
||||||
|
if self.block_type == BlockType.STANDARD and (
|
||||||
|
error := self.output_schema.validate_field(output_name, output_data)
|
||||||
|
):
|
||||||
|
raise BlockOutputError(
|
||||||
|
message=f"Block produced an invalid output data: {error}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
yield output_name, output_data
|
||||||
|
|
||||||
|
def is_triggered_by_event_type(
|
||||||
|
self, trigger_config: dict[str, Any], event_type: str
|
||||||
|
) -> bool:
|
||||||
|
if not self.webhook_config:
|
||||||
|
raise TypeError("This method can't be used on non-trigger blocks")
|
||||||
|
if not self.webhook_config.event_filter_input:
|
||||||
|
return True
|
||||||
|
event_filter = trigger_config.get(self.webhook_config.event_filter_input)
|
||||||
|
if not event_filter:
|
||||||
|
raise ValueError("Event filter is not configured on trigger")
|
||||||
|
return event_type in [
|
||||||
|
self.webhook_config.event_format.format(event=k)
|
||||||
|
for k in event_filter
|
||||||
|
if event_filter[k] is True
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for any block with standard input/output schemas
|
||||||
|
AnyBlockSchema: TypeAlias = Block[BlockSchemaInput, BlockSchemaOutput]
|
||||||
122
autogpt_platform/backend/backend/blocks/_utils.py
Normal file
122
autogpt_platform/backend/backend/blocks/_utils.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
from ._base import AnyBlockSchema
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_block_auth_configured(
|
||||||
|
block_cls: type[AnyBlockSchema],
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a block has a valid authentication method configured at runtime.
|
||||||
|
|
||||||
|
For example if a block is an OAuth-only block and there env vars are not set,
|
||||||
|
do not show it in the UI.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from backend.sdk.registry import AutoRegistry
|
||||||
|
|
||||||
|
# Create an instance to access input_schema
|
||||||
|
try:
|
||||||
|
block = block_cls()
|
||||||
|
except Exception as e:
|
||||||
|
# If we can't create a block instance, assume it's not OAuth-only
|
||||||
|
logger.error(f"Error creating block instance for {block_cls.__name__}: {e}")
|
||||||
|
return True
|
||||||
|
logger.debug(
|
||||||
|
f"Checking if block {block_cls.__name__} has a valid provider configured"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all credential inputs from input schema
|
||||||
|
credential_inputs = block.input_schema.get_credentials_fields_info()
|
||||||
|
required_inputs = block.input_schema.get_required_fields()
|
||||||
|
if not credential_inputs:
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} has no credential inputs - Treating as valid"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check credential inputs
|
||||||
|
if len(required_inputs.intersection(credential_inputs.keys())) == 0:
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} has only optional credential inputs"
|
||||||
|
" - will work without credentials configured"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the credential inputs for this block are correctly configured
|
||||||
|
for field_name, field_info in credential_inputs.items():
|
||||||
|
provider_names = field_info.provider
|
||||||
|
if not provider_names:
|
||||||
|
logger.warning(
|
||||||
|
f"Block {block_cls.__name__} "
|
||||||
|
f"has credential input '{field_name}' with no provider options"
|
||||||
|
" - Disabling"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If a field has multiple possible providers, each one needs to be usable to
|
||||||
|
# prevent breaking the UX
|
||||||
|
for _provider_name in provider_names:
|
||||||
|
provider_name = _provider_name.value
|
||||||
|
if provider_name in ProviderName.__members__.values():
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"provider '{provider_name}' is part of the legacy provider system"
|
||||||
|
" - Treating as valid"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
provider = AutoRegistry.get_provider(provider_name)
|
||||||
|
if not provider:
|
||||||
|
logger.warning(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"refers to unknown provider '{provider_name}' - Disabling"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check the provider's supported auth types
|
||||||
|
if field_info.supported_types != provider.supported_auth_types:
|
||||||
|
logger.warning(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"has mismatched supported auth types (field <> Provider): "
|
||||||
|
f"{field_info.supported_types} != {provider.supported_auth_types}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (supported_auth_types := provider.supported_auth_types):
|
||||||
|
# No auth methods are been configured for this provider
|
||||||
|
logger.warning(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"provider '{provider_name}' "
|
||||||
|
"has no authentication methods configured - Disabling"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if provider supports OAuth
|
||||||
|
if "oauth2" in supported_auth_types:
|
||||||
|
# Check if OAuth environment variables are set
|
||||||
|
if (oauth_config := provider.oauth_config) and bool(
|
||||||
|
os.getenv(oauth_config.client_id_env_var)
|
||||||
|
and os.getenv(oauth_config.client_secret_env_var)
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"provider '{provider_name}' is configured for OAuth"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"provider '{provider_name}' "
|
||||||
|
"is missing OAuth client ID or secret - Disabling"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' is valid; "
|
||||||
|
f"supported credential types: {', '.join(field_info.supported_types)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockInput,
|
BlockInput,
|
||||||
@@ -9,13 +9,15 @@ from backend.data.block import (
|
|||||||
BlockSchema,
|
BlockSchema,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockType,
|
BlockType,
|
||||||
get_block,
|
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks
|
from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks
|
||||||
from backend.data.model import NodeExecutionStats, SchemaField
|
from backend.data.model import NodeExecutionStats, SchemaField
|
||||||
from backend.util.json import validate_with_jsonschema
|
from backend.util.json import validate_with_jsonschema
|
||||||
from backend.util.retry import func_retry
|
from backend.util.retry import func_retry
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.executor.utils import LogMetadata
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -124,9 +126,10 @@ class AgentExecutorBlock(Block):
|
|||||||
graph_version: int,
|
graph_version: int,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
logger,
|
logger: "LogMetadata",
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
|
from backend.blocks import get_block
|
||||||
from backend.data.execution import ExecutionEventType
|
from backend.data.execution import ExecutionEventType
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
|
|
||||||
@@ -198,7 +201,7 @@ class AgentExecutorBlock(Block):
|
|||||||
self,
|
self,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
logger,
|
logger: "LogMetadata",
|
||||||
) -> None:
|
) -> None:
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,11 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.blocks._base import (
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.llm import (
|
from backend.blocks.llm import (
|
||||||
DEFAULT_LLM_MODEL,
|
DEFAULT_LLM_MODEL,
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -11,12 +17,6 @@ from backend.blocks.llm import (
|
|||||||
LLMResponse,
|
LLMResponse,
|
||||||
llm_call,
|
llm_call,
|
||||||
)
|
)
|
||||||
from backend.data.block import (
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from pydantic import SecretStr
|
|||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -5,7 +5,12 @@ from pydantic import SecretStr
|
|||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Literal
|
|||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Literal
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,3 +1,10 @@
|
|||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.apollo._api import ApolloClient
|
from backend.blocks.apollo._api import ApolloClient
|
||||||
from backend.blocks.apollo._auth import (
|
from backend.blocks.apollo._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -10,13 +17,6 @@ from backend.blocks.apollo.models import (
|
|||||||
PrimaryPhone,
|
PrimaryPhone,
|
||||||
SearchOrganizationsRequest,
|
SearchOrganizationsRequest,
|
||||||
)
|
)
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.apollo._api import ApolloClient
|
from backend.blocks.apollo._api import ApolloClient
|
||||||
from backend.blocks.apollo._auth import (
|
from backend.blocks.apollo._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -14,13 +21,6 @@ from backend.blocks.apollo.models import (
|
|||||||
SearchPeopleRequest,
|
SearchPeopleRequest,
|
||||||
SenorityLevels,
|
SenorityLevels,
|
||||||
)
|
)
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,10 @@
|
|||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.apollo._api import ApolloClient
|
from backend.blocks.apollo._api import ApolloClient
|
||||||
from backend.blocks.apollo._auth import (
|
from backend.blocks.apollo._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -6,13 +13,6 @@ from backend.blocks.apollo._auth import (
|
|||||||
ApolloCredentialsInput,
|
ApolloCredentialsInput,
|
||||||
)
|
)
|
||||||
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.data.block import BlockSchemaInput
|
from backend.blocks._base import BlockSchemaInput
|
||||||
from backend.data.model import SchemaField, UserIntegrations
|
from backend.data.model import SchemaField, UserIntegrations
|
||||||
from backend.integrations.ayrshare import AyrshareClient
|
from backend.integrations.ayrshare import AyrshareClient
|
||||||
from backend.util.clients import get_database_manager_async_client
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import enum
|
import enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
import shlex
|
import shlex
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Literal, Optional
|
from typing import TYPE_CHECKING, Literal, Optional
|
||||||
|
|
||||||
from e2b import AsyncSandbox as BaseAsyncSandbox
|
from e2b import AsyncSandbox as BaseAsyncSandbox
|
||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
@@ -20,6 +20,13 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.sandbox_files import (
|
||||||
|
SandboxFileOutput,
|
||||||
|
extract_and_store_sandbox_files,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.executor.utils import ExecutionContext
|
||||||
|
|
||||||
|
|
||||||
class ClaudeCodeExecutionError(Exception):
|
class ClaudeCodeExecutionError(Exception):
|
||||||
@@ -174,22 +181,15 @@ class ClaudeCodeBlock(Block):
|
|||||||
advanced=True,
|
advanced=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
class FileOutput(BaseModel):
|
|
||||||
"""A file extracted from the sandbox."""
|
|
||||||
|
|
||||||
path: str
|
|
||||||
relative_path: str # Path relative to working directory (for GitHub, etc.)
|
|
||||||
name: str
|
|
||||||
content: str
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
response: str = SchemaField(
|
response: str = SchemaField(
|
||||||
description="The output/response from Claude Code execution"
|
description="The output/response from Claude Code execution"
|
||||||
)
|
)
|
||||||
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
|
files: list[SandboxFileOutput] = SchemaField(
|
||||||
description=(
|
description=(
|
||||||
"List of text files created/modified by Claude Code during this execution. "
|
"List of text files created/modified by Claude Code during this execution. "
|
||||||
"Each file has 'path', 'relative_path', 'name', and 'content' fields."
|
"Each file has 'path', 'relative_path', 'name', 'content', and 'workspace_ref' fields. "
|
||||||
|
"workspace_ref contains a workspace:// URI if the file was stored to workspace."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conversation_history: str = SchemaField(
|
conversation_history: str = SchemaField(
|
||||||
@@ -252,6 +252,7 @@ class ClaudeCodeBlock(Block):
|
|||||||
"relative_path": "index.html",
|
"relative_path": "index.html",
|
||||||
"name": "index.html",
|
"name": "index.html",
|
||||||
"content": "<html>Hello World</html>",
|
"content": "<html>Hello World</html>",
|
||||||
|
"workspace_ref": None,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
@@ -267,11 +268,12 @@ class ClaudeCodeBlock(Block):
|
|||||||
"execute_claude_code": lambda *args, **kwargs: (
|
"execute_claude_code": lambda *args, **kwargs: (
|
||||||
"Created index.html with hello world content", # response
|
"Created index.html with hello world content", # response
|
||||||
[
|
[
|
||||||
ClaudeCodeBlock.FileOutput(
|
SandboxFileOutput(
|
||||||
path="/home/user/index.html",
|
path="/home/user/index.html",
|
||||||
relative_path="index.html",
|
relative_path="index.html",
|
||||||
name="index.html",
|
name="index.html",
|
||||||
content="<html>Hello World</html>",
|
content="<html>Hello World</html>",
|
||||||
|
workspace_ref=None,
|
||||||
)
|
)
|
||||||
], # files
|
], # files
|
||||||
"User: Create a hello world HTML file\n"
|
"User: Create a hello world HTML file\n"
|
||||||
@@ -294,7 +296,8 @@ class ClaudeCodeBlock(Block):
|
|||||||
existing_sandbox_id: str,
|
existing_sandbox_id: str,
|
||||||
conversation_history: str,
|
conversation_history: str,
|
||||||
dispose_sandbox: bool,
|
dispose_sandbox: bool,
|
||||||
) -> tuple[str, list["ClaudeCodeBlock.FileOutput"], str, str, str]:
|
execution_context: "ExecutionContext",
|
||||||
|
) -> tuple[str, list[SandboxFileOutput], str, str, str]:
|
||||||
"""
|
"""
|
||||||
Execute Claude Code in an E2B sandbox.
|
Execute Claude Code in an E2B sandbox.
|
||||||
|
|
||||||
@@ -449,14 +452,18 @@ class ClaudeCodeBlock(Block):
|
|||||||
else:
|
else:
|
||||||
new_conversation_history = turn_entry
|
new_conversation_history = turn_entry
|
||||||
|
|
||||||
# Extract files created/modified during this run
|
# Extract files created/modified during this run and store to workspace
|
||||||
files = await self._extract_files(
|
sandbox_files = await extract_and_store_sandbox_files(
|
||||||
sandbox, working_directory, start_timestamp
|
sandbox=sandbox,
|
||||||
|
working_directory=working_directory,
|
||||||
|
execution_context=execution_context,
|
||||||
|
since_timestamp=start_timestamp,
|
||||||
|
text_only=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
response,
|
response,
|
||||||
files,
|
sandbox_files, # Already SandboxFileOutput objects
|
||||||
new_conversation_history,
|
new_conversation_history,
|
||||||
current_session_id,
|
current_session_id,
|
||||||
sandbox_id,
|
sandbox_id,
|
||||||
@@ -471,140 +478,6 @@ class ClaudeCodeBlock(Block):
|
|||||||
if dispose_sandbox and sandbox:
|
if dispose_sandbox and sandbox:
|
||||||
await sandbox.kill()
|
await sandbox.kill()
|
||||||
|
|
||||||
async def _extract_files(
|
|
||||||
self,
|
|
||||||
sandbox: BaseAsyncSandbox,
|
|
||||||
working_directory: str,
|
|
||||||
since_timestamp: str | None = None,
|
|
||||||
) -> list["ClaudeCodeBlock.FileOutput"]:
|
|
||||||
"""
|
|
||||||
Extract text files created/modified during this Claude Code execution.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sandbox: The E2B sandbox instance
|
|
||||||
working_directory: Directory to search for files
|
|
||||||
since_timestamp: ISO timestamp - only return files modified after this time
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of FileOutput objects with path, relative_path, name, and content
|
|
||||||
"""
|
|
||||||
files: list[ClaudeCodeBlock.FileOutput] = []
|
|
||||||
|
|
||||||
# Text file extensions we can safely read as text
|
|
||||||
text_extensions = {
|
|
||||||
".txt",
|
|
||||||
".md",
|
|
||||||
".html",
|
|
||||||
".htm",
|
|
||||||
".css",
|
|
||||||
".js",
|
|
||||||
".ts",
|
|
||||||
".jsx",
|
|
||||||
".tsx",
|
|
||||||
".json",
|
|
||||||
".xml",
|
|
||||||
".yaml",
|
|
||||||
".yml",
|
|
||||||
".toml",
|
|
||||||
".ini",
|
|
||||||
".cfg",
|
|
||||||
".conf",
|
|
||||||
".py",
|
|
||||||
".rb",
|
|
||||||
".php",
|
|
||||||
".java",
|
|
||||||
".c",
|
|
||||||
".cpp",
|
|
||||||
".h",
|
|
||||||
".hpp",
|
|
||||||
".cs",
|
|
||||||
".go",
|
|
||||||
".rs",
|
|
||||||
".swift",
|
|
||||||
".kt",
|
|
||||||
".scala",
|
|
||||||
".sh",
|
|
||||||
".bash",
|
|
||||||
".zsh",
|
|
||||||
".sql",
|
|
||||||
".graphql",
|
|
||||||
".env",
|
|
||||||
".gitignore",
|
|
||||||
".dockerfile",
|
|
||||||
"Dockerfile",
|
|
||||||
".vue",
|
|
||||||
".svelte",
|
|
||||||
".astro",
|
|
||||||
".mdx",
|
|
||||||
".rst",
|
|
||||||
".tex",
|
|
||||||
".csv",
|
|
||||||
".log",
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# List files recursively using find command
|
|
||||||
# Exclude node_modules and .git directories, but allow hidden files
|
|
||||||
# like .env and .gitignore (they're filtered by text_extensions later)
|
|
||||||
# Filter by timestamp to only get files created/modified during this run
|
|
||||||
safe_working_dir = shlex.quote(working_directory)
|
|
||||||
timestamp_filter = ""
|
|
||||||
if since_timestamp:
|
|
||||||
timestamp_filter = f"-newermt {shlex.quote(since_timestamp)} "
|
|
||||||
find_result = await sandbox.commands.run(
|
|
||||||
f"find {safe_working_dir} -type f "
|
|
||||||
f"{timestamp_filter}"
|
|
||||||
f"-not -path '*/node_modules/*' "
|
|
||||||
f"-not -path '*/.git/*' "
|
|
||||||
f"2>/dev/null"
|
|
||||||
)
|
|
||||||
|
|
||||||
if find_result.stdout:
|
|
||||||
for file_path in find_result.stdout.strip().split("\n"):
|
|
||||||
if not file_path:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if it's a text file we can read
|
|
||||||
is_text = any(
|
|
||||||
file_path.endswith(ext) for ext in text_extensions
|
|
||||||
) or file_path.endswith("Dockerfile")
|
|
||||||
|
|
||||||
if is_text:
|
|
||||||
try:
|
|
||||||
content = await sandbox.files.read(file_path)
|
|
||||||
# Handle bytes or string
|
|
||||||
if isinstance(content, bytes):
|
|
||||||
content = content.decode("utf-8", errors="replace")
|
|
||||||
|
|
||||||
# Extract filename from path
|
|
||||||
file_name = file_path.split("/")[-1]
|
|
||||||
|
|
||||||
# Calculate relative path by stripping working directory
|
|
||||||
relative_path = file_path
|
|
||||||
if file_path.startswith(working_directory):
|
|
||||||
relative_path = file_path[len(working_directory) :]
|
|
||||||
# Remove leading slash if present
|
|
||||||
if relative_path.startswith("/"):
|
|
||||||
relative_path = relative_path[1:]
|
|
||||||
|
|
||||||
files.append(
|
|
||||||
ClaudeCodeBlock.FileOutput(
|
|
||||||
path=file_path,
|
|
||||||
relative_path=relative_path,
|
|
||||||
name=file_name,
|
|
||||||
content=content,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
# Skip files that can't be read
|
|
||||||
pass
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
# If file extraction fails, return empty results
|
|
||||||
pass
|
|
||||||
|
|
||||||
return files
|
|
||||||
|
|
||||||
def _escape_prompt(self, prompt: str) -> str:
|
def _escape_prompt(self, prompt: str) -> str:
|
||||||
"""Escape the prompt for safe shell execution."""
|
"""Escape the prompt for safe shell execution."""
|
||||||
# Use single quotes and escape any single quotes in the prompt
|
# Use single quotes and escape any single quotes in the prompt
|
||||||
@@ -617,6 +490,7 @@ class ClaudeCodeBlock(Block):
|
|||||||
*,
|
*,
|
||||||
e2b_credentials: APIKeyCredentials,
|
e2b_credentials: APIKeyCredentials,
|
||||||
anthropic_credentials: APIKeyCredentials,
|
anthropic_credentials: APIKeyCredentials,
|
||||||
|
execution_context: "ExecutionContext",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
@@ -637,6 +511,7 @@ class ClaudeCodeBlock(Block):
|
|||||||
existing_sandbox_id=input_data.sandbox_id,
|
existing_sandbox_id=input_data.sandbox_id,
|
||||||
conversation_history=input_data.conversation_history,
|
conversation_history=input_data.conversation_history,
|
||||||
dispose_sandbox=input_data.dispose_sandbox,
|
dispose_sandbox=input_data.dispose_sandbox,
|
||||||
|
execution_context=execution_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield "response", response
|
yield "response", response
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Literal, Optional
|
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||||
|
|
||||||
from e2b_code_interpreter import AsyncSandbox
|
from e2b_code_interpreter import AsyncSandbox
|
||||||
from e2b_code_interpreter import Result as E2BExecutionResult
|
from e2b_code_interpreter import Result as E2BExecutionResult
|
||||||
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
||||||
from pydantic import BaseModel, Field, JsonValue, SecretStr
|
from pydantic import BaseModel, Field, JsonValue, SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
@@ -20,6 +20,13 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.sandbox_files import (
|
||||||
|
SandboxFileOutput,
|
||||||
|
extract_and_store_sandbox_files,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.executor.utils import ExecutionContext
|
||||||
|
|
||||||
TEST_CREDENTIALS = APIKeyCredentials(
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
@@ -85,6 +92,9 @@ class CodeExecutionResult(MainCodeExecutionResult):
|
|||||||
class BaseE2BExecutorMixin:
|
class BaseE2BExecutorMixin:
|
||||||
"""Shared implementation methods for E2B executor blocks."""
|
"""Shared implementation methods for E2B executor blocks."""
|
||||||
|
|
||||||
|
# Default working directory in E2B sandboxes
|
||||||
|
WORKING_DIR = "/home/user"
|
||||||
|
|
||||||
async def execute_code(
|
async def execute_code(
|
||||||
self,
|
self,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
@@ -95,14 +105,21 @@ class BaseE2BExecutorMixin:
|
|||||||
timeout: Optional[int] = None,
|
timeout: Optional[int] = None,
|
||||||
sandbox_id: Optional[str] = None,
|
sandbox_id: Optional[str] = None,
|
||||||
dispose_sandbox: bool = False,
|
dispose_sandbox: bool = False,
|
||||||
|
execution_context: Optional["ExecutionContext"] = None,
|
||||||
|
extract_files: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Unified code execution method that handles all three use cases:
|
Unified code execution method that handles all three use cases:
|
||||||
1. Create new sandbox and execute (ExecuteCodeBlock)
|
1. Create new sandbox and execute (ExecuteCodeBlock)
|
||||||
2. Create new sandbox, execute, and return sandbox_id (InstantiateCodeSandboxBlock)
|
2. Create new sandbox, execute, and return sandbox_id (InstantiateCodeSandboxBlock)
|
||||||
3. Connect to existing sandbox and execute (ExecuteCodeStepBlock)
|
3. Connect to existing sandbox and execute (ExecuteCodeStepBlock)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
extract_files: If True and execution_context provided, extract files
|
||||||
|
created/modified during execution and store to workspace.
|
||||||
""" # noqa
|
""" # noqa
|
||||||
sandbox = None
|
sandbox = None
|
||||||
|
files: list[SandboxFileOutput] = []
|
||||||
try:
|
try:
|
||||||
if sandbox_id:
|
if sandbox_id:
|
||||||
# Connect to existing sandbox (ExecuteCodeStepBlock case)
|
# Connect to existing sandbox (ExecuteCodeStepBlock case)
|
||||||
@@ -118,6 +135,12 @@ class BaseE2BExecutorMixin:
|
|||||||
for cmd in setup_commands:
|
for cmd in setup_commands:
|
||||||
await sandbox.commands.run(cmd)
|
await sandbox.commands.run(cmd)
|
||||||
|
|
||||||
|
# Capture timestamp before execution to scope file extraction
|
||||||
|
start_timestamp = None
|
||||||
|
if extract_files:
|
||||||
|
ts_result = await sandbox.commands.run("date -u +%Y-%m-%dT%H:%M:%S")
|
||||||
|
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
|
||||||
|
|
||||||
# Execute the code
|
# Execute the code
|
||||||
execution = await sandbox.run_code(
|
execution = await sandbox.run_code(
|
||||||
code,
|
code,
|
||||||
@@ -133,7 +156,24 @@ class BaseE2BExecutorMixin:
|
|||||||
stdout_logs = "".join(execution.logs.stdout)
|
stdout_logs = "".join(execution.logs.stdout)
|
||||||
stderr_logs = "".join(execution.logs.stderr)
|
stderr_logs = "".join(execution.logs.stderr)
|
||||||
|
|
||||||
return results, text_output, stdout_logs, stderr_logs, sandbox.sandbox_id
|
# Extract files created/modified during this execution
|
||||||
|
if extract_files and execution_context:
|
||||||
|
files = await extract_and_store_sandbox_files(
|
||||||
|
sandbox=sandbox,
|
||||||
|
working_directory=self.WORKING_DIR,
|
||||||
|
execution_context=execution_context,
|
||||||
|
since_timestamp=start_timestamp,
|
||||||
|
text_only=False, # Include binary files too
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
results,
|
||||||
|
text_output,
|
||||||
|
stdout_logs,
|
||||||
|
stderr_logs,
|
||||||
|
sandbox.sandbox_id,
|
||||||
|
files,
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
# Dispose of sandbox if requested to reduce usage costs
|
# Dispose of sandbox if requested to reduce usage costs
|
||||||
if dispose_sandbox and sandbox:
|
if dispose_sandbox and sandbox:
|
||||||
@@ -238,6 +278,12 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
|||||||
description="Standard output logs from execution"
|
description="Standard output logs from execution"
|
||||||
)
|
)
|
||||||
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
||||||
|
files: list[SandboxFileOutput] = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Files created or modified during execution. "
|
||||||
|
"Each file has path, name, content, and workspace_ref (if stored)."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -259,23 +305,30 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
|||||||
("results", []),
|
("results", []),
|
||||||
("response", "Hello World"),
|
("response", "Hello World"),
|
||||||
("stdout_logs", "Hello World\n"),
|
("stdout_logs", "Hello World\n"),
|
||||||
|
("files", []),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox: ( # noqa
|
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox, execution_context, extract_files: ( # noqa
|
||||||
[], # results
|
[], # results
|
||||||
"Hello World", # text_output
|
"Hello World", # text_output
|
||||||
"Hello World\n", # stdout_logs
|
"Hello World\n", # stdout_logs
|
||||||
"", # stderr_logs
|
"", # stderr_logs
|
||||||
"sandbox_id", # sandbox_id
|
"sandbox_id", # sandbox_id
|
||||||
|
[], # files
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: "ExecutionContext",
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
results, text_output, stdout, stderr, _ = await self.execute_code(
|
results, text_output, stdout, stderr, _, files = await self.execute_code(
|
||||||
api_key=credentials.api_key.get_secret_value(),
|
api_key=credentials.api_key.get_secret_value(),
|
||||||
code=input_data.code,
|
code=input_data.code,
|
||||||
language=input_data.language,
|
language=input_data.language,
|
||||||
@@ -283,6 +336,8 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
|||||||
setup_commands=input_data.setup_commands,
|
setup_commands=input_data.setup_commands,
|
||||||
timeout=input_data.timeout,
|
timeout=input_data.timeout,
|
||||||
dispose_sandbox=input_data.dispose_sandbox,
|
dispose_sandbox=input_data.dispose_sandbox,
|
||||||
|
execution_context=execution_context,
|
||||||
|
extract_files=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Determine result object shape & filter out empty formats
|
# Determine result object shape & filter out empty formats
|
||||||
@@ -296,6 +351,8 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
|||||||
yield "stdout_logs", stdout
|
yield "stdout_logs", stdout
|
||||||
if stderr:
|
if stderr:
|
||||||
yield "stderr_logs", stderr
|
yield "stderr_logs", stderr
|
||||||
|
# Always yield files (empty list if none)
|
||||||
|
yield "files", [f.model_dump() for f in files]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "error", str(e)
|
yield "error", str(e)
|
||||||
|
|
||||||
@@ -393,6 +450,7 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
|||||||
"Hello World\n", # stdout_logs
|
"Hello World\n", # stdout_logs
|
||||||
"", # stderr_logs
|
"", # stderr_logs
|
||||||
"sandbox_id", # sandbox_id
|
"sandbox_id", # sandbox_id
|
||||||
|
[], # files
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -401,7 +459,7 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
|||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
_, text_output, stdout, stderr, sandbox_id = await self.execute_code(
|
_, text_output, stdout, stderr, sandbox_id, _ = await self.execute_code(
|
||||||
api_key=credentials.api_key.get_secret_value(),
|
api_key=credentials.api_key.get_secret_value(),
|
||||||
code=input_data.setup_code,
|
code=input_data.setup_code,
|
||||||
language=input_data.language,
|
language=input_data.language,
|
||||||
@@ -500,6 +558,7 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
|||||||
"Hello World\n", # stdout_logs
|
"Hello World\n", # stdout_logs
|
||||||
"", # stderr_logs
|
"", # stderr_logs
|
||||||
sandbox_id, # sandbox_id
|
sandbox_id, # sandbox_id
|
||||||
|
[], # files
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -508,7 +567,7 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
|||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
results, text_output, stdout, stderr, _ = await self.execute_code(
|
results, text_output, stdout, stderr, _, _ = await self.execute_code(
|
||||||
api_key=credentials.api_key.get_secret_value(),
|
api_key=credentials.api_key.get_secret_value(),
|
||||||
code=input_data.step_code,
|
code=input_data.step_code,
|
||||||
language=input_data.language,
|
language=input_data.language,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from openai import AsyncOpenAI
|
|||||||
from openai.types.responses import Response as OpenAIResponse
|
from openai.types.responses import Response as OpenAIResponse
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockManualWebhookConfig,
|
BlockManualWebhookConfig,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import codecs
|
import codecs
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Any, Literal, cast
|
|||||||
import discord
|
import discord
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
Discord OAuth-based blocks.
|
Discord OAuth-based blocks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Literal
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, SecretStr
|
from pydantic import BaseModel, ConfigDict, SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ which provides access to LinkedIn profile data and related information.
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user