mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-09 22:35:54 -05:00
Compare commits
33 Commits
refactor/a
...
feat/copit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
626980bf27 | ||
|
|
6467f6734f | ||
|
|
5a30d11416 | ||
|
|
1f4105e8f9 | ||
|
|
caf9ff34e6 | ||
|
|
e42b27af3c | ||
|
|
34face15d2 | ||
|
|
e8fc8ee623 | ||
|
|
1a16e203b8 | ||
|
|
7d32c83f95 | ||
|
|
5dae303ce0 | ||
|
|
6e2a45b84e | ||
|
|
32f6532e9c | ||
|
|
6cbfbdd013 | ||
|
|
0c6fa60436 | ||
|
|
b04e916c23 | ||
|
|
1a32ba7d9a | ||
|
|
deccc26f1f | ||
|
|
9e38bd5b78 | ||
|
|
a329831b0b | ||
|
|
98dd1a9480 | ||
|
|
9c7c598c7d | ||
|
|
728c40def5 | ||
|
|
0bbe8a184d | ||
|
|
7592deed63 | ||
|
|
b9c759ce4f | ||
|
|
cd64562e1b | ||
|
|
5efb80d47b | ||
|
|
b49d8e2cba | ||
|
|
452544530d | ||
|
|
32ee7e6cf8 | ||
|
|
670663c406 | ||
|
|
0dbe4cf51e |
2
.github/workflows/classic-frontend-ci.yml
vendored
2
.github/workflows/classic-frontend-ci.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||||
if: github.event_name == 'push'
|
if: github.event_name == 'push'
|
||||||
uses: peter-evans/create-pull-request@v7
|
uses: peter-evans/create-pull-request@v8
|
||||||
with:
|
with:
|
||||||
add-paths: classic/frontend/build/web
|
add-paths: classic/frontend/build/web
|
||||||
base: ${{ github.ref_name }}
|
base: ${{ github.ref_name }}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Get CI failure details
|
- name: Get CI failure details
|
||||||
id: failure_details
|
id: failure_details
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const run = await github.rest.actions.getWorkflowRun({
|
const run = await github.rest.actions.getWorkflowRun({
|
||||||
|
|||||||
9
.github/workflows/claude-dependabot.yml
vendored
9
.github/workflows/claude-dependabot.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -78,7 +78,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -124,7 +124,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
@@ -309,6 +309,7 @@ jobs:
|
|||||||
uses: anthropics/claude-code-action@v1
|
uses: anthropics/claude-code-action@v1
|
||||||
with:
|
with:
|
||||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||||
|
allowed_bots: "dependabot[bot]"
|
||||||
claude_args: |
|
claude_args: |
|
||||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||||
prompt: |
|
prompt: |
|
||||||
|
|||||||
8
.github/workflows/claude.yml
vendored
8
.github/workflows/claude.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -94,7 +94,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -140,7 +140,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
|
|||||||
8
.github/workflows/copilot-setup-steps.yml
vendored
8
.github/workflows/copilot-setup-steps.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -76,7 +76,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -132,7 +132,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
|
|||||||
2
.github/workflows/docs-block-sync.yml
vendored
2
.github/workflows/docs-block-sync.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
2
.github/workflows/docs-claude-review.yml
vendored
2
.github/workflows/docs-claude-review.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
2
.github/workflows/docs-enhance.yml
vendored
2
.github/workflows/docs-enhance.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -88,7 +88,7 @@ jobs:
|
|||||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ jobs:
|
|||||||
- name: Check comment permissions and deployment status
|
- name: Check comment permissions and deployment status
|
||||||
id: check_status
|
id: check_status
|
||||||
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const commentBody = context.payload.comment.body.trim();
|
const commentBody = context.payload.comment.body.trim();
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post permission denied comment
|
- name: Post permission denied comment
|
||||||
if: steps.check_status.outputs.permission_denied == 'true'
|
if: steps.check_status.outputs.permission_denied == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -68,7 +68,7 @@ jobs:
|
|||||||
- name: Get PR details for deployment
|
- name: Get PR details for deployment
|
||||||
id: pr_details
|
id: pr_details
|
||||||
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const pr = await github.rest.pulls.get({
|
const pr = await github.rest.pulls.get({
|
||||||
@@ -98,7 +98,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post deploy success comment
|
- name: Post deploy success comment
|
||||||
if: steps.check_status.outputs.should_deploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -126,7 +126,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post undeploy success comment
|
- name: Post undeploy success comment
|
||||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -139,7 +139,7 @@ jobs:
|
|||||||
- name: Check deployment status on PR close
|
- name: Check deployment status on PR close
|
||||||
id: check_pr_close
|
id: check_pr_close
|
||||||
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const comments = await github.rest.issues.listComments({
|
const comments = await github.rest.issues.listComments({
|
||||||
@@ -187,7 +187,7 @@ jobs:
|
|||||||
github.event_name == 'pull_request' &&
|
github.event_name == 'pull_request' &&
|
||||||
github.event.action == 'closed' &&
|
github.event.action == 'closed' &&
|
||||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
|
|||||||
22
.github/workflows/platform-frontend-ci.yml
vendored
22
.github/workflows/platform-frontend-ci.yml
vendored
@@ -42,7 +42,7 @@ jobs:
|
|||||||
- 'autogpt_platform/frontend/src/components/**'
|
- 'autogpt_platform/frontend/src/components/**'
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -54,7 +54,7 @@ jobs:
|
|||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
key: ${{ steps.cache-key.outputs.key }}
|
||||||
@@ -74,7 +74,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -82,7 +82,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -112,7 +112,7 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -120,7 +120,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -153,7 +153,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -176,7 +176,7 @@ jobs:
|
|||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
- name: Cache Docker layers
|
- name: Cache Docker layers
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: /tmp/.buildx-cache
|
path: /tmp/.buildx-cache
|
||||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -231,7 +231,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -282,7 +282,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -290,7 +290,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
|||||||
8
.github/workflows/platform-fullstack-ci.yml
vendored
8
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ jobs:
|
|||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
key: ${{ steps.cache-key.outputs.key }}
|
||||||
@@ -68,7 +68,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -88,7 +88,7 @@ jobs:
|
|||||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
|||||||
1660
autogpt_platform/autogpt_libs/poetry.lock
generated
1660
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.10,<4.0"
|
python = ">=3.10,<4.0"
|
||||||
colorama = "^0.4.6"
|
colorama = "^0.4.6"
|
||||||
cryptography = "^45.0"
|
cryptography = "^46.0"
|
||||||
expiringdict = "^1.2.2"
|
expiringdict = "^1.2.2"
|
||||||
fastapi = "^0.116.1"
|
fastapi = "^0.128.0"
|
||||||
google-cloud-logging = "^3.12.1"
|
google-cloud-logging = "^3.13.0"
|
||||||
launchdarkly-server-sdk = "^9.12.0"
|
launchdarkly-server-sdk = "^9.14.1"
|
||||||
pydantic = "^2.11.7"
|
pydantic = "^2.12.5"
|
||||||
pydantic-settings = "^2.10.1"
|
pydantic-settings = "^2.12.0"
|
||||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
supabase = "^2.16.0"
|
supabase = "^2.27.2"
|
||||||
uvicorn = "^0.35.0"
|
uvicorn = "^0.40.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pyright = "^1.1.404"
|
pyright = "^1.1.408"
|
||||||
pytest = "^8.4.1"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.1.0"
|
pytest-asyncio = "^1.3.0"
|
||||||
pytest-mock = "^3.14.1"
|
pytest-mock = "^3.15.1"
|
||||||
pytest-cov = "^6.2.1"
|
pytest-cov = "^6.2.1"
|
||||||
ruff = "^0.12.11"
|
ruff = "^0.15.0"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|||||||
@@ -27,12 +27,20 @@ class ChatConfig(BaseSettings):
|
|||||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||||
|
|
||||||
# Streaming Configuration
|
# Streaming Configuration
|
||||||
|
# Note: When using Claude Agent SDK, context management is handled automatically
|
||||||
|
# via the SDK's built-in compaction. This is mainly used for the fallback path.
|
||||||
max_context_messages: int = Field(
|
max_context_messages: int = Field(
|
||||||
default=50, ge=1, le=200, description="Maximum context messages"
|
default=100,
|
||||||
|
ge=1,
|
||||||
|
le=500,
|
||||||
|
description="Max context messages (SDK handles compaction automatically)",
|
||||||
)
|
)
|
||||||
|
|
||||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
max_retries: int = Field(
|
||||||
|
default=3,
|
||||||
|
description="Max retries for fallback path (SDK handles retries internally)",
|
||||||
|
)
|
||||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||||
max_agent_schedules: int = Field(
|
max_agent_schedules: int = Field(
|
||||||
default=30, description="Maximum number of agent schedules"
|
default=30, description="Maximum number of agent schedules"
|
||||||
@@ -93,6 +101,12 @@ class ChatConfig(BaseSettings):
|
|||||||
description="Name of the prompt in Langfuse to fetch",
|
description="Name of the prompt in Langfuse to fetch",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Claude Agent SDK Configuration
|
||||||
|
use_claude_agent_sdk: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Use Claude Agent SDK for chat completions",
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("api_key", mode="before")
|
@field_validator("api_key", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_api_key(cls, v):
|
def get_api_key(cls, v):
|
||||||
@@ -132,6 +146,17 @@ class ChatConfig(BaseSettings):
|
|||||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@field_validator("use_claude_agent_sdk", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def get_use_claude_agent_sdk(cls, v):
|
||||||
|
"""Get use_claude_agent_sdk from environment if not provided."""
|
||||||
|
# Check environment variable - default to True if not set
|
||||||
|
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
||||||
|
if env_val:
|
||||||
|
return env_val in ("true", "1", "yes", "on")
|
||||||
|
# Default to True (SDK enabled by default)
|
||||||
|
return True if v is None else v
|
||||||
|
|
||||||
# Prompt paths for different contexts
|
# Prompt paths for different contexts
|
||||||
PROMPT_PATHS: dict[str, str] = {
|
PROMPT_PATHS: dict[str, str] = {
|
||||||
"default": "prompts/chat_system.md",
|
"default": "prompts/chat_system.md",
|
||||||
|
|||||||
@@ -45,10 +45,7 @@ async def create_chat_session(
|
|||||||
successfulAgentRuns=SafeJson({}),
|
successfulAgentRuns=SafeJson({}),
|
||||||
successfulAgentSchedules=SafeJson({}),
|
successfulAgentSchedules=SafeJson({}),
|
||||||
)
|
)
|
||||||
return await PrismaChatSession.prisma().create(
|
return await PrismaChatSession.prisma().create(data=data)
|
||||||
data=data,
|
|
||||||
include={"Messages": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def update_chat_session(
|
async def update_chat_session(
|
||||||
|
|||||||
@@ -273,9 +273,8 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
|||||||
try:
|
try:
|
||||||
session = ChatSession.model_validate_json(raw_session)
|
session = ChatSession.model_validate_json(raw_session)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loading session {session_id} from cache: "
|
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
|
||||||
f"message_count={len(session.messages)}, "
|
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
|
||||||
f"roles={[m.role for m in session.messages]}"
|
|
||||||
)
|
)
|
||||||
return session
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -317,11 +316,9 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
messages = prisma_session.Messages
|
messages = prisma_session.Messages
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"Loading session {session_id} from DB: "
|
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
|
||||||
f"has_messages={messages is not None}, "
|
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
|
||||||
f"message_count={len(messages) if messages else 0}, "
|
|
||||||
f"roles={[m.role for m in messages] if messages else []}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatSession.from_db(prisma_session, messages)
|
return ChatSession.from_db(prisma_session, messages)
|
||||||
@@ -372,10 +369,9 @@ async def _save_session_to_db(
|
|||||||
"function_call": msg.function_call,
|
"function_call": msg.function_call,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
|
||||||
f"roles={[m['role'] for m in messages_data]}, "
|
f"roles={[m['role'] for m in messages_data]}"
|
||||||
f"start_sequence={existing_message_count}"
|
|
||||||
)
|
)
|
||||||
await chat_db.add_chat_messages_batch(
|
await chat_db.add_chat_messages_batch(
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
@@ -415,7 +411,7 @@ async def get_chat_session(
|
|||||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||||
|
|
||||||
# Fall back to database
|
# Fall back to database
|
||||||
logger.info(f"Session {session_id} not in cache, checking database")
|
logger.debug(f"Session {session_id} not in cache, checking database")
|
||||||
session = await _get_session_from_db(session_id)
|
session = await _get_session_from_db(session_id)
|
||||||
|
|
||||||
if session is None:
|
if session is None:
|
||||||
@@ -432,7 +428,6 @@ async def get_chat_session(
|
|||||||
# Cache the session from DB
|
# Cache the session from DB
|
||||||
try:
|
try:
|
||||||
await _cache_session(session)
|
await _cache_session(session)
|
||||||
logger.info(f"Cached session {session_id} from database")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||||
|
|
||||||
@@ -603,13 +598,19 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
|||||||
logger.warning(f"Session {session_id} not found for title update")
|
logger.warning(f"Session {session_id} not found for title update")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Invalidate cache so next fetch gets updated title
|
# Update title in cache if it exists (instead of invalidating).
|
||||||
|
# This prevents race conditions where cache invalidation causes
|
||||||
|
# the frontend to see stale DB data while streaming is still in progress.
|
||||||
try:
|
try:
|
||||||
redis_key = _get_session_cache_key(session_id)
|
cached = await _get_session_from_cache(session_id)
|
||||||
async_redis = await get_redis_async()
|
if cached:
|
||||||
await async_redis.delete(redis_key)
|
cached.title = title
|
||||||
|
await _cache_session(cached)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
# Not critical - title will be correct on next full cache refresh
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to update title in cache for session {session_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Chat API routes for chat session management and streaming via SSE."""
|
"""Chat API routes for chat session management and streaming via SSE."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid as uuid_module
|
import uuid as uuid_module
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
@@ -16,8 +17,17 @@ from . import service as chat_service
|
|||||||
from . import stream_registry
|
from . import stream_registry
|
||||||
from .completion_handler import process_operation_failure, process_operation_success
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
from .model import (
|
||||||
|
ChatMessage,
|
||||||
|
ChatSession,
|
||||||
|
create_chat_session,
|
||||||
|
get_chat_session,
|
||||||
|
get_user_sessions,
|
||||||
|
upsert_chat_session,
|
||||||
|
)
|
||||||
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
||||||
|
from .sdk import service as sdk_service
|
||||||
|
from .tracking import track_user_message
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -209,6 +219,10 @@ async def get_session(
|
|||||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||||
session_id, user_id
|
session_id, user_id
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
|
||||||
|
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||||
|
)
|
||||||
if active_task:
|
if active_task:
|
||||||
# Filter out the in-progress assistant message from the session response.
|
# Filter out the in-progress assistant message from the session response.
|
||||||
# The client will receive the complete assistant response through the SSE
|
# The client will receive the complete assistant response through the SSE
|
||||||
@@ -266,12 +280,59 @@ async def stream_chat_post(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
stream_start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Base log metadata (task_id added after creation)
|
||||||
|
log_meta = {"component": "ChatStream", "session_id": session_id}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
|
||||||
|
f"user={user_id}, message_len={len(request.message)}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add user message to session BEFORE creating task to avoid race condition
|
||||||
|
# where GET_SESSION sees the task as "running" but the message isn't saved yet
|
||||||
|
if request.message:
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="user" if request.is_user_message else "assistant",
|
||||||
|
content=request.message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if request.is_user_message:
|
||||||
|
track_user_message(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
message_length=len(request.message),
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[STREAM] Saving user message to session {session_id}, "
|
||||||
|
f"msg_count={len(session.messages)}"
|
||||||
|
)
|
||||||
|
session = await upsert_chat_session(session)
|
||||||
|
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Create a task in the stream registry for reconnection support
|
||||||
task_id = str(uuid_module.uuid4())
|
task_id = str(uuid_module.uuid4())
|
||||||
operation_id = str(uuid_module.uuid4())
|
operation_id = str(uuid_module.uuid4())
|
||||||
|
log_meta["task_id"] = task_id
|
||||||
|
|
||||||
|
task_create_start = time.perf_counter()
|
||||||
await stream_registry.create_task(
|
await stream_registry.create_task(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -280,72 +341,260 @@ async def stream_chat_post(
|
|||||||
tool_name="chat",
|
tool_name="chat",
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Background task that runs the AI generation independently of SSE connection
|
# Background task that runs the AI generation independently of SSE connection
|
||||||
async def run_ai_generation():
|
async def run_ai_generation():
|
||||||
|
import time as time_module
|
||||||
|
|
||||||
|
gen_start_time = time_module.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
first_chunk_time, ttfc = None, None
|
||||||
|
chunk_count = 0
|
||||||
try:
|
try:
|
||||||
# Emit a start event with task_id for reconnection
|
# Emit a start event with task_id for reconnection
|
||||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time)*1000:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
|
||||||
|
* 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
# Choose service based on configuration
|
||||||
|
use_sdk = config.use_claude_agent_sdk
|
||||||
|
stream_fn = (
|
||||||
|
sdk_service.stream_chat_completion_sdk
|
||||||
|
if use_sdk
|
||||||
|
else chat_service.stream_chat_completion
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
# Pass message=None since we already added it to the session above
|
||||||
|
async for chunk in stream_fn(
|
||||||
session_id,
|
session_id,
|
||||||
request.message,
|
None, # Message already in session
|
||||||
is_user_message=request.is_user_message,
|
is_user_message=request.is_user_message,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
session=session, # Pass session with message already added
|
||||||
context=request.context,
|
context=request.context,
|
||||||
):
|
):
|
||||||
|
chunk_count += 1
|
||||||
|
if first_chunk_time is None:
|
||||||
|
first_chunk_time = time_module.perf_counter()
|
||||||
|
ttfc = first_chunk_time - gen_start_time
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunk_type": type(chunk).__name__,
|
||||||
|
"time_to_first_chunk_ms": ttfc * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
# Write to Redis (subscribers will receive via XREAD)
|
# Write to Redis (subscribers will receive via XREAD)
|
||||||
await stream_registry.publish_chunk(task_id, chunk)
|
await stream_registry.publish_chunk(task_id, chunk)
|
||||||
|
|
||||||
# Mark task as completed
|
gen_end_time = time_module.perf_counter()
|
||||||
|
total_time = (gen_end_time - gen_start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
||||||
|
f"task={task_id}, session={session_id}, "
|
||||||
|
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"time_to_first_chunk_ms": (
|
||||||
|
ttfc * 1000 if ttfc is not None else None
|
||||||
|
),
|
||||||
|
"n_chunks": chunk_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
await stream_registry.mark_task_completed(task_id, "completed")
|
await stream_registry.mark_task_completed(task_id, "completed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
elapsed = time_module.perf_counter() - gen_start_time
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error in background AI generation for session {session_id}: {e}"
|
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed * 1000,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
await stream_registry.mark_task_completed(task_id, "failed")
|
await stream_registry.mark_task_completed(task_id, "failed")
|
||||||
|
|
||||||
# Start the AI generation in a background task
|
# Start the AI generation in a background task
|
||||||
bg_task = asyncio.create_task(run_ai_generation())
|
bg_task = asyncio.create_task(run_ai_generation())
|
||||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||||
|
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||||
|
)
|
||||||
|
|
||||||
# SSE endpoint that subscribes to the task's stream
|
# SSE endpoint that subscribes to the task's stream
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
import time as time_module
|
||||||
|
|
||||||
|
event_gen_start = time_module.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
||||||
|
f"user={user_id}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
subscriber_queue = None
|
subscriber_queue = None
|
||||||
|
first_chunk_yielded = False
|
||||||
|
chunks_yielded = 0
|
||||||
try:
|
try:
|
||||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||||
|
subscribe_start = time_module.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] Calling subscribe_to_task",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
last_message_id="0-0", # Get all messages from the beginning
|
last_message_id="0-0", # Get all messages from the beginning
|
||||||
)
|
)
|
||||||
|
subscribe_time = (time_module.perf_counter() - subscribe_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] subscribe_to_task completed in {subscribe_time:.1f}ms, "
|
||||||
|
f"queue_ok={subscriber_queue is not None}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": subscribe_time,
|
||||||
|
"queue_obtained": subscriber_queue is not None,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if subscriber_queue is None:
|
if subscriber_queue is None:
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] subscriber_queue is None, yielding finish",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
yield StreamFinish().to_sse()
|
yield StreamFinish().to_sse()
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
return
|
return
|
||||||
|
|
||||||
# Read from the subscriber queue and yield to SSE
|
# Read from the subscriber queue and yield to SSE
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] Starting to read from subscriber_queue",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
queue_wait_start = time_module.perf_counter()
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||||
|
queue_wait_time = (
|
||||||
|
time_module.perf_counter() - queue_wait_start
|
||||||
|
) * 1000
|
||||||
|
chunks_yielded += 1
|
||||||
|
|
||||||
|
if not first_chunk_yielded:
|
||||||
|
first_chunk_yielded = True
|
||||||
|
elapsed = time_module.perf_counter() - event_gen_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
|
||||||
|
f"type={type(chunk).__name__}, "
|
||||||
|
f"wait={queue_wait_time:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunk_type": type(chunk).__name__,
|
||||||
|
"elapsed_ms": elapsed * 1000,
|
||||||
|
"queue_wait_ms": queue_wait_time,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif chunks_yielded % 50 == 0:
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Chunk #{chunks_yielded}, "
|
||||||
|
f"type={type(chunk).__name__}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunk_number": chunks_yielded,
|
||||||
|
"chunk_type": type(chunk).__name__,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
yield chunk.to_sse()
|
yield chunk.to_sse()
|
||||||
|
|
||||||
# Check for finish signal
|
# Check for finish signal
|
||||||
if isinstance(chunk, StreamFinish):
|
if isinstance(chunk, StreamFinish):
|
||||||
|
total_time = time_module.perf_counter() - event_gen_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
|
||||||
|
f"n_chunks={chunks_yielded}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunks_yielded": chunks_yielded,
|
||||||
|
"total_time_ms": total_time * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
break
|
break
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# Send heartbeat to keep connection alive
|
# Send heartbeat to keep connection alive
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Heartbeat timeout, chunks_so_far={chunks_yielded}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {**log_meta, "chunks_so_far": chunks_yielded}
|
||||||
|
},
|
||||||
|
)
|
||||||
yield StreamHeartbeat().to_sse()
|
yield StreamHeartbeat().to_sse()
|
||||||
|
|
||||||
except GeneratorExit:
|
except GeneratorExit:
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunks_yielded": chunks_yielded,
|
||||||
|
"reason": "client_disconnect",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
pass # Client disconnected - background task continues
|
pass # Client disconnected - background task continues
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||||
|
logger.error(
|
||||||
|
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||||
|
},
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
# Unsubscribe when client disconnects or stream ends
|
||||||
if subscriber_queue is not None:
|
if subscriber_queue is not None:
|
||||||
try:
|
try:
|
||||||
await stream_registry.unsubscribe_from_task(
|
await stream_registry.unsubscribe_from_task(
|
||||||
@@ -357,6 +606,18 @@ async def stream_chat_post(
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||||
|
total_time = time_module.perf_counter() - event_gen_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
||||||
|
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time * 1000,
|
||||||
|
"chunks_yielded": chunks_yielded,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
@@ -400,35 +661,21 @@ async def stream_chat_get(
|
|||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
chunk_count = 0
|
# Choose service based on configuration
|
||||||
first_chunk_type: str | None = None
|
use_sdk = config.use_claude_agent_sdk
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
stream_fn = (
|
||||||
|
sdk_service.stream_chat_completion_sdk
|
||||||
|
if use_sdk
|
||||||
|
else chat_service.stream_chat_completion
|
||||||
|
)
|
||||||
|
async for chunk in stream_fn(
|
||||||
session_id,
|
session_id,
|
||||||
message,
|
message,
|
||||||
is_user_message=is_user_message,
|
is_user_message=is_user_message,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
):
|
):
|
||||||
if chunk_count < 3:
|
|
||||||
logger.info(
|
|
||||||
"Chat stream chunk",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_type": str(chunk.type),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if not first_chunk_type:
|
|
||||||
first_chunk_type = str(chunk.type)
|
|
||||||
chunk_count += 1
|
|
||||||
yield chunk.to_sse()
|
yield chunk.to_sse()
|
||||||
logger.info(
|
|
||||||
"Chat stream completed",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"first_chunk_type": first_chunk_type,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination
|
# AI SDK protocol termination
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
@@ -550,8 +797,6 @@ async def stream_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
import asyncio
|
|
||||||
|
|
||||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
"""Claude Agent SDK integration for CoPilot.
|
||||||
|
|
||||||
|
This module provides the integration layer between the Claude Agent SDK
|
||||||
|
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||||
|
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .service import stream_chat_completion_sdk
|
||||||
|
from .tool_adapter import create_copilot_mcp_server
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"stream_chat_completion_sdk",
|
||||||
|
"create_copilot_mcp_server",
|
||||||
|
]
|
||||||
@@ -0,0 +1,348 @@
|
|||||||
|
"""Anthropic SDK fallback implementation.
|
||||||
|
|
||||||
|
This module provides the fallback streaming implementation using the Anthropic SDK
|
||||||
|
directly when the Claude Agent SDK is not available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from ..model import ChatMessage, ChatSession
|
||||||
|
from ..response_model import (
|
||||||
|
StreamBaseResponse,
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamTextEnd,
|
||||||
|
StreamTextStart,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolInputStart,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
StreamUsage,
|
||||||
|
)
|
||||||
|
from .tool_adapter import get_tool_definitions, get_tool_handlers
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_with_anthropic(
|
||||||
|
session: ChatSession,
|
||||||
|
system_prompt: str,
|
||||||
|
text_block_id: str,
|
||||||
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
|
"""Stream using Anthropic SDK directly with tool calling support.
|
||||||
|
|
||||||
|
This function accumulates messages into the session for persistence.
|
||||||
|
The caller should NOT yield an additional StreamFinish - this function handles it.
|
||||||
|
"""
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
# Only use ANTHROPIC_API_KEY - don't fall back to OpenRouter keys
|
||||||
|
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
yield StreamError(
|
||||||
|
errorText="ANTHROPIC_API_KEY not configured for fallback",
|
||||||
|
code="config_error",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
return
|
||||||
|
|
||||||
|
client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||||
|
tool_definitions = get_tool_definitions()
|
||||||
|
tool_handlers = get_tool_handlers()
|
||||||
|
|
||||||
|
anthropic_tools = [
|
||||||
|
{
|
||||||
|
"name": t["name"],
|
||||||
|
"description": t["description"],
|
||||||
|
"input_schema": t["inputSchema"],
|
||||||
|
}
|
||||||
|
for t in tool_definitions
|
||||||
|
]
|
||||||
|
|
||||||
|
anthropic_messages = _convert_session_to_anthropic(session)
|
||||||
|
|
||||||
|
if not anthropic_messages or anthropic_messages[-1]["role"] != "user":
|
||||||
|
anthropic_messages.append(
|
||||||
|
{"role": "user", "content": "Continue with the task."}
|
||||||
|
)
|
||||||
|
|
||||||
|
has_started_text = False
|
||||||
|
max_iterations = 10
|
||||||
|
accumulated_text = ""
|
||||||
|
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for _ in range(max_iterations):
|
||||||
|
try:
|
||||||
|
async with client.messages.stream(
|
||||||
|
model="claude-sonnet-4-20250514",
|
||||||
|
max_tokens=4096,
|
||||||
|
system=system_prompt,
|
||||||
|
messages=cast(Any, anthropic_messages),
|
||||||
|
tools=cast(Any, anthropic_tools) if anthropic_tools else [],
|
||||||
|
) as stream:
|
||||||
|
async for event in stream:
|
||||||
|
if event.type == "content_block_start":
|
||||||
|
block = event.content_block
|
||||||
|
if hasattr(block, "type"):
|
||||||
|
if block.type == "text" and not has_started_text:
|
||||||
|
yield StreamTextStart(id=text_block_id)
|
||||||
|
has_started_text = True
|
||||||
|
elif block.type == "tool_use":
|
||||||
|
yield StreamToolInputStart(
|
||||||
|
toolCallId=block.id, toolName=block.name
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event.type == "content_block_delta":
|
||||||
|
delta = event.delta
|
||||||
|
if hasattr(delta, "type") and delta.type == "text_delta":
|
||||||
|
accumulated_text += delta.text
|
||||||
|
yield StreamTextDelta(id=text_block_id, delta=delta.text)
|
||||||
|
|
||||||
|
final_message = await stream.get_final_message()
|
||||||
|
|
||||||
|
if final_message.stop_reason == "tool_use":
|
||||||
|
if has_started_text:
|
||||||
|
yield StreamTextEnd(id=text_block_id)
|
||||||
|
has_started_text = False
|
||||||
|
text_block_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
tool_results = []
|
||||||
|
assistant_content: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for block in final_message.content:
|
||||||
|
if block.type == "text":
|
||||||
|
assistant_content.append(
|
||||||
|
{"type": "text", "text": block.text}
|
||||||
|
)
|
||||||
|
elif block.type == "tool_use":
|
||||||
|
assistant_content.append(
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": block.id,
|
||||||
|
"name": block.name,
|
||||||
|
"input": block.input,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track tool call for session persistence
|
||||||
|
accumulated_tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": block.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": block.name,
|
||||||
|
"arguments": json.dumps(
|
||||||
|
block.input
|
||||||
|
if isinstance(block.input, dict)
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
yield StreamToolInputAvailable(
|
||||||
|
toolCallId=block.id,
|
||||||
|
toolName=block.name,
|
||||||
|
input=(
|
||||||
|
block.input if isinstance(block.input, dict) else {}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
output, is_error = await _execute_tool(
|
||||||
|
block.name, block.input, tool_handlers
|
||||||
|
)
|
||||||
|
|
||||||
|
yield StreamToolOutputAvailable(
|
||||||
|
toolCallId=block.id,
|
||||||
|
toolName=block.name,
|
||||||
|
output=output,
|
||||||
|
success=not is_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save tool result to session
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="tool",
|
||||||
|
content=output,
|
||||||
|
tool_call_id=block.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_results.append(
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": block.id,
|
||||||
|
"content": output,
|
||||||
|
"is_error": is_error,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save assistant message with tool calls to session
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=accumulated_text or None,
|
||||||
|
tool_calls=(
|
||||||
|
accumulated_tool_calls
|
||||||
|
if accumulated_tool_calls
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Reset for next iteration
|
||||||
|
accumulated_text = ""
|
||||||
|
accumulated_tool_calls = []
|
||||||
|
|
||||||
|
anthropic_messages.append(
|
||||||
|
{"role": "assistant", "content": assistant_content}
|
||||||
|
)
|
||||||
|
anthropic_messages.append({"role": "user", "content": tool_results})
|
||||||
|
continue
|
||||||
|
|
||||||
|
else:
|
||||||
|
if has_started_text:
|
||||||
|
yield StreamTextEnd(id=text_block_id)
|
||||||
|
|
||||||
|
# Save final assistant response to session
|
||||||
|
if accumulated_text:
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(role="assistant", content=accumulated_text)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield StreamUsage(
|
||||||
|
promptTokens=final_message.usage.input_tokens,
|
||||||
|
completionTokens=final_message.usage.output_tokens,
|
||||||
|
totalTokens=final_message.usage.input_tokens
|
||||||
|
+ final_message.usage.output_tokens,
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Anthropic Fallback] Error: {e}", exc_info=True)
|
||||||
|
yield StreamError(
|
||||||
|
errorText="An error occurred. Please try again.",
|
||||||
|
code="anthropic_error",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
return
|
||||||
|
|
||||||
|
yield StreamError(errorText="Max tool iterations reached", code="max_iterations")
|
||||||
|
yield StreamFinish()
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]:
|
||||||
|
"""Convert session messages to Anthropic format.
|
||||||
|
|
||||||
|
Handles merging consecutive same-role messages (Anthropic requires alternating roles).
|
||||||
|
"""
|
||||||
|
messages: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for msg in session.messages:
|
||||||
|
if msg.role == "user":
|
||||||
|
new_msg = {"role": "user", "content": msg.content or ""}
|
||||||
|
elif msg.role == "assistant":
|
||||||
|
content: list[dict[str, Any]] = []
|
||||||
|
if msg.content:
|
||||||
|
content.append({"type": "text", "text": msg.content})
|
||||||
|
if msg.tool_calls:
|
||||||
|
for tc in msg.tool_calls:
|
||||||
|
func = tc.get("function", {})
|
||||||
|
args = func.get("arguments", {})
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
args = json.loads(args)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
args = {}
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": tc.get("id", str(uuid.uuid4())),
|
||||||
|
"name": func.get("name", ""),
|
||||||
|
"input": args,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if content:
|
||||||
|
new_msg = {"role": "assistant", "content": content}
|
||||||
|
else:
|
||||||
|
continue # Skip empty assistant messages
|
||||||
|
elif msg.role == "tool":
|
||||||
|
new_msg = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": msg.tool_call_id or "",
|
||||||
|
"content": msg.content or "",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
messages.append(new_msg)
|
||||||
|
|
||||||
|
# Merge consecutive same-role messages (Anthropic requires alternating roles)
|
||||||
|
return _merge_consecutive_roles(messages)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_consecutive_roles(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Merge consecutive messages with the same role.
|
||||||
|
|
||||||
|
Anthropic API requires alternating user/assistant roles.
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
merged: list[dict[str, Any]] = []
|
||||||
|
for msg in messages:
|
||||||
|
if merged and merged[-1]["role"] == msg["role"]:
|
||||||
|
# Merge with previous message
|
||||||
|
prev_content = merged[-1]["content"]
|
||||||
|
new_content = msg["content"]
|
||||||
|
|
||||||
|
# Normalize both to list-of-blocks form
|
||||||
|
if isinstance(prev_content, str):
|
||||||
|
prev_content = [{"type": "text", "text": prev_content}]
|
||||||
|
if isinstance(new_content, str):
|
||||||
|
new_content = [{"type": "text", "text": new_content}]
|
||||||
|
|
||||||
|
# Ensure both are lists
|
||||||
|
if not isinstance(prev_content, list):
|
||||||
|
prev_content = [prev_content]
|
||||||
|
if not isinstance(new_content, list):
|
||||||
|
new_content = [new_content]
|
||||||
|
|
||||||
|
merged[-1]["content"] = prev_content + new_content
|
||||||
|
else:
|
||||||
|
merged.append(msg)
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_tool(
|
||||||
|
tool_name: str, tool_input: Any, handlers: dict[str, Any]
|
||||||
|
) -> tuple[str, bool]:
|
||||||
|
"""Execute a tool and return (output, is_error)."""
|
||||||
|
handler = handlers.get(tool_name)
|
||||||
|
if not handler:
|
||||||
|
return f"Unknown tool: {tool_name}", True
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await handler(tool_input)
|
||||||
|
# Safely extract output - handle empty or missing content
|
||||||
|
content = result.get("content") or []
|
||||||
|
if content and isinstance(content, list) and len(content) > 0:
|
||||||
|
first_item = content[0]
|
||||||
|
output = first_item.get("text", "") if isinstance(first_item, dict) else ""
|
||||||
|
else:
|
||||||
|
output = ""
|
||||||
|
is_error = result.get("isError", False)
|
||||||
|
return output, is_error
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error: {str(e)}", True
|
||||||
@@ -0,0 +1,320 @@
|
|||||||
|
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||||
|
|
||||||
|
This module provides the adapter layer that converts streaming messages from
|
||||||
|
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
|
||||||
|
the frontend expects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
|
from backend.api.features.chat.response_model import (
|
||||||
|
StreamBaseResponse,
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamHeartbeat,
|
||||||
|
StreamStart,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamTextEnd,
|
||||||
|
StreamTextStart,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolInputStart,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
StreamUsage,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SDKResponseAdapter:
|
||||||
|
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||||
|
|
||||||
|
This class maintains state during a streaming session to properly track
|
||||||
|
text blocks, tool calls, and message lifecycle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message_id: str | None = None):
|
||||||
|
"""Initialize the adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_id: Optional message ID. If not provided, one will be generated.
|
||||||
|
"""
|
||||||
|
self.message_id = message_id or str(uuid.uuid4())
|
||||||
|
self.text_block_id = str(uuid.uuid4())
|
||||||
|
self.has_started_text = False
|
||||||
|
self.has_ended_text = False
|
||||||
|
self.current_tool_calls: dict[str, dict[str, Any]] = {}
|
||||||
|
self.task_id: str | None = None
|
||||||
|
|
||||||
|
def set_task_id(self, task_id: str) -> None:
|
||||||
|
"""Set the task ID for reconnection support."""
|
||||||
|
self.task_id = task_id
|
||||||
|
|
||||||
|
def convert_message(self, sdk_message: Any) -> list[StreamBaseResponse]:
|
||||||
|
"""Convert a single SDK message to Vercel AI SDK format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sdk_message: A message from the Claude Agent SDK.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of StreamBaseResponse objects (may be empty or multiple).
|
||||||
|
"""
|
||||||
|
responses: list[StreamBaseResponse] = []
|
||||||
|
|
||||||
|
# Handle different SDK message types - use class name since SDK uses dataclasses
|
||||||
|
class_name = type(sdk_message).__name__
|
||||||
|
msg_subtype = getattr(sdk_message, "subtype", None)
|
||||||
|
|
||||||
|
if class_name == "SystemMessage":
|
||||||
|
if msg_subtype == "init":
|
||||||
|
# Session initialization - emit start
|
||||||
|
responses.append(
|
||||||
|
StreamStart(
|
||||||
|
messageId=self.message_id,
|
||||||
|
taskId=self.task_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif class_name == "AssistantMessage":
|
||||||
|
# Assistant message with content blocks
|
||||||
|
content = getattr(sdk_message, "content", [])
|
||||||
|
for block in content:
|
||||||
|
# Check block type by class name (SDK uses dataclasses) or dict type
|
||||||
|
block_class = type(block).__name__
|
||||||
|
block_type = block.get("type") if isinstance(block, dict) else None
|
||||||
|
|
||||||
|
if block_class == "TextBlock" or block_type == "text":
|
||||||
|
# Text content
|
||||||
|
text = getattr(block, "text", None) or (
|
||||||
|
block.get("text") if isinstance(block, dict) else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
if text:
|
||||||
|
# Start text block if needed (or restart after tool calls)
|
||||||
|
if not self.has_started_text or self.has_ended_text:
|
||||||
|
# Generate new text block ID for text after tools
|
||||||
|
if self.has_ended_text:
|
||||||
|
self.text_block_id = str(uuid.uuid4())
|
||||||
|
self.has_ended_text = False
|
||||||
|
responses.append(StreamTextStart(id=self.text_block_id))
|
||||||
|
self.has_started_text = True
|
||||||
|
|
||||||
|
# Emit text delta
|
||||||
|
responses.append(
|
||||||
|
StreamTextDelta(
|
||||||
|
id=self.text_block_id,
|
||||||
|
delta=text,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif block_class == "ToolUseBlock" or block_type == "tool_use":
|
||||||
|
# Tool call
|
||||||
|
tool_id_raw = getattr(block, "id", None) or (
|
||||||
|
block.get("id") if isinstance(block, dict) else None
|
||||||
|
)
|
||||||
|
tool_id: str = (
|
||||||
|
str(tool_id_raw) if tool_id_raw else str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_name_raw = getattr(block, "name", None) or (
|
||||||
|
block.get("name") if isinstance(block, dict) else None
|
||||||
|
)
|
||||||
|
tool_name: str = str(tool_name_raw) if tool_name_raw else "unknown"
|
||||||
|
|
||||||
|
tool_input = getattr(block, "input", None) or (
|
||||||
|
block.get("input") if isinstance(block, dict) else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# End text block if we were streaming text
|
||||||
|
if self.has_started_text and not self.has_ended_text:
|
||||||
|
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||||
|
self.has_ended_text = True
|
||||||
|
|
||||||
|
# Emit tool input start
|
||||||
|
responses.append(
|
||||||
|
StreamToolInputStart(
|
||||||
|
toolCallId=tool_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit tool input available with full input
|
||||||
|
responses.append(
|
||||||
|
StreamToolInputAvailable(
|
||||||
|
toolCallId=tool_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
input=tool_input if isinstance(tool_input, dict) else {},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track the tool call
|
||||||
|
self.current_tool_calls[tool_id] = {
|
||||||
|
"name": tool_name,
|
||||||
|
"input": tool_input,
|
||||||
|
}
|
||||||
|
|
||||||
|
elif class_name in ("ToolResultMessage", "UserMessage"):
|
||||||
|
# Tool result - check for tool_result content
|
||||||
|
content = getattr(sdk_message, "content", [])
|
||||||
|
|
||||||
|
for block in content:
|
||||||
|
block_class = type(block).__name__
|
||||||
|
block_type = block.get("type") if isinstance(block, dict) else None
|
||||||
|
|
||||||
|
if block_class == "ToolResultBlock" or block_type == "tool_result":
|
||||||
|
tool_use_id = getattr(block, "tool_use_id", None) or (
|
||||||
|
block.get("tool_use_id") if isinstance(block, dict) else None
|
||||||
|
)
|
||||||
|
result_content = getattr(block, "content", None) or (
|
||||||
|
block.get("content") if isinstance(block, dict) else ""
|
||||||
|
)
|
||||||
|
is_error = getattr(block, "is_error", False) or (
|
||||||
|
block.get("is_error", False)
|
||||||
|
if isinstance(block, dict)
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
|
||||||
|
if tool_use_id:
|
||||||
|
tool_info = self.current_tool_calls.get(tool_use_id, {})
|
||||||
|
tool_name = tool_info.get("name", "unknown")
|
||||||
|
|
||||||
|
# Format the output
|
||||||
|
if isinstance(result_content, list):
|
||||||
|
# Extract text from content blocks
|
||||||
|
output_text = ""
|
||||||
|
for item in result_content:
|
||||||
|
if (
|
||||||
|
isinstance(item, dict)
|
||||||
|
and item.get("type") == "text"
|
||||||
|
):
|
||||||
|
output_text += item.get("text", "")
|
||||||
|
elif hasattr(item, "text"):
|
||||||
|
output_text += getattr(item, "text", "")
|
||||||
|
if output_text:
|
||||||
|
output = output_text
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
output = json.dumps(result_content)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
output = str(result_content)
|
||||||
|
elif isinstance(result_content, str):
|
||||||
|
output = result_content
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
output = json.dumps(result_content)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
output = str(result_content)
|
||||||
|
|
||||||
|
responses.append(
|
||||||
|
StreamToolOutputAvailable(
|
||||||
|
toolCallId=tool_use_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
output=output,
|
||||||
|
success=not is_error,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif class_name == "ResultMessage":
|
||||||
|
# Final result
|
||||||
|
if msg_subtype == "success":
|
||||||
|
# End text block if still open
|
||||||
|
if self.has_started_text and not self.has_ended_text:
|
||||||
|
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||||
|
self.has_ended_text = True
|
||||||
|
|
||||||
|
# Emit finish
|
||||||
|
responses.append(StreamFinish())
|
||||||
|
|
||||||
|
elif msg_subtype in ("error", "error_during_execution"):
|
||||||
|
error_msg = getattr(sdk_message, "error", "Unknown error")
|
||||||
|
responses.append(
|
||||||
|
StreamError(
|
||||||
|
errorText=str(error_msg),
|
||||||
|
code="sdk_error",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
responses.append(StreamFinish())
|
||||||
|
|
||||||
|
elif class_name == "ErrorMessage":
|
||||||
|
# Error message
|
||||||
|
error_msg = getattr(sdk_message, "message", None) or getattr(
|
||||||
|
sdk_message, "error", "Unknown error"
|
||||||
|
)
|
||||||
|
responses.append(
|
||||||
|
StreamError(
|
||||||
|
errorText=str(error_msg),
|
||||||
|
code="sdk_error",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
responses.append(StreamFinish())
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.debug(f"Unhandled SDK message type: {class_name}")
|
||||||
|
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def create_heartbeat(self, tool_call_id: str | None = None) -> StreamHeartbeat:
|
||||||
|
"""Create a heartbeat response."""
|
||||||
|
return StreamHeartbeat(toolCallId=tool_call_id)
|
||||||
|
|
||||||
|
def create_usage(
|
||||||
|
self,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
) -> StreamUsage:
|
||||||
|
"""Create a usage statistics response."""
|
||||||
|
return StreamUsage(
|
||||||
|
promptTokens=prompt_tokens,
|
||||||
|
completionTokens=completion_tokens,
|
||||||
|
totalTokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def adapt_sdk_stream(
|
||||||
|
sdk_stream: AsyncGenerator[Any, None],
|
||||||
|
message_id: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
|
"""Adapt a Claude Agent SDK stream to Vercel AI SDK format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sdk_stream: The async generator from the Claude Agent SDK.
|
||||||
|
message_id: Optional message ID for the response.
|
||||||
|
task_id: Optional task ID for reconnection support.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamBaseResponse objects in Vercel AI SDK format.
|
||||||
|
"""
|
||||||
|
adapter = SDKResponseAdapter(message_id=message_id)
|
||||||
|
if task_id:
|
||||||
|
adapter.set_task_id(task_id)
|
||||||
|
|
||||||
|
# Emit start immediately
|
||||||
|
yield StreamStart(messageId=adapter.message_id, taskId=task_id)
|
||||||
|
|
||||||
|
finished = False
|
||||||
|
try:
|
||||||
|
async for sdk_message in sdk_stream:
|
||||||
|
responses = adapter.convert_message(sdk_message)
|
||||||
|
for response in responses:
|
||||||
|
# Skip duplicate start messages
|
||||||
|
if isinstance(response, StreamStart):
|
||||||
|
continue
|
||||||
|
if isinstance(response, StreamFinish):
|
||||||
|
finished = True
|
||||||
|
yield response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in SDK stream: {e}", exc_info=True)
|
||||||
|
yield StreamError(
|
||||||
|
errorText="An error occurred. Please try again.",
|
||||||
|
code="stream_error",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ensure terminal StreamFinish if SDK stream ended without one
|
||||||
|
if not finished:
|
||||||
|
yield StreamFinish()
|
||||||
@@ -0,0 +1,281 @@
|
|||||||
|
"""Security hooks for Claude Agent SDK integration.
|
||||||
|
|
||||||
|
This module provides security hooks that validate tool calls before execution,
|
||||||
|
ensuring multi-user isolation and preventing unauthorized operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Tools that are blocked entirely (CLI/system access)
|
||||||
|
BLOCKED_TOOLS = {
|
||||||
|
"Bash",
|
||||||
|
"bash",
|
||||||
|
"shell",
|
||||||
|
"exec",
|
||||||
|
"terminal",
|
||||||
|
"command",
|
||||||
|
"Read", # Block raw file read - use workspace tools instead
|
||||||
|
"Write", # Block raw file write - use workspace tools instead
|
||||||
|
"Edit", # Block raw file edit - use workspace tools instead
|
||||||
|
"Glob", # Block raw file glob - use workspace tools instead
|
||||||
|
"Grep", # Block raw file grep - use workspace tools instead
|
||||||
|
}
|
||||||
|
|
||||||
|
# Dangerous patterns in tool inputs
|
||||||
|
DANGEROUS_PATTERNS = [
|
||||||
|
r"sudo",
|
||||||
|
r"rm\s+-rf",
|
||||||
|
r"dd\s+if=",
|
||||||
|
r"/etc/passwd",
|
||||||
|
r"/etc/shadow",
|
||||||
|
r"chmod\s+777",
|
||||||
|
r"curl\s+.*\|.*sh",
|
||||||
|
r"wget\s+.*\|.*sh",
|
||||||
|
r"eval\s*\(",
|
||||||
|
r"exec\s*\(",
|
||||||
|
r"__import__",
|
||||||
|
r"os\.system",
|
||||||
|
r"subprocess",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_tool_access(tool_name: str, tool_input: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Validate that a tool call is allowed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Empty dict to allow, or dict with hookSpecificOutput to deny
|
||||||
|
"""
|
||||||
|
# Block forbidden tools
|
||||||
|
if tool_name in BLOCKED_TOOLS:
|
||||||
|
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
||||||
|
return {
|
||||||
|
"hookSpecificOutput": {
|
||||||
|
"hookEventName": "PreToolUse",
|
||||||
|
"permissionDecision": "deny",
|
||||||
|
"permissionDecisionReason": (
|
||||||
|
f"Tool '{tool_name}' is not available. "
|
||||||
|
"Use the CoPilot-specific tools instead."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check for dangerous patterns in tool input
|
||||||
|
input_str = str(tool_input)
|
||||||
|
|
||||||
|
for pattern in DANGEROUS_PATTERNS:
|
||||||
|
if re.search(pattern, input_str, re.IGNORECASE):
|
||||||
|
logger.warning(
|
||||||
|
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"hookSpecificOutput": {
|
||||||
|
"hookEventName": "PreToolUse",
|
||||||
|
"permissionDecision": "deny",
|
||||||
|
"permissionDecisionReason": "Input contains blocked pattern",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_user_isolation(
|
||||||
|
tool_name: str, tool_input: dict[str, Any], user_id: str | None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Validate that tool calls respect user isolation."""
|
||||||
|
# For workspace file tools, ensure path doesn't escape
|
||||||
|
if "workspace" in tool_name.lower():
|
||||||
|
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||||
|
if path:
|
||||||
|
# Check for path traversal
|
||||||
|
if ".." in path or path.startswith("/"):
|
||||||
|
logger.warning(
|
||||||
|
f"Blocked path traversal attempt: {path} by user {user_id}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"hookSpecificOutput": {
|
||||||
|
"hookEventName": "PreToolUse",
|
||||||
|
"permissionDecision": "deny",
|
||||||
|
"permissionDecisionReason": "Path traversal not allowed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def create_security_hooks(user_id: str | None) -> dict[str, Any]:
|
||||||
|
"""Create the security hooks configuration for Claude Agent SDK.
|
||||||
|
|
||||||
|
Includes security validation and observability hooks:
|
||||||
|
- PreToolUse: Security validation before tool execution
|
||||||
|
- PostToolUse: Log successful tool executions
|
||||||
|
- PostToolUseFailure: Log and handle failed tool executions
|
||||||
|
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Current user ID for isolation validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hooks configuration dict for ClaudeAgentOptions
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from claude_agent_sdk import HookMatcher
|
||||||
|
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||||
|
|
||||||
|
async def pre_tool_use_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Combined pre-tool-use validation hook."""
|
||||||
|
_ = context # unused but required by signature
|
||||||
|
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||||
|
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||||
|
|
||||||
|
# Strip MCP prefix for consistent validation
|
||||||
|
clean_name = tool_name.removeprefix("mcp__copilot__")
|
||||||
|
|
||||||
|
# Validate basic tool access
|
||||||
|
result = _validate_tool_access(clean_name, tool_input)
|
||||||
|
if result:
|
||||||
|
return cast(SyncHookJSONOutput, result)
|
||||||
|
|
||||||
|
# Validate user isolation
|
||||||
|
result = _validate_user_isolation(clean_name, tool_input, user_id)
|
||||||
|
if result:
|
||||||
|
return cast(SyncHookJSONOutput, result)
|
||||||
|
|
||||||
|
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
async def post_tool_use_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Log successful tool executions for observability."""
|
||||||
|
_ = context
|
||||||
|
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||||
|
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
async def post_tool_failure_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Log failed tool executions for debugging."""
|
||||||
|
_ = context
|
||||||
|
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||||
|
error = input_data.get("error", "Unknown error")
|
||||||
|
logger.warning(
|
||||||
|
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
||||||
|
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||||
|
)
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
async def pre_compact_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Log when SDK triggers context compaction.
|
||||||
|
|
||||||
|
The SDK automatically compacts conversation history when it grows too large.
|
||||||
|
This hook provides visibility into when compaction happens.
|
||||||
|
"""
|
||||||
|
_ = context, tool_use_id
|
||||||
|
trigger = input_data.get("trigger", "auto")
|
||||||
|
logger.info(
|
||||||
|
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
||||||
|
)
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
||||||
|
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
||||||
|
"PostToolUseFailure": [
|
||||||
|
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
|
||||||
|
],
|
||||||
|
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||||
|
}
|
||||||
|
except ImportError:
|
||||||
|
# Fallback for when SDK isn't available - return empty hooks
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def create_strict_security_hooks(
|
||||||
|
user_id: str | None,
|
||||||
|
allowed_tools: list[str] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create strict security hooks that only allow specific tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Current user ID
|
||||||
|
allowed_tools: List of allowed tool names (defaults to CoPilot tools)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hooks configuration dict
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from claude_agent_sdk import HookMatcher
|
||||||
|
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||||
|
|
||||||
|
from .tool_adapter import RAW_TOOL_NAMES
|
||||||
|
|
||||||
|
tools_list = allowed_tools if allowed_tools is not None else RAW_TOOL_NAMES
|
||||||
|
allowed_set = set(tools_list)
|
||||||
|
|
||||||
|
async def strict_pre_tool_use(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Strict validation that only allows whitelisted tools."""
|
||||||
|
_ = context # unused but required by signature
|
||||||
|
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||||
|
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||||
|
|
||||||
|
# Remove MCP prefix if present
|
||||||
|
clean_name = tool_name.removeprefix("mcp__copilot__")
|
||||||
|
|
||||||
|
if clean_name not in allowed_set:
|
||||||
|
logger.warning(f"Blocked non-whitelisted tool: {tool_name}")
|
||||||
|
return cast(
|
||||||
|
SyncHookJSONOutput,
|
||||||
|
{
|
||||||
|
"hookSpecificOutput": {
|
||||||
|
"hookEventName": "PreToolUse",
|
||||||
|
"permissionDecision": "deny",
|
||||||
|
"permissionDecisionReason": (
|
||||||
|
f"Tool '{tool_name}' is not in the allowed list"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run standard validations using clean_name for consistent checks
|
||||||
|
result = _validate_tool_access(clean_name, tool_input)
|
||||||
|
if result:
|
||||||
|
return cast(SyncHookJSONOutput, result)
|
||||||
|
|
||||||
|
result = _validate_user_isolation(clean_name, tool_input, user_id)
|
||||||
|
if result:
|
||||||
|
return cast(SyncHookJSONOutput, result)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[SDK Audit] Tool call: tool={tool_name}, "
|
||||||
|
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||||
|
)
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"PreToolUse": [
|
||||||
|
HookMatcher(matcher="*", hooks=[strict_pre_tool_use]),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
except ImportError:
|
||||||
|
return {}
|
||||||
@@ -0,0 +1,475 @@
|
|||||||
|
"""Claude Agent SDK service layer for CoPilot chat completions."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from backend.data.understanding import (
|
||||||
|
format_understanding_for_prompt,
|
||||||
|
get_business_understanding,
|
||||||
|
)
|
||||||
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
|
from ..config import ChatConfig
|
||||||
|
from ..model import (
|
||||||
|
ChatMessage,
|
||||||
|
ChatSession,
|
||||||
|
get_chat_session,
|
||||||
|
update_session_title,
|
||||||
|
upsert_chat_session,
|
||||||
|
)
|
||||||
|
from ..response_model import (
|
||||||
|
StreamBaseResponse,
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamStart,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
)
|
||||||
|
from ..tracking import track_user_message
|
||||||
|
from .anthropic_fallback import stream_with_anthropic
|
||||||
|
from .response_adapter import SDKResponseAdapter
|
||||||
|
from .security_hooks import create_security_hooks
|
||||||
|
from .tool_adapter import (
|
||||||
|
COPILOT_TOOL_NAMES,
|
||||||
|
create_copilot_mcp_server,
|
||||||
|
set_execution_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
config = ChatConfig()
|
||||||
|
|
||||||
|
# Set to hold background tasks to prevent garbage collection
|
||||||
|
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||||
|
|
||||||
|
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
|
||||||
|
|
||||||
|
Here is everything you know about the current user from previous interactions:
|
||||||
|
|
||||||
|
<users_information>
|
||||||
|
{users_information}
|
||||||
|
</users_information>
|
||||||
|
|
||||||
|
## YOUR CORE MANDATE
|
||||||
|
|
||||||
|
You are action-oriented. Your success is measured by:
|
||||||
|
- **Value Delivery**: Does the user think "wow, that was amazing" or "what was the point"?
|
||||||
|
- **Demonstrable Proof**: Show working automations, not descriptions of what's possible
|
||||||
|
- **Time Saved**: Focus on tangible efficiency gains
|
||||||
|
- **Quality Output**: Deliver results that meet or exceed expectations
|
||||||
|
|
||||||
|
## YOUR WORKFLOW
|
||||||
|
|
||||||
|
Adapt flexibly to the conversation context. Not every interaction requires all stages:
|
||||||
|
|
||||||
|
1. **Explore & Understand**: Learn about the user's business, tasks, and goals. Use `add_understanding` to capture important context that will improve future conversations.
|
||||||
|
|
||||||
|
2. **Assess Automation Potential**: Help the user understand whether and how AI can automate their task.
|
||||||
|
|
||||||
|
3. **Prepare for AI**: Provide brief, actionable guidance on prerequisites (data, access, etc.).
|
||||||
|
|
||||||
|
4. **Discover or Create Agents**:
|
||||||
|
- **Always check the user's library first** with `find_library_agent` (these may be customized to their needs)
|
||||||
|
- Search the marketplace with `find_agent` for pre-built automations
|
||||||
|
- Find reusable components with `find_block`
|
||||||
|
- Create custom solutions with `create_agent` if nothing suitable exists
|
||||||
|
- Modify existing library agents with `edit_agent`
|
||||||
|
|
||||||
|
5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`.
|
||||||
|
|
||||||
|
6. **Show Results**: Display outputs using `agent_output`.
|
||||||
|
|
||||||
|
## BEHAVIORAL GUIDELINES
|
||||||
|
|
||||||
|
**Be Concise:**
|
||||||
|
- Target 2-5 short lines maximum
|
||||||
|
- Make every word count—no repetition or filler
|
||||||
|
- Use lightweight structure for scannability (bullets, numbered lists, short prompts)
|
||||||
|
- Avoid jargon (blocks, slugs, cron) unless the user asks
|
||||||
|
|
||||||
|
**Be Proactive:**
|
||||||
|
- Suggest next steps before being asked
|
||||||
|
- Anticipate needs based on conversation context and user information
|
||||||
|
- Look for opportunities to expand scope when relevant
|
||||||
|
- Reveal capabilities through action, not explanation
|
||||||
|
|
||||||
|
**Use Tools Effectively:**
|
||||||
|
- Select the right tool for each task
|
||||||
|
- **Always check `find_library_agent` before searching the marketplace**
|
||||||
|
- Use `add_understanding` to capture valuable business context
|
||||||
|
- When tool calls fail, try alternative approaches
|
||||||
|
|
||||||
|
## CRITICAL REMINDER
|
||||||
|
|
||||||
|
You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation."""
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_system_prompt(
|
||||||
|
user_id: str | None, has_conversation_history: bool = False
|
||||||
|
) -> tuple[str, Any]:
|
||||||
|
"""Build the system prompt with user's business understanding context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID to fetch understanding for.
|
||||||
|
has_conversation_history: Whether there's existing conversation history.
|
||||||
|
If True, we don't tell the model to greet/introduce (since they're
|
||||||
|
already in a conversation).
|
||||||
|
"""
|
||||||
|
understanding = None
|
||||||
|
if user_id:
|
||||||
|
try:
|
||||||
|
understanding = await get_business_understanding(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch business understanding: {e}")
|
||||||
|
|
||||||
|
if understanding:
|
||||||
|
context = format_understanding_for_prompt(understanding)
|
||||||
|
elif has_conversation_history:
|
||||||
|
# Don't tell model to greet if there's conversation history
|
||||||
|
context = "No prior understanding saved yet. Continue the existing conversation naturally."
|
||||||
|
else:
|
||||||
|
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||||
|
|
||||||
|
return DEFAULT_SYSTEM_PROMPT.replace("{users_information}", context), understanding
|
||||||
|
|
||||||
|
|
||||||
|
def _format_conversation_history(session: ChatSession) -> str:
|
||||||
|
"""Format conversation history as a prompt context.
|
||||||
|
|
||||||
|
The SDK handles context compaction automatically, but we apply
|
||||||
|
max_context_messages as a safety guard to limit initial prompt size.
|
||||||
|
"""
|
||||||
|
if not session.messages:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Get all messages except the last user message (which will be the prompt)
|
||||||
|
messages = session.messages[:-1] if session.messages else []
|
||||||
|
if not messages:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Apply max_context_messages limit as a safety guard
|
||||||
|
# (SDK handles compaction, but this prevents excessively large initial prompts)
|
||||||
|
max_messages = config.max_context_messages
|
||||||
|
if len(messages) > max_messages:
|
||||||
|
messages = messages[-max_messages:]
|
||||||
|
|
||||||
|
history_parts = ["<conversation_history>"]
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
if msg.role == "user":
|
||||||
|
history_parts.append(f"User: {msg.content or ''}")
|
||||||
|
elif msg.role == "assistant":
|
||||||
|
# Pass full content - SDK handles compaction automatically
|
||||||
|
history_parts.append(f"Assistant: {msg.content or ''}")
|
||||||
|
if msg.tool_calls:
|
||||||
|
for tc in msg.tool_calls:
|
||||||
|
func = tc.get("function", {})
|
||||||
|
history_parts.append(
|
||||||
|
f" [Called tool: {func.get('name', 'unknown')}]"
|
||||||
|
)
|
||||||
|
elif msg.role == "tool":
|
||||||
|
# Truncate large tool results to avoid blowing context window
|
||||||
|
tool_content = msg.content or ""
|
||||||
|
if len(tool_content) > 500:
|
||||||
|
tool_content = tool_content[:500] + "... (truncated)"
|
||||||
|
history_parts.append(f" [Tool result: {tool_content}]")
|
||||||
|
|
||||||
|
history_parts.append("</conversation_history>")
|
||||||
|
history_parts.append("")
|
||||||
|
history_parts.append(
|
||||||
|
"Continue this conversation. Respond to the user's latest message:"
|
||||||
|
)
|
||||||
|
history_parts.append("")
|
||||||
|
|
||||||
|
return "\n".join(history_parts)
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_session_title(
|
||||||
|
message: str,
|
||||||
|
user_id: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
"""Generate a concise title for a chat session."""
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
try:
|
||||||
|
# Build extra_body for OpenRouter tracing
|
||||||
|
extra_body: dict[str, Any] = {
|
||||||
|
"posthogProperties": {"environment": settings.config.app_env.value},
|
||||||
|
}
|
||||||
|
if user_id:
|
||||||
|
extra_body["user"] = user_id[:128]
|
||||||
|
extra_body["posthogDistinctId"] = user_id
|
||||||
|
if session_id:
|
||||||
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
|
client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model=config.title_model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Generate a very short title (3-6 words) for a chat conversation based on the user's first message. Return ONLY the title, no quotes or punctuation.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": message[:500]},
|
||||||
|
],
|
||||||
|
max_tokens=20,
|
||||||
|
extra_body=extra_body,
|
||||||
|
)
|
||||||
|
title = response.choices[0].message.content
|
||||||
|
if title:
|
||||||
|
title = title.strip().strip("\"'")
|
||||||
|
return title[:47] + "..." if len(title) > 50 else title
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to generate session title: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_chat_completion_sdk(
|
||||||
|
session_id: str,
|
||||||
|
message: str | None = None,
|
||||||
|
tool_call_response: str | None = None, # noqa: ARG001
|
||||||
|
is_user_message: bool = True,
|
||||||
|
user_id: str | None = None,
|
||||||
|
retry_count: int = 0, # noqa: ARG001
|
||||||
|
session: ChatSession | None = None,
|
||||||
|
context: dict[str, str] | None = None, # noqa: ARG001
|
||||||
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
|
"""Stream chat completion using Claude Agent SDK.
|
||||||
|
|
||||||
|
Drop-in replacement for stream_chat_completion with improved reliability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
raise NotFoundError(
|
||||||
|
f"Session {session_id} not found. Please create a new session first."
|
||||||
|
)
|
||||||
|
|
||||||
|
if message:
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="user" if is_user_message else "assistant", content=message
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if is_user_message:
|
||||||
|
track_user_message(
|
||||||
|
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||||
|
)
|
||||||
|
|
||||||
|
session = await upsert_chat_session(session)
|
||||||
|
|
||||||
|
# Generate title for new sessions (first user message)
|
||||||
|
if is_user_message and not session.title:
|
||||||
|
user_messages = [m for m in session.messages if m.role == "user"]
|
||||||
|
if len(user_messages) == 1:
|
||||||
|
first_message = user_messages[0].content or message or ""
|
||||||
|
if first_message:
|
||||||
|
task = asyncio.create_task(
|
||||||
|
_update_title_async(session_id, first_message, user_id)
|
||||||
|
)
|
||||||
|
# Store reference to prevent garbage collection
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
|
# Check if there's conversation history (more than just the current message)
|
||||||
|
has_history = len(session.messages) > 1
|
||||||
|
system_prompt, _ = await _build_system_prompt(
|
||||||
|
user_id, has_conversation_history=has_history
|
||||||
|
)
|
||||||
|
set_execution_context(user_id, session, None)
|
||||||
|
|
||||||
|
message_id = str(uuid.uuid4())
|
||||||
|
text_block_id = str(uuid.uuid4())
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
yield StreamStart(messageId=message_id, taskId=task_id)
|
||||||
|
|
||||||
|
# Track whether the stream completed normally via ResultMessage
|
||||||
|
stream_completed = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||||
|
|
||||||
|
# Create MCP server with CoPilot tools
|
||||||
|
mcp_server = create_copilot_mcp_server()
|
||||||
|
|
||||||
|
options = ClaudeAgentOptions(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
mcp_servers={"copilot": mcp_server}, # type: ignore[arg-type]
|
||||||
|
allowed_tools=COPILOT_TOOL_NAMES,
|
||||||
|
hooks=create_security_hooks(user_id), # type: ignore[arg-type]
|
||||||
|
continue_conversation=True, # Enable conversation continuation
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter = SDKResponseAdapter(message_id=message_id)
|
||||||
|
adapter.set_task_id(task_id)
|
||||||
|
|
||||||
|
async with ClaudeSDKClient(options=options) as client:
|
||||||
|
# Build prompt with conversation history for context
|
||||||
|
# The SDK doesn't support replaying full conversation history,
|
||||||
|
# so we include it as context in the prompt
|
||||||
|
current_message = message or ""
|
||||||
|
if not current_message and session.messages:
|
||||||
|
last_user = [m for m in session.messages if m.role == "user"]
|
||||||
|
if last_user:
|
||||||
|
current_message = last_user[-1].content or ""
|
||||||
|
|
||||||
|
# Include conversation history if there are prior messages
|
||||||
|
if len(session.messages) > 1:
|
||||||
|
history_context = _format_conversation_history(session)
|
||||||
|
prompt = f"{history_context}{current_message}"
|
||||||
|
else:
|
||||||
|
prompt = current_message
|
||||||
|
|
||||||
|
# Guard against empty prompts
|
||||||
|
if not prompt.strip():
|
||||||
|
yield StreamError(
|
||||||
|
errorText="Message cannot be empty.",
|
||||||
|
code="empty_prompt",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
return
|
||||||
|
|
||||||
|
await client.query(prompt, session_id=session_id)
|
||||||
|
|
||||||
|
# Track assistant response to save to session
|
||||||
|
# We may need multiple assistant messages if text comes after tool results
|
||||||
|
assistant_response = ChatMessage(role="assistant", content="")
|
||||||
|
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||||
|
has_appended_assistant = False
|
||||||
|
has_tool_results = False # Track if we've received tool results
|
||||||
|
|
||||||
|
# Receive messages from the SDK
|
||||||
|
async for sdk_msg in client.receive_messages():
|
||||||
|
for response in adapter.convert_message(sdk_msg):
|
||||||
|
if isinstance(response, StreamStart):
|
||||||
|
continue
|
||||||
|
yield response
|
||||||
|
|
||||||
|
# Accumulate text deltas into assistant response
|
||||||
|
if isinstance(response, StreamTextDelta):
|
||||||
|
delta = response.delta or ""
|
||||||
|
# After tool results, create new assistant message for post-tool text
|
||||||
|
if has_tool_results and has_appended_assistant:
|
||||||
|
assistant_response = ChatMessage(
|
||||||
|
role="assistant", content=delta
|
||||||
|
)
|
||||||
|
accumulated_tool_calls = [] # Reset for new message
|
||||||
|
session.messages.append(assistant_response)
|
||||||
|
has_tool_results = False
|
||||||
|
else:
|
||||||
|
assistant_response.content = (
|
||||||
|
assistant_response.content or ""
|
||||||
|
) + delta
|
||||||
|
if not has_appended_assistant:
|
||||||
|
session.messages.append(assistant_response)
|
||||||
|
has_appended_assistant = True
|
||||||
|
|
||||||
|
# Track tool calls on the assistant message
|
||||||
|
elif isinstance(response, StreamToolInputAvailable):
|
||||||
|
accumulated_tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": response.toolCallId,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": response.toolName,
|
||||||
|
"arguments": json.dumps(response.input or {}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Update assistant message with tool calls
|
||||||
|
assistant_response.tool_calls = accumulated_tool_calls
|
||||||
|
# Append assistant message if not already (tool-only response)
|
||||||
|
if not has_appended_assistant:
|
||||||
|
session.messages.append(assistant_response)
|
||||||
|
has_appended_assistant = True
|
||||||
|
|
||||||
|
elif isinstance(response, StreamToolOutputAvailable):
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="tool",
|
||||||
|
content=(
|
||||||
|
response.output
|
||||||
|
if isinstance(response.output, str)
|
||||||
|
else str(response.output)
|
||||||
|
),
|
||||||
|
tool_call_id=response.toolCallId,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
has_tool_results = True
|
||||||
|
|
||||||
|
elif isinstance(response, StreamFinish):
|
||||||
|
stream_completed = True
|
||||||
|
|
||||||
|
# Break out of the message loop if we received finish signal
|
||||||
|
if stream_completed:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Ensure assistant response is saved even if no text deltas
|
||||||
|
# (e.g., only tool calls were made)
|
||||||
|
if (
|
||||||
|
assistant_response.content or assistant_response.tool_calls
|
||||||
|
) and not has_appended_assistant:
|
||||||
|
session.messages.append(assistant_response)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"[SDK] claude-agent-sdk not available, using Anthropic fallback"
|
||||||
|
)
|
||||||
|
async for response in stream_with_anthropic(
|
||||||
|
session, system_prompt, text_block_id
|
||||||
|
):
|
||||||
|
if isinstance(response, StreamFinish):
|
||||||
|
stream_completed = True
|
||||||
|
yield response
|
||||||
|
|
||||||
|
# Save the session with accumulated messages
|
||||||
|
await upsert_chat_session(session)
|
||||||
|
logger.debug(
|
||||||
|
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
|
||||||
|
)
|
||||||
|
# Yield StreamFinish to signal completion to the caller (routes.py)
|
||||||
|
# Only if one hasn't already been yielded by the stream
|
||||||
|
if not stream_completed:
|
||||||
|
yield StreamFinish()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
||||||
|
# Save session even on error to preserve any partial response
|
||||||
|
try:
|
||||||
|
await upsert_chat_session(session)
|
||||||
|
except Exception as save_err:
|
||||||
|
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
||||||
|
# Sanitize error message to avoid exposing internal details
|
||||||
|
yield StreamError(
|
||||||
|
errorText="An error occurred. Please try again.",
|
||||||
|
code="sdk_error",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_title_async(
|
||||||
|
session_id: str, message: str, user_id: str | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Background task to update session title."""
|
||||||
|
try:
|
||||||
|
title = await _generate_session_title(
|
||||||
|
message, user_id=user_id, session_id=session_id
|
||||||
|
)
|
||||||
|
if title:
|
||||||
|
await update_session_title(session_id, title)
|
||||||
|
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[SDK] Failed to update session title: {e}")
|
||||||
@@ -0,0 +1,217 @@
|
|||||||
|
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
|
||||||
|
|
||||||
|
This module provides the adapter layer that converts existing BaseTool implementations
|
||||||
|
into in-process MCP tools that can be used with the Claude Agent SDK.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tools import TOOL_REGISTRY
|
||||||
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Context variables to pass user/session info to tool execution
|
||||||
|
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
||||||
|
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||||
|
"current_session", default=None
|
||||||
|
)
|
||||||
|
_current_tool_call_id: ContextVar[str | None] = ContextVar(
|
||||||
|
"current_tool_call_id", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_execution_context(
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
tool_call_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set the execution context for tool calls.
|
||||||
|
|
||||||
|
This must be called before streaming begins to ensure tools have access
|
||||||
|
to user_id and session information.
|
||||||
|
"""
|
||||||
|
_current_user_id.set(user_id)
|
||||||
|
_current_session.set(session)
|
||||||
|
_current_tool_call_id.set(tool_call_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_execution_context() -> tuple[str | None, ChatSession | None, str | None]:
|
||||||
|
"""Get the current execution context."""
|
||||||
|
return (
|
||||||
|
_current_user_id.get(),
|
||||||
|
_current_session.get(),
|
||||||
|
_current_tool_call_id.get(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_tool_handler(base_tool: BaseTool):
|
||||||
|
"""Create an async handler function for a BaseTool.
|
||||||
|
|
||||||
|
This wraps the existing BaseTool._execute method to be compatible
|
||||||
|
with the Claude Agent SDK MCP tool format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Execute the wrapped tool and return MCP-formatted response."""
|
||||||
|
user_id, session, tool_call_id = get_execution_context()
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": json.dumps(
|
||||||
|
{
|
||||||
|
"error": "No session context available",
|
||||||
|
"type": "error",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call the existing tool's execute method
|
||||||
|
# Generate unique tool_call_id per invocation for proper correlation
|
||||||
|
effective_id = tool_call_id or f"sdk-{uuid.uuid4().hex[:12]}"
|
||||||
|
result = await base_tool.execute(
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
tool_call_id=effective_id,
|
||||||
|
**args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The result is a StreamToolOutputAvailable, extract the output
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": (
|
||||||
|
result.output
|
||||||
|
if isinstance(result.output, str)
|
||||||
|
else json.dumps(result.output)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"isError": not result.success,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": json.dumps(
|
||||||
|
{
|
||||||
|
"error": str(e),
|
||||||
|
"type": "error",
|
||||||
|
"message": f"Failed to execute {base_tool.name}",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
return tool_handler
|
||||||
|
|
||||||
|
|
||||||
|
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
||||||
|
"""Build a JSON Schema input schema for a tool."""
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": base_tool.parameters.get("properties", {}),
|
||||||
|
"required": base_tool.parameters.get("required", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_definitions() -> list[dict[str, Any]]:
|
||||||
|
"""Get all tool definitions in MCP format.
|
||||||
|
|
||||||
|
Returns a list of tool definitions that can be used with
|
||||||
|
create_sdk_mcp_server or as raw tool definitions.
|
||||||
|
"""
|
||||||
|
tool_definitions = []
|
||||||
|
|
||||||
|
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||||
|
tool_def = {
|
||||||
|
"name": tool_name,
|
||||||
|
"description": base_tool.description,
|
||||||
|
"inputSchema": _build_input_schema(base_tool),
|
||||||
|
}
|
||||||
|
tool_definitions.append(tool_def)
|
||||||
|
|
||||||
|
return tool_definitions
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_handlers() -> dict[str, Any]:
|
||||||
|
"""Get all tool handlers mapped by name.
|
||||||
|
|
||||||
|
Returns a dictionary mapping tool names to their handler functions.
|
||||||
|
"""
|
||||||
|
handlers = {}
|
||||||
|
|
||||||
|
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||||
|
handlers[tool_name] = create_tool_handler(base_tool)
|
||||||
|
|
||||||
|
return handlers
|
||||||
|
|
||||||
|
|
||||||
|
# Create the MCP server configuration
|
||||||
|
def create_copilot_mcp_server():
|
||||||
|
"""Create an in-process MCP server configuration for CoPilot tools.
|
||||||
|
|
||||||
|
This can be passed to ClaudeAgentOptions.mcp_servers.
|
||||||
|
|
||||||
|
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
|
||||||
|
package being available. This function returns the configuration that
|
||||||
|
can be used with the SDK.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||||
|
|
||||||
|
# Create decorated tool functions
|
||||||
|
sdk_tools = []
|
||||||
|
|
||||||
|
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||||
|
# Get the handler
|
||||||
|
handler = create_tool_handler(base_tool)
|
||||||
|
|
||||||
|
# Create the decorated tool
|
||||||
|
# The @tool decorator expects (name, description, schema)
|
||||||
|
# Pass full JSON schema with type, properties, and required
|
||||||
|
decorated = tool(
|
||||||
|
tool_name,
|
||||||
|
base_tool.description,
|
||||||
|
_build_input_schema(base_tool),
|
||||||
|
)(handler)
|
||||||
|
|
||||||
|
sdk_tools.append(decorated)
|
||||||
|
|
||||||
|
# Create the MCP server
|
||||||
|
server = create_sdk_mcp_server(
|
||||||
|
name="copilot",
|
||||||
|
version="1.0.0",
|
||||||
|
tools=sdk_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
return server
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# Let ImportError propagate so service.py handles the fallback
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# List of tool names for allowed_tools configuration
|
||||||
|
COPILOT_TOOL_NAMES = [f"mcp__copilot__{name}" for name in TOOL_REGISTRY.keys()]
|
||||||
|
|
||||||
|
# Also export the raw tool names for flexibility
|
||||||
|
RAW_TOOL_NAMES = list(TOOL_REGISTRY.keys())
|
||||||
@@ -371,21 +371,45 @@ async def stream_chat_completion(
|
|||||||
ValueError: If max_context_messages is exceeded
|
ValueError: If max_context_messages is exceeded
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
completion_start = time.monotonic()
|
||||||
|
|
||||||
|
# Build log metadata for structured logging
|
||||||
|
log_meta = {"component": "ChatService", "session_id": session_id}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
|
||||||
|
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"message_len": len(message) if message else 0,
|
||||||
|
"is_user_message": is_user_message,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only fetch from Redis if session not provided (initial call)
|
# Only fetch from Redis if session not provided (initial call)
|
||||||
if session is None:
|
if session is None:
|
||||||
|
fetch_start = time.monotonic()
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
fetch_time = (time.monotonic() - fetch_start) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Fetched session from Redis: {session.session_id if session else 'None'}, "
|
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
|
||||||
f"message_count={len(session.messages) if session else 0}"
|
f"n_messages={len(session.messages) if session else 0}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": fetch_time,
|
||||||
|
"n_messages": len(session.messages) if session else 0,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using provided session object: {session.session_id}, "
|
f"[TIMING] Using provided session, messages={len(session.messages)}",
|
||||||
f"message_count={len(session.messages)}"
|
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||||
)
|
)
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
@@ -406,17 +430,25 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
# Track user message in PostHog
|
# Track user message in PostHog
|
||||||
if is_user_message:
|
if is_user_message:
|
||||||
|
posthog_start = time.monotonic()
|
||||||
track_user_message(
|
track_user_message(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
message_length=len(message),
|
message_length=len(message),
|
||||||
)
|
)
|
||||||
|
posthog_time = (time.monotonic() - posthog_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
upsert_start = time.monotonic()
|
||||||
f"Upserting session: {session.session_id} with user id {session.user_id}, "
|
|
||||||
f"message_count={len(session.messages)}"
|
|
||||||
)
|
|
||||||
session = await upsert_chat_session(session)
|
session = await upsert_chat_session(session)
|
||||||
|
upsert_time = (time.monotonic() - upsert_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
|
||||||
|
)
|
||||||
assert session, "Session not found"
|
assert session, "Session not found"
|
||||||
|
|
||||||
# Generate title for new sessions on first user message (non-blocking)
|
# Generate title for new sessions on first user message (non-blocking)
|
||||||
@@ -454,7 +486,13 @@ async def stream_chat_completion(
|
|||||||
asyncio.create_task(_update_title())
|
asyncio.create_task(_update_title())
|
||||||
|
|
||||||
# Build system prompt with business understanding
|
# Build system prompt with business understanding
|
||||||
|
prompt_start = time.monotonic()
|
||||||
system_prompt, understanding = await _build_system_prompt(user_id)
|
system_prompt, understanding = await _build_system_prompt(user_id)
|
||||||
|
prompt_time = (time.monotonic() - prompt_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize variables for streaming
|
# Initialize variables for streaming
|
||||||
assistant_response = ChatMessage(
|
assistant_response = ChatMessage(
|
||||||
@@ -483,9 +521,18 @@ async def stream_chat_completion(
|
|||||||
text_block_id = str(uuid_module.uuid4())
|
text_block_id = str(uuid_module.uuid4())
|
||||||
|
|
||||||
# Yield message start
|
# Yield message start
|
||||||
|
setup_time = (time.monotonic() - completion_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||||
|
)
|
||||||
yield StreamStart(messageId=message_id)
|
yield StreamStart(messageId=message_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] Calling _stream_chat_chunks",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
async for chunk in _stream_chat_chunks(
|
async for chunk in _stream_chat_chunks(
|
||||||
session=session,
|
session=session,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@@ -893,9 +940,21 @@ async def _stream_chat_chunks(
|
|||||||
SSE formatted JSON response objects
|
SSE formatted JSON response objects
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import time as time_module
|
||||||
|
|
||||||
|
stream_chunks_start = time_module.perf_counter()
|
||||||
model = config.model
|
model = config.model
|
||||||
|
|
||||||
logger.info("Starting pure chat stream")
|
# Build log metadata for structured logging
|
||||||
|
log_meta = {"component": "ChatService", "session_id": session.session_id}
|
||||||
|
if session.user_id:
|
||||||
|
log_meta["user_id"] = session.user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
|
||||||
|
f"user={session.user_id}, n_messages={len(session.messages)}",
|
||||||
|
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||||
|
)
|
||||||
|
|
||||||
messages = session.to_openai_messages()
|
messages = session.to_openai_messages()
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
@@ -906,12 +965,18 @@ async def _stream_chat_chunks(
|
|||||||
messages = [system_message] + messages
|
messages = [system_message] + messages
|
||||||
|
|
||||||
# Apply context window management
|
# Apply context window management
|
||||||
|
context_start = time_module.perf_counter()
|
||||||
context_result = await _manage_context_window(
|
context_result = await _manage_context_window(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=model,
|
model=model,
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
base_url=config.base_url,
|
base_url=config.base_url,
|
||||||
)
|
)
|
||||||
|
context_time = (time_module.perf_counter() - context_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
|
||||||
|
)
|
||||||
|
|
||||||
if context_result.error:
|
if context_result.error:
|
||||||
if "System prompt dropped" in context_result.error:
|
if "System prompt dropped" in context_result.error:
|
||||||
@@ -946,9 +1011,19 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
while retry_count <= MAX_RETRIES:
|
while retry_count <= MAX_RETRIES:
|
||||||
try:
|
try:
|
||||||
|
elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||||
|
retry_info = (
|
||||||
|
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Creating OpenAI chat completion stream..."
|
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
|
||||||
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"retry_count": retry_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build extra_body for OpenRouter tracing and PostHog analytics
|
# Build extra_body for OpenRouter tracing and PostHog analytics
|
||||||
@@ -965,6 +1040,7 @@ async def _stream_chat_chunks(
|
|||||||
:128
|
:128
|
||||||
] # OpenRouter limit
|
] # OpenRouter limit
|
||||||
|
|
||||||
|
api_call_start = time_module.perf_counter()
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
@@ -974,6 +1050,11 @@ async def _stream_chat_chunks(
|
|||||||
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
|
||||||
|
)
|
||||||
|
|
||||||
# Variables to accumulate tool calls
|
# Variables to accumulate tool calls
|
||||||
tool_calls: list[dict[str, Any]] = []
|
tool_calls: list[dict[str, Any]] = []
|
||||||
@@ -984,10 +1065,13 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
# Track if we've started the text block
|
# Track if we've started the text block
|
||||||
text_started = False
|
text_started = False
|
||||||
|
first_content_chunk = True
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
# Process the stream
|
# Process the stream
|
||||||
chunk: ChatCompletionChunk
|
chunk: ChatCompletionChunk
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
|
chunk_count += 1
|
||||||
if chunk.usage:
|
if chunk.usage:
|
||||||
yield StreamUsage(
|
yield StreamUsage(
|
||||||
promptTokens=chunk.usage.prompt_tokens,
|
promptTokens=chunk.usage.prompt_tokens,
|
||||||
@@ -1010,6 +1094,23 @@ async def _stream_chat_chunks(
|
|||||||
if not text_started and text_block_id:
|
if not text_started and text_block_id:
|
||||||
yield StreamTextStart(id=text_block_id)
|
yield StreamTextStart(id=text_block_id)
|
||||||
text_started = True
|
text_started = True
|
||||||
|
# Log timing for first content chunk
|
||||||
|
if first_content_chunk:
|
||||||
|
first_content_chunk = False
|
||||||
|
ttfc = (
|
||||||
|
time_module.perf_counter() - api_call_start
|
||||||
|
) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
|
||||||
|
f"(since API call), n_chunks={chunk_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"time_to_first_chunk_ms": ttfc,
|
||||||
|
"n_chunks": chunk_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
# Stream the text delta
|
# Stream the text delta
|
||||||
text_response = StreamTextDelta(
|
text_response = StreamTextDelta(
|
||||||
id=text_block_id or "",
|
id=text_block_id or "",
|
||||||
@@ -1066,7 +1167,21 @@ async def _stream_chat_chunks(
|
|||||||
toolName=tool_calls[idx]["function"]["name"],
|
toolName=tool_calls[idx]["function"]["name"],
|
||||||
)
|
)
|
||||||
emitted_start_for_idx.add(idx)
|
emitted_start_for_idx.add(idx)
|
||||||
logger.info(f"Stream complete. Finish reason: {finish_reason}")
|
stream_duration = time_module.perf_counter() - api_call_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
|
||||||
|
f"duration={stream_duration:.2f}s, "
|
||||||
|
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"stream_duration_ms": stream_duration * 1000,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
"n_chunks": chunk_count,
|
||||||
|
"n_tool_calls": len(tool_calls),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Yield all accumulated tool calls after the stream is complete
|
# Yield all accumulated tool calls after the stream is complete
|
||||||
# This ensures all tool call arguments have been fully received
|
# This ensures all tool call arguments have been fully received
|
||||||
@@ -1086,6 +1201,12 @@ async def _stream_chat_chunks(
|
|||||||
# Re-raise to trigger retry logic in the parent function
|
# Re-raise to trigger retry logic in the parent function
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
|
||||||
|
f"session={session.session_id}, user={session.user_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||||
|
)
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -104,6 +104,24 @@ async def create_task(
|
|||||||
Returns:
|
Returns:
|
||||||
The created ActiveTask instance (metadata only)
|
The created ActiveTask instance (metadata only)
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Build log metadata for structured logging
|
||||||
|
log_meta = {
|
||||||
|
"component": "StreamRegistry",
|
||||||
|
"task_id": task_id,
|
||||||
|
"session_id": session_id,
|
||||||
|
}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
|
||||||
task = ActiveTask(
|
task = ActiveTask(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -114,10 +132,18 @@ async def create_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Store metadata in Redis
|
# Store metadata in Redis
|
||||||
|
redis_start = time.perf_counter()
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
|
redis_time = (time.perf_counter() - redis_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
|
||||||
|
)
|
||||||
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
meta_key = _get_task_meta_key(task_id)
|
||||||
op_key = _get_operation_mapping_key(operation_id)
|
op_key = _get_operation_mapping_key(operation_id)
|
||||||
|
|
||||||
|
hset_start = time.perf_counter()
|
||||||
await redis.hset( # type: ignore[misc]
|
await redis.hset( # type: ignore[misc]
|
||||||
meta_key,
|
meta_key,
|
||||||
mapping={
|
mapping={
|
||||||
@@ -131,12 +157,22 @@ async def create_task(
|
|||||||
"created_at": task.created_at.isoformat(),
|
"created_at": task.created_at.isoformat(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
hset_time = (time.perf_counter() - hset_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] redis.hset took {hset_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
|
||||||
|
)
|
||||||
|
|
||||||
await redis.expire(meta_key, config.stream_ttl)
|
await redis.expire(meta_key, config.stream_ttl)
|
||||||
|
|
||||||
# Create operation_id -> task_id mapping for webhook lookups
|
# Create operation_id -> task_id mapping for webhook lookups
|
||||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||||
|
|
||||||
logger.debug(f"Created task {task_id} for session {session_id}")
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||||
|
)
|
||||||
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
@@ -156,26 +192,60 @@ async def publish_chunk(
|
|||||||
Returns:
|
Returns:
|
||||||
The Redis Stream message ID
|
The Redis Stream message ID
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
chunk_type = type(chunk).__name__
|
||||||
chunk_json = chunk.model_dump_json()
|
chunk_json = chunk.model_dump_json()
|
||||||
message_id = "0-0"
|
message_id = "0-0"
|
||||||
|
|
||||||
|
# Build log metadata
|
||||||
|
log_meta = {
|
||||||
|
"component": "StreamRegistry",
|
||||||
|
"task_id": task_id,
|
||||||
|
"chunk_type": chunk_type,
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
# Write to Redis Stream for persistence and real-time delivery
|
# Write to Redis Stream for persistence and real-time delivery
|
||||||
|
xadd_start = time.perf_counter()
|
||||||
raw_id = await redis.xadd(
|
raw_id = await redis.xadd(
|
||||||
stream_key,
|
stream_key,
|
||||||
{"data": chunk_json},
|
{"data": chunk_json},
|
||||||
maxlen=config.stream_max_length,
|
maxlen=config.stream_max_length,
|
||||||
)
|
)
|
||||||
|
xadd_time = (time.perf_counter() - xadd_start) * 1000
|
||||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||||
|
|
||||||
# Set TTL on stream to match task metadata TTL
|
# Set TTL on stream to match task metadata TTL
|
||||||
await redis.expire(stream_key, config.stream_ttl)
|
await redis.expire(stream_key, config.stream_ttl)
|
||||||
|
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
# Only log timing for significant chunks or slow operations
|
||||||
|
if (
|
||||||
|
chunk_type
|
||||||
|
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
|
||||||
|
or total_time > 50
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"xadd_time_ms": xadd_time,
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to publish chunk for task {task_id}: {e}",
|
f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}",
|
||||||
|
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -200,24 +270,61 @@ async def subscribe_to_task(
|
|||||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||||
or user doesn't have access
|
or user doesn't have access
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Build log metadata
|
||||||
|
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_start = time.perf_counter()
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
meta_key = _get_task_meta_key(task_id)
|
meta_key = _get_task_meta_key(task_id)
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
|
hgetall_time = (time.perf_counter() - redis_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
||||||
|
)
|
||||||
|
|
||||||
if not meta:
|
if not meta:
|
||||||
logger.debug(f"Task {task_id} not found in Redis")
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"reason": "task_not_found",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||||
task_status = meta.get("status", "")
|
task_status = meta.get("status", "")
|
||||||
task_user_id = meta.get("user_id", "") or None
|
task_user_id = meta.get("user_id", "") or None
|
||||||
|
log_meta["session_id"] = meta.get("session_id", "")
|
||||||
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
# Validate ownership - if task has an owner, requester must match
|
||||||
if task_user_id:
|
if task_user_id:
|
||||||
if user_id != task_user_id:
|
if user_id != task_user_id:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"User {user_id} denied access to task {task_id} "
|
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}",
|
||||||
f"owned by {task_user_id}"
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"task_owner": task_user_id,
|
||||||
|
"reason": "access_denied",
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -225,7 +332,19 @@ async def subscribe_to_task(
|
|||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
# Step 1: Replay messages from Redis Stream
|
# Step 1: Replay messages from Redis Stream
|
||||||
|
xread_start = time.perf_counter()
|
||||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||||
|
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": xread_time,
|
||||||
|
"task_status": task_status,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
replayed_count = 0
|
replayed_count = 0
|
||||||
replay_last_id = last_message_id
|
replay_last_id = last_message_id
|
||||||
@@ -244,19 +363,48 @@ async def subscribe_to_task(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to replay message: {e}")
|
logger.warning(f"Failed to replay message: {e}")
|
||||||
|
|
||||||
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
|
logger.info(
|
||||||
|
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"n_messages_replayed": replayed_count,
|
||||||
|
"replay_last_id": replay_last_id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Step 2: If task is still running, start stream listener for live updates
|
# Step 2: If task is still running, start stream listener for live updates
|
||||||
if task_status == "running":
|
if task_status == "running":
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] Task still running, starting _stream_listener",
|
||||||
|
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||||
|
)
|
||||||
listener_task = asyncio.create_task(
|
listener_task = asyncio.create_task(
|
||||||
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
|
||||||
)
|
)
|
||||||
# Track listener task for cleanup on unsubscribe
|
# Track listener task for cleanup on unsubscribe
|
||||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||||
else:
|
else:
|
||||||
# Task is completed/failed - add finish marker
|
# Task is completed/failed - add finish marker
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Task already {task_status}, adding StreamFinish",
|
||||||
|
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||||
|
)
|
||||||
await subscriber_queue.put(StreamFinish())
|
await subscriber_queue.put(StreamFinish())
|
||||||
|
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
|
||||||
|
f"n_messages_replayed={replayed_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"n_messages_replayed": replayed_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
return subscriber_queue
|
return subscriber_queue
|
||||||
|
|
||||||
|
|
||||||
@@ -264,6 +412,7 @@ async def _stream_listener(
|
|||||||
task_id: str,
|
task_id: str,
|
||||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||||
last_replayed_id: str,
|
last_replayed_id: str,
|
||||||
|
log_meta: dict | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||||
|
|
||||||
@@ -274,10 +423,27 @@ async def _stream_listener(
|
|||||||
task_id: Task ID to listen for
|
task_id: Task ID to listen for
|
||||||
subscriber_queue: Queue to deliver messages to
|
subscriber_queue: Queue to deliver messages to
|
||||||
last_replayed_id: Last message ID from replay (continue from here)
|
last_replayed_id: Last message ID from replay (continue from here)
|
||||||
|
log_meta: Structured logging metadata
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Use provided log_meta or build minimal one
|
||||||
|
if log_meta is None:
|
||||||
|
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
|
||||||
|
)
|
||||||
|
|
||||||
queue_id = id(subscriber_queue)
|
queue_id = id(subscriber_queue)
|
||||||
# Track the last successfully delivered message ID for recovery hints
|
# Track the last successfully delivered message ID for recovery hints
|
||||||
last_delivered_id = last_replayed_id
|
last_delivered_id = last_replayed_id
|
||||||
|
messages_delivered = 0
|
||||||
|
first_message_time = None
|
||||||
|
xread_count = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
@@ -287,9 +453,39 @@ async def _stream_listener(
|
|||||||
while True:
|
while True:
|
||||||
# Block for up to 30 seconds waiting for new messages
|
# Block for up to 30 seconds waiting for new messages
|
||||||
# This allows periodic checking if task is still running
|
# This allows periodic checking if task is still running
|
||||||
|
xread_start = time.perf_counter()
|
||||||
|
xread_count += 1
|
||||||
messages = await redis.xread(
|
messages = await redis.xread(
|
||||||
{stream_key: current_id}, block=30000, count=100
|
{stream_key: current_id}, block=30000, count=100
|
||||||
)
|
)
|
||||||
|
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||||
|
|
||||||
|
if messages:
|
||||||
|
msg_count = sum(len(msgs) for _, msgs in messages)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"xread_count": xread_count,
|
||||||
|
"n_messages": msg_count,
|
||||||
|
"duration_ms": xread_time,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif xread_time > 1000:
|
||||||
|
# Only log timeouts (30s blocking)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"xread_count": xread_count,
|
||||||
|
"duration_ms": xread_time,
|
||||||
|
"reason": "timeout",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if not messages:
|
if not messages:
|
||||||
# Timeout - check if task is still running
|
# Timeout - check if task is still running
|
||||||
@@ -326,10 +522,30 @@ async def _stream_listener(
|
|||||||
)
|
)
|
||||||
# Update last delivered ID on successful delivery
|
# Update last delivered ID on successful delivery
|
||||||
last_delivered_id = current_id
|
last_delivered_id = current_id
|
||||||
|
messages_delivered += 1
|
||||||
|
if first_message_time is None:
|
||||||
|
first_message_time = time.perf_counter()
|
||||||
|
elapsed = (first_message_time - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"chunk_type": type(chunk).__name__,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Subscriber queue full for task {task_id}, "
|
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
|
||||||
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"timeout_s": QUEUE_PUT_TIMEOUT,
|
||||||
|
"reason": "queue_full",
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
# Send overflow error with recovery info
|
# Send overflow error with recovery info
|
||||||
try:
|
try:
|
||||||
@@ -351,15 +567,44 @@ async def _stream_listener(
|
|||||||
|
|
||||||
# Stop listening on finish
|
# Stop listening on finish
|
||||||
if isinstance(chunk, StreamFinish):
|
if isinstance(chunk, StreamFinish):
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"messages_delivered": messages_delivered,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error processing stream message: {e}")
|
logger.warning(
|
||||||
|
f"Error processing stream message: {e}",
|
||||||
|
extra={"json_fields": {**log_meta, "error": str(e)}},
|
||||||
|
)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.debug(f"Stream listener cancelled for task {task_id}")
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"messages_delivered": messages_delivered,
|
||||||
|
"reason": "cancelled",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
raise # Re-raise to propagate cancellation
|
raise # Re-raise to propagate cancellation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Stream listener error for task {task_id}: {e}")
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.error(
|
||||||
|
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
|
||||||
|
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
||||||
|
)
|
||||||
# On error, send finish to unblock subscriber
|
# On error, send finish to unblock subscriber
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
@@ -368,10 +613,24 @@ async def _stream_listener(
|
|||||||
)
|
)
|
||||||
except (asyncio.TimeoutError, asyncio.QueueFull):
|
except (asyncio.TimeoutError, asyncio.QueueFull):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Could not deliver finish event for task {task_id} after error"
|
"Could not deliver finish event after error",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
# Clean up listener task mapping on exit
|
# Clean up listener task mapping on exit
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
|
||||||
|
f"delivered={messages_delivered}, xread_count={xread_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"messages_delivered": messages_delivered,
|
||||||
|
"xread_count": xread_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
_listener_tasks.pop(queue_id, None)
|
_listener_tasks.pop(queue_id, None)
|
||||||
|
|
||||||
|
|
||||||
@@ -555,6 +814,10 @@ async def get_active_task_for_session(
|
|||||||
if task_user_id and user_id != task_user_id:
|
if task_user_id and user_id != task_user_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
||||||
|
)
|
||||||
|
|
||||||
# Get the last message ID from Redis Stream
|
# Get the last message ID from Redis Stream
|
||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
last_id = "0-0"
|
last_id = "0-0"
|
||||||
|
|||||||
@@ -13,10 +13,32 @@ from backend.api.features.chat.tools.models import (
|
|||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
from backend.data.block import get_block
|
from backend.data.block import BlockType, get_block
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_TARGET_RESULTS = 10
|
||||||
|
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
|
||||||
|
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
|
||||||
|
_OVERFETCH_PAGE_SIZE = 40
|
||||||
|
|
||||||
|
# Block types that only work within graphs and cannot run standalone in CoPilot.
|
||||||
|
COPILOT_EXCLUDED_BLOCK_TYPES = {
|
||||||
|
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
|
||||||
|
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
|
||||||
|
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
|
||||||
|
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
|
||||||
|
BlockType.NOTE, # Visual annotation only - no runtime behavior
|
||||||
|
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
|
||||||
|
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
|
||||||
|
}
|
||||||
|
|
||||||
|
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
||||||
|
COPILOT_EXCLUDED_BLOCK_IDS = {
|
||||||
|
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
|
||||||
|
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class FindBlockTool(BaseTool):
|
class FindBlockTool(BaseTool):
|
||||||
"""Tool for searching available blocks."""
|
"""Tool for searching available blocks."""
|
||||||
@@ -88,7 +110,7 @@ class FindBlockTool(BaseTool):
|
|||||||
query=query,
|
query=query,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
page_size=10,
|
page_size=_OVERFETCH_PAGE_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
@@ -108,60 +130,90 @@ class FindBlockTool(BaseTool):
|
|||||||
block = get_block(block_id)
|
block = get_block(block_id)
|
||||||
|
|
||||||
# Skip disabled blocks
|
# Skip disabled blocks
|
||||||
if block and not block.disabled:
|
if not block or block.disabled:
|
||||||
# Get input/output schemas
|
continue
|
||||||
input_schema = {}
|
|
||||||
output_schema = {}
|
|
||||||
try:
|
|
||||||
input_schema = block.input_schema.jsonschema()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
output_schema = block.output_schema.jsonschema()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Get categories from block instance
|
# Skip blocks excluded from CoPilot (graph-only blocks)
|
||||||
categories = []
|
if (
|
||||||
if hasattr(block, "categories") and block.categories:
|
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
categories = [cat.value for cat in block.categories]
|
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
# Extract required inputs for easier use
|
# Get input/output schemas
|
||||||
required_inputs: list[BlockInputFieldInfo] = []
|
input_schema = {}
|
||||||
if input_schema:
|
output_schema = {}
|
||||||
properties = input_schema.get("properties", {})
|
try:
|
||||||
required_fields = set(input_schema.get("required", []))
|
input_schema = block.input_schema.jsonschema()
|
||||||
# Get credential field names to exclude from required inputs
|
except Exception as e:
|
||||||
credentials_fields = set(
|
logger.debug(
|
||||||
block.input_schema.get_credentials_fields().keys()
|
"Failed to generate input schema for block %s: %s",
|
||||||
)
|
block_id,
|
||||||
|
e,
|
||||||
for field_name, field_schema in properties.items():
|
|
||||||
# Skip credential fields - they're handled separately
|
|
||||||
if field_name in credentials_fields:
|
|
||||||
continue
|
|
||||||
|
|
||||||
required_inputs.append(
|
|
||||||
BlockInputFieldInfo(
|
|
||||||
name=field_name,
|
|
||||||
type=field_schema.get("type", "string"),
|
|
||||||
description=field_schema.get("description", ""),
|
|
||||||
required=field_name in required_fields,
|
|
||||||
default=field_schema.get("default"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
blocks.append(
|
|
||||||
BlockInfoSummary(
|
|
||||||
id=block_id,
|
|
||||||
name=block.name,
|
|
||||||
description=block.description or "",
|
|
||||||
categories=categories,
|
|
||||||
input_schema=input_schema,
|
|
||||||
output_schema=output_schema,
|
|
||||||
required_inputs=required_inputs,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
output_schema = block.output_schema.jsonschema()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
"Failed to generate output schema for block %s: %s",
|
||||||
|
block_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get categories from block instance
|
||||||
|
categories = []
|
||||||
|
if hasattr(block, "categories") and block.categories:
|
||||||
|
categories = [cat.value for cat in block.categories]
|
||||||
|
|
||||||
|
# Extract required inputs for easier use
|
||||||
|
required_inputs: list[BlockInputFieldInfo] = []
|
||||||
|
if input_schema:
|
||||||
|
properties = input_schema.get("properties", {})
|
||||||
|
required_fields = set(input_schema.get("required", []))
|
||||||
|
# Get credential field names to exclude from required inputs
|
||||||
|
credentials_fields = set(
|
||||||
|
block.input_schema.get_credentials_fields().keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
for field_name, field_schema in properties.items():
|
||||||
|
# Skip credential fields - they're handled separately
|
||||||
|
if field_name in credentials_fields:
|
||||||
|
continue
|
||||||
|
|
||||||
|
required_inputs.append(
|
||||||
|
BlockInputFieldInfo(
|
||||||
|
name=field_name,
|
||||||
|
type=field_schema.get("type", "string"),
|
||||||
|
description=field_schema.get("description", ""),
|
||||||
|
required=field_name in required_fields,
|
||||||
|
default=field_schema.get("default"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
blocks.append(
|
||||||
|
BlockInfoSummary(
|
||||||
|
id=block_id,
|
||||||
|
name=block.name,
|
||||||
|
description=block.description or "",
|
||||||
|
categories=categories,
|
||||||
|
input_schema=input_schema,
|
||||||
|
output_schema=output_schema,
|
||||||
|
required_inputs=required_inputs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(blocks) >= _TARGET_RESULTS:
|
||||||
|
break
|
||||||
|
|
||||||
|
if blocks and len(blocks) < _TARGET_RESULTS:
|
||||||
|
logger.debug(
|
||||||
|
"find_block returned %d/%d results for query '%s' "
|
||||||
|
"(filtered %d excluded/disabled blocks)",
|
||||||
|
len(blocks),
|
||||||
|
_TARGET_RESULTS,
|
||||||
|
query,
|
||||||
|
len(results) - len(blocks),
|
||||||
|
)
|
||||||
|
|
||||||
if not blocks:
|
if not blocks:
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
|
|||||||
@@ -0,0 +1,139 @@
|
|||||||
|
"""Tests for block filtering in FindBlockTool."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.find_block import (
|
||||||
|
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||||
|
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||||
|
FindBlockTool,
|
||||||
|
)
|
||||||
|
from backend.api.features.chat.tools.models import BlockListResponse
|
||||||
|
from backend.data.block import BlockType
|
||||||
|
|
||||||
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
_TEST_USER_ID = "test-user-find-block"
|
||||||
|
|
||||||
|
|
||||||
|
def make_mock_block(
|
||||||
|
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||||
|
):
|
||||||
|
"""Create a mock block for testing."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.id = block_id
|
||||||
|
mock.name = name
|
||||||
|
mock.description = f"{name} description"
|
||||||
|
mock.block_type = block_type
|
||||||
|
mock.disabled = disabled
|
||||||
|
mock.input_schema = MagicMock()
|
||||||
|
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||||
|
mock.input_schema.get_credentials_fields.return_value = {}
|
||||||
|
mock.output_schema = MagicMock()
|
||||||
|
mock.output_schema.jsonschema.return_value = {}
|
||||||
|
mock.categories = []
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindBlockFiltering:
|
||||||
|
"""Tests for block filtering in FindBlockTool."""
|
||||||
|
|
||||||
|
def test_excluded_block_types_contains_expected_types(self):
|
||||||
|
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
|
||||||
|
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
|
||||||
|
def test_excluded_block_ids_contains_smart_decision_maker(self):
|
||||||
|
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
||||||
|
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_type_filtered_from_results(self):
|
||||||
|
"""Verify blocks with excluded BlockTypes are filtered from search results."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
|
||||||
|
search_results = [
|
||||||
|
{"content_id": "input-block-id", "score": 0.9},
|
||||||
|
{"content_id": "standard-block-id", "score": 0.8},
|
||||||
|
]
|
||||||
|
|
||||||
|
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||||
|
standard_block = make_mock_block(
|
||||||
|
"standard-block-id", "HTTP Request", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_get_block(block_id):
|
||||||
|
return {
|
||||||
|
"input-block-id": input_block,
|
||||||
|
"standard-block-id": standard_block,
|
||||||
|
}.get(block_id)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=(search_results, 2),
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.get_block",
|
||||||
|
side_effect=mock_get_block,
|
||||||
|
):
|
||||||
|
tool = FindBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID, session=session, query="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only return the standard block, not the INPUT block
|
||||||
|
assert isinstance(response, BlockListResponse)
|
||||||
|
assert len(response.blocks) == 1
|
||||||
|
assert response.blocks[0].id == "standard-block-id"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_id_filtered_from_results(self):
|
||||||
|
"""Verify SmartDecisionMakerBlock is filtered from search results."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||||
|
search_results = [
|
||||||
|
{"content_id": smart_decision_id, "score": 0.9},
|
||||||
|
{"content_id": "normal-block-id", "score": 0.8},
|
||||||
|
]
|
||||||
|
|
||||||
|
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
|
||||||
|
smart_block = make_mock_block(
|
||||||
|
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
normal_block = make_mock_block(
|
||||||
|
"normal-block-id", "Normal Block", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_get_block(block_id):
|
||||||
|
return {
|
||||||
|
smart_decision_id: smart_block,
|
||||||
|
"normal-block-id": normal_block,
|
||||||
|
}.get(block_id)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=(search_results, 2),
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.get_block",
|
||||||
|
side_effect=mock_get_block,
|
||||||
|
):
|
||||||
|
tool = FindBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID, session=session, query="decision"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only return normal block, not SmartDecisionMakerBlock
|
||||||
|
assert isinstance(response, BlockListResponse)
|
||||||
|
assert len(response.blocks) == 1
|
||||||
|
assert response.blocks[0].id == "normal-block-id"
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
"""Shared helpers for chat tools."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def get_inputs_from_schema(
|
||||||
|
input_schema: dict[str, Any],
|
||||||
|
exclude_fields: set[str] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Extract input field info from JSON schema."""
|
||||||
|
if not isinstance(input_schema, dict):
|
||||||
|
return []
|
||||||
|
|
||||||
|
exclude = exclude_fields or set()
|
||||||
|
properties = input_schema.get("properties", {})
|
||||||
|
required = set(input_schema.get("required", []))
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"title": schema.get("title", name),
|
||||||
|
"type": schema.get("type", "string"),
|
||||||
|
"description": schema.get("description", ""),
|
||||||
|
"required": name in required,
|
||||||
|
"default": schema.get("default"),
|
||||||
|
}
|
||||||
|
for name, schema in properties.items()
|
||||||
|
if name not in exclude
|
||||||
|
]
|
||||||
@@ -24,6 +24,7 @@ from backend.util.timezone_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
AgentDetails,
|
AgentDetails,
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
@@ -261,7 +262,7 @@ class RunAgentTool(BaseTool):
|
|||||||
),
|
),
|
||||||
requirements={
|
requirements={
|
||||||
"credentials": requirements_creds_list,
|
"credentials": requirements_creds_list,
|
||||||
"inputs": self._get_inputs_list(graph.input_schema),
|
"inputs": get_inputs_from_schema(graph.input_schema),
|
||||||
"execution_modes": self._get_execution_modes(graph),
|
"execution_modes": self._get_execution_modes(graph),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@@ -369,22 +370,6 @@ class RunAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
|
||||||
"""Extract inputs list from schema."""
|
|
||||||
inputs_list = []
|
|
||||||
if isinstance(input_schema, dict) and "properties" in input_schema:
|
|
||||||
for field_name, field_schema in input_schema["properties"].items():
|
|
||||||
inputs_list.append(
|
|
||||||
{
|
|
||||||
"name": field_name,
|
|
||||||
"title": field_schema.get("title", field_name),
|
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
|
||||||
"required": field_name in input_schema.get("required", []),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return inputs_list
|
|
||||||
|
|
||||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||||
"""Get available execution modes for the graph."""
|
"""Get available execution modes for the graph."""
|
||||||
trigger_info = graph.trigger_setup_info
|
trigger_info = graph.trigger_setup_info
|
||||||
@@ -398,7 +383,7 @@ class RunAgentTool(BaseTool):
|
|||||||
suffix: str,
|
suffix: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build a message describing available inputs for an agent."""
|
"""Build a message describing available inputs for an agent."""
|
||||||
inputs_list = self._get_inputs_list(graph.input_schema)
|
inputs_list = get_inputs_from_schema(graph.input_schema)
|
||||||
required_names = [i["name"] for i in inputs_list if i["required"]]
|
required_names = [i["name"] for i in inputs_list if i["required"]]
|
||||||
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
||||||
|
|
||||||
|
|||||||
@@ -8,14 +8,19 @@ from typing import Any
|
|||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.block import get_block
|
from backend.api.features.chat.tools.find_block import (
|
||||||
|
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||||
|
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||||
|
)
|
||||||
|
from backend.data.block import AnyBlockSchema, get_block
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -24,7 +29,10 @@ from .models import (
|
|||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
UserReadiness,
|
UserReadiness,
|
||||||
)
|
)
|
||||||
from .utils import build_missing_credentials_from_field_info
|
from .utils import (
|
||||||
|
build_missing_credentials_from_field_info,
|
||||||
|
match_credentials_to_requirements,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -73,91 +81,6 @@ class RunBlockTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def _check_block_credentials(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
block: Any,
|
|
||||||
input_data: dict[str, Any] | None = None,
|
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
|
||||||
"""
|
|
||||||
Check if user has required credentials for a block.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
block: Block to check credentials for
|
|
||||||
input_data: Input data for the block (used to determine provider via discriminator)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[matched_credentials, missing_credentials]
|
|
||||||
"""
|
|
||||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
|
||||||
missing_credentials: list[CredentialsMetaInput] = []
|
|
||||||
input_data = input_data or {}
|
|
||||||
|
|
||||||
# Get credential field info from block's input schema
|
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
|
||||||
|
|
||||||
if not credentials_fields_info:
|
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
# Get user's available credentials
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
|
||||||
|
|
||||||
for field_name, field_info in credentials_fields_info.items():
|
|
||||||
effective_field_info = field_info
|
|
||||||
if field_info.discriminator and field_info.discriminator_mapping:
|
|
||||||
# Get discriminator from input, falling back to schema default
|
|
||||||
discriminator_value = input_data.get(field_info.discriminator)
|
|
||||||
if discriminator_value is None:
|
|
||||||
field = block.input_schema.model_fields.get(
|
|
||||||
field_info.discriminator
|
|
||||||
)
|
|
||||||
if field and field.default is not PydanticUndefined:
|
|
||||||
discriminator_value = field.default
|
|
||||||
|
|
||||||
if (
|
|
||||||
discriminator_value
|
|
||||||
and discriminator_value in field_info.discriminator_mapping
|
|
||||||
):
|
|
||||||
effective_field_info = field_info.discriminate(discriminator_value)
|
|
||||||
logger.debug(
|
|
||||||
f"Discriminated provider for {field_name}: "
|
|
||||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
|
||||||
)
|
|
||||||
|
|
||||||
matching_cred = next(
|
|
||||||
(
|
|
||||||
cred
|
|
||||||
for cred in available_creds
|
|
||||||
if cred.provider in effective_field_info.provider
|
|
||||||
and cred.type in effective_field_info.supported_types
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if matching_cred:
|
|
||||||
matched_credentials[field_name] = CredentialsMetaInput(
|
|
||||||
id=matching_cred.id,
|
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Create a placeholder for the missing credential
|
|
||||||
provider = next(iter(effective_field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
|
||||||
missing_credentials.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=field_name.replace("_", " ").title(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
@@ -212,11 +135,24 @@ class RunBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if block is excluded from CoPilot (graph-only blocks)
|
||||||
|
if (
|
||||||
|
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||||
|
):
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Block '{block.name}' cannot be run directly in CoPilot. "
|
||||||
|
"This block is designed for use within graphs only."
|
||||||
|
),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||||
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
creds_manager = IntegrationCredentialsManager()
|
||||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
matched_credentials, missing_credentials = (
|
||||||
user_id, block, input_data
|
await self._resolve_block_credentials(user_id, block, input_data)
|
||||||
)
|
)
|
||||||
|
|
||||||
if missing_credentials:
|
if missing_credentials:
|
||||||
@@ -345,29 +281,75 @@ class RunBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
async def _resolve_block_credentials(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
block: AnyBlockSchema,
|
||||||
|
input_data: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Resolve credentials for a block by matching user's available credentials.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
block: Block to resolve credentials for
|
||||||
|
input_data: Input data for the block (used to determine provider via discriminator)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple of (matched_credentials, missing_credentials) - matched credentials
|
||||||
|
are used for block execution, missing ones indicate setup requirements.
|
||||||
|
"""
|
||||||
|
input_data = input_data or {}
|
||||||
|
requirements = self._resolve_discriminated_credentials(block, input_data)
|
||||||
|
|
||||||
|
if not requirements:
|
||||||
|
return {}, []
|
||||||
|
|
||||||
|
return await match_credentials_to_requirements(user_id, requirements)
|
||||||
|
|
||||||
|
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
|
||||||
"""Extract non-credential inputs from block schema."""
|
"""Extract non-credential inputs from block schema."""
|
||||||
inputs_list = []
|
|
||||||
schema = block.input_schema.jsonschema()
|
schema = block.input_schema.jsonschema()
|
||||||
properties = schema.get("properties", {})
|
|
||||||
required_fields = set(schema.get("required", []))
|
|
||||||
|
|
||||||
# Get credential field names to exclude
|
|
||||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||||
|
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||||
|
|
||||||
for field_name, field_schema in properties.items():
|
def _resolve_discriminated_credentials(
|
||||||
# Skip credential fields
|
self,
|
||||||
if field_name in credentials_fields:
|
block: AnyBlockSchema,
|
||||||
continue
|
input_data: dict[str, Any],
|
||||||
|
) -> dict[str, CredentialsFieldInfo]:
|
||||||
|
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||||
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
|
if not credentials_fields_info:
|
||||||
|
return {}
|
||||||
|
|
||||||
inputs_list.append(
|
resolved: dict[str, CredentialsFieldInfo] = {}
|
||||||
{
|
|
||||||
"name": field_name,
|
|
||||||
"title": field_schema.get("title", field_name),
|
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
|
||||||
"required": field_name in required_fields,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs_list
|
for field_name, field_info in credentials_fields_info.items():
|
||||||
|
effective_field_info = field_info
|
||||||
|
|
||||||
|
if field_info.discriminator and field_info.discriminator_mapping:
|
||||||
|
discriminator_value = input_data.get(field_info.discriminator)
|
||||||
|
if discriminator_value is None:
|
||||||
|
field = block.input_schema.model_fields.get(
|
||||||
|
field_info.discriminator
|
||||||
|
)
|
||||||
|
if field and field.default is not PydanticUndefined:
|
||||||
|
discriminator_value = field.default
|
||||||
|
|
||||||
|
if (
|
||||||
|
discriminator_value
|
||||||
|
and discriminator_value in field_info.discriminator_mapping
|
||||||
|
):
|
||||||
|
effective_field_info = field_info.discriminate(discriminator_value)
|
||||||
|
# For host-scoped credentials, add the discriminator value
|
||||||
|
# (e.g., URL) so _credential_is_for_host can match it
|
||||||
|
effective_field_info.discriminator_values.add(discriminator_value)
|
||||||
|
logger.debug(
|
||||||
|
f"Discriminated provider for {field_name}: "
|
||||||
|
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved[field_name] = effective_field_info
|
||||||
|
|
||||||
|
return resolved
|
||||||
|
|||||||
@@ -0,0 +1,106 @@
|
|||||||
|
"""Tests for block execution guards in RunBlockTool."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.models import ErrorResponse
|
||||||
|
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||||
|
from backend.data.block import BlockType
|
||||||
|
|
||||||
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
_TEST_USER_ID = "test-user-run-block"
|
||||||
|
|
||||||
|
|
||||||
|
def make_mock_block(
|
||||||
|
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||||
|
):
|
||||||
|
"""Create a mock block for testing."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.id = block_id
|
||||||
|
mock.name = name
|
||||||
|
mock.block_type = block_type
|
||||||
|
mock.disabled = disabled
|
||||||
|
mock.input_schema = MagicMock()
|
||||||
|
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||||
|
mock.input_schema.get_credentials_fields_info.return_value = []
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunBlockFiltering:
|
||||||
|
"""Tests for block execution guards in RunBlockTool."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_type_returns_error(self):
|
||||||
|
"""Attempting to execute a block with excluded BlockType returns error."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=input_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="input-block-id",
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, ErrorResponse)
|
||||||
|
assert "cannot be run directly in CoPilot" in response.message
|
||||||
|
assert "designed for use within graphs only" in response.message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_id_returns_error(self):
|
||||||
|
"""Attempting to execute SmartDecisionMakerBlock returns error."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||||
|
smart_block = make_mock_block(
|
||||||
|
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=smart_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id=smart_decision_id,
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, ErrorResponse)
|
||||||
|
assert "cannot be run directly in CoPilot" in response.message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_non_excluded_block_passes_guard(self):
|
||||||
|
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
standard_block = make_mock_block(
|
||||||
|
"standard-id", "HTTP Request", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=standard_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="standard-id",
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should NOT be an ErrorResponse about CoPilot exclusion
|
||||||
|
# (may be other errors like missing credentials, but not the exclusion guard)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
assert "cannot be run directly in CoPilot" not in response.message
|
||||||
@@ -8,6 +8,7 @@ from backend.api.features.library import model as library_model
|
|||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
|
Credentials,
|
||||||
CredentialsFieldInfo,
|
CredentialsFieldInfo,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
HostScopedCredentials,
|
HostScopedCredentials,
|
||||||
@@ -223,6 +224,99 @@ async def get_or_create_library_agent(
|
|||||||
return library_agents[0]
|
return library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def match_credentials_to_requirements(
|
||||||
|
user_id: str,
|
||||||
|
requirements: dict[str, CredentialsFieldInfo],
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Match user's credentials against a dictionary of credential requirements.
|
||||||
|
|
||||||
|
This is the core matching logic shared by both graph and block credential matching.
|
||||||
|
"""
|
||||||
|
matched: dict[str, CredentialsMetaInput] = {}
|
||||||
|
missing: list[CredentialsMetaInput] = []
|
||||||
|
|
||||||
|
if not requirements:
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
available_creds = await get_user_credentials(user_id)
|
||||||
|
|
||||||
|
for field_name, field_info in requirements.items():
|
||||||
|
matching_cred = find_matching_credential(available_creds, field_info)
|
||||||
|
|
||||||
|
if matching_cred:
|
||||||
|
try:
|
||||||
|
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
||||||
|
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
||||||
|
f"credential_id={matching_cred.id}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=f"{field_name} (validation failed: {e})",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=field_name.replace("_", " ").title(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_credentials(user_id: str) -> list[Credentials]:
|
||||||
|
"""Get all available credentials for a user."""
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
return await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def find_matching_credential(
|
||||||
|
available_creds: list[Credentials],
|
||||||
|
field_info: CredentialsFieldInfo,
|
||||||
|
) -> Credentials | None:
|
||||||
|
"""Find a credential that matches the required provider, type, scopes, and host."""
|
||||||
|
for cred in available_creds:
|
||||||
|
if cred.provider not in field_info.provider:
|
||||||
|
continue
|
||||||
|
if cred.type not in field_info.supported_types:
|
||||||
|
continue
|
||||||
|
if cred.type == "oauth2" and not _credential_has_required_scopes(
|
||||||
|
cred, field_info
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
|
||||||
|
continue
|
||||||
|
return cred
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_credential_meta_from_match(
|
||||||
|
matching_cred: Credentials,
|
||||||
|
) -> CredentialsMetaInput:
|
||||||
|
"""Create a CredentialsMetaInput from a matched credential."""
|
||||||
|
return CredentialsMetaInput(
|
||||||
|
id=matching_cred.id,
|
||||||
|
provider=matching_cred.provider, # type: ignore
|
||||||
|
type=matching_cred.type,
|
||||||
|
title=matching_cred.title,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def match_user_credentials_to_graph(
|
async def match_user_credentials_to_graph(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
graph: GraphModel,
|
graph: GraphModel,
|
||||||
@@ -331,8 +425,6 @@ def _credential_has_required_scopes(
|
|||||||
# If no scopes are required, any credential matches
|
# If no scopes are required, any credential matches
|
||||||
if not requirements.required_scopes:
|
if not requirements.required_scopes:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check that credential scopes are a superset of required scopes
|
|
||||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -478,7 +478,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
webset = aexa.websets.get(id=input_data.external_id)
|
webset = await aexa.websets.get(id=input_data.external_id)
|
||||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||||
|
|
||||||
yield "webset", webset_result
|
yield "webset", webset_result
|
||||||
@@ -494,7 +494,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
count=input_data.search_count,
|
count=input_data.search_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
webset = aexa.websets.create(
|
webset = await aexa.websets.create(
|
||||||
params=CreateWebsetParameters(
|
params=CreateWebsetParameters(
|
||||||
search=search_params,
|
search=search_params,
|
||||||
external_id=input_data.external_id,
|
external_id=input_data.external_id,
|
||||||
@@ -554,7 +554,7 @@ class ExaUpdateWebsetBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_webset = aexa.websets.update(id=input_data.webset_id, params=payload)
|
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -617,7 +617,7 @@ class ExaListWebsetsBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = aexa.websets.list(
|
response = await aexa.websets.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -678,7 +678,7 @@ class ExaGetWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_webset = aexa.websets.get(id=input_data.webset_id)
|
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -748,7 +748,7 @@ class ExaDeleteWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_webset = aexa.websets.delete(id=input_data.webset_id)
|
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
deleted_webset.status.value
|
deleted_webset.status.value
|
||||||
@@ -798,7 +798,7 @@ class ExaCancelWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_webset = aexa.websets.cancel(id=input_data.webset_id)
|
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
canceled_webset.status.value
|
canceled_webset.status.value
|
||||||
@@ -968,7 +968,7 @@ class ExaPreviewWebsetBlock(Block):
|
|||||||
entity["description"] = input_data.entity_description
|
entity["description"] = input_data.entity_description
|
||||||
payload["entity"] = entity
|
payload["entity"] = entity
|
||||||
|
|
||||||
sdk_preview = aexa.websets.preview(params=payload)
|
sdk_preview = await aexa.websets.preview(params=payload)
|
||||||
|
|
||||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||||
|
|
||||||
@@ -1051,7 +1051,7 @@ class ExaWebsetStatusBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
@@ -1185,7 +1185,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Extract basic info
|
# Extract basic info
|
||||||
webset_id = webset.id
|
webset_id = webset.id
|
||||||
@@ -1211,7 +1211,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
total_items = 0
|
total_items = 0
|
||||||
|
|
||||||
if input_data.include_sample_items and input_data.sample_size > 0:
|
if input_data.include_sample_items and input_data.sample_size > 0:
|
||||||
items_response = aexa.websets.items.list(
|
items_response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
sample_items_data = [
|
sample_items_data = [
|
||||||
@@ -1362,7 +1362,7 @@ class ExaWebsetReadyCheckBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset details
|
# Get webset details
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = aexa.websets.enrichments.create(
|
sdk_enrichment = await aexa.websets.enrichments.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -223,7 +223,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_enrich = aexa.websets.enrichments.get(
|
current_enrich = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=enrichment_id
|
webset_id=input_data.webset_id, id=enrichment_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -234,7 +234,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
|
|
||||||
if current_status in ["completed", "failed", "cancelled"]:
|
if current_status in ["completed", "failed", "cancelled"]:
|
||||||
# Estimate items from webset searches
|
# Estimate items from webset searches
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
for search in webset.searches:
|
for search in webset.searches:
|
||||||
if search.progress:
|
if search.progress:
|
||||||
@@ -329,7 +329,7 @@ class ExaGetEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = aexa.websets.enrichments.get(
|
sdk_enrichment = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -474,7 +474,7 @@ class ExaDeleteEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_enrichment = aexa.websets.enrichments.delete(
|
deleted_enrichment = await aexa.websets.enrichments.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -525,13 +525,13 @@ class ExaCancelEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_enrichment = aexa.websets.enrichments.cancel(
|
canceled_enrichment = await aexa.websets.enrichments.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to estimate how many items were enriched before cancellation
|
# Try to estimate how many items were enriched before cancellation
|
||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
items_response = aexa.websets.items.list(
|
items_response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=100
|
webset_id=input_data.webset_id, limit=100
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
# Create mock SDK import object
|
# Create mock SDK import object
|
||||||
mock_import = MagicMock()
|
mock_import = MagicMock()
|
||||||
@@ -247,7 +247,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
|
imports=MagicMock(create=AsyncMock(return_value=mock_import))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -294,7 +294,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_import = aexa.websets.imports.create(
|
sdk_import = await aexa.websets.imports.create(
|
||||||
params=payload, csv_data=input_data.csv_data
|
params=payload, csv_data=input_data.csv_data
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -360,7 +360,7 @@ class ExaGetImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
|
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
|
||||||
|
|
||||||
import_obj = ImportModel.from_sdk(sdk_import)
|
import_obj = ImportModel.from_sdk(sdk_import)
|
||||||
|
|
||||||
@@ -426,7 +426,7 @@ class ExaListImportsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = aexa.websets.imports.list(
|
response = await aexa.websets.imports.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -474,7 +474,9 @@ class ExaDeleteImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
|
deleted_import = await aexa.websets.imports.delete(
|
||||||
|
import_id=input_data.import_id
|
||||||
|
)
|
||||||
|
|
||||||
yield "import_id", deleted_import.id
|
yield "import_id", deleted_import.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -573,14 +575,14 @@ class ExaExportWebsetBlock(Block):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create mock iterator
|
# Create async iterator for list_all
|
||||||
mock_items = [mock_item1, mock_item2]
|
async def async_item_iterator(*args, **kwargs):
|
||||||
|
for item in [mock_item1, mock_item2]:
|
||||||
|
yield item
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(items=MagicMock(list_all=async_item_iterator))
|
||||||
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -602,7 +604,7 @@ class ExaExportWebsetBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
for sdk_item in item_iterator:
|
async for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class ExaGetWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_item = aexa.websets.items.get(
|
sdk_item = await aexa.websets.items.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -269,7 +269,7 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
response = None
|
response = None
|
||||||
|
|
||||||
while time.time() - start_time < input_data.wait_timeout:
|
while time.time() - start_time < input_data.wait_timeout:
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -282,13 +282,13 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
interval = min(interval * 1.2, 10)
|
interval = min(interval * 1.2, 10)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -340,7 +340,7 @@ class ExaDeleteWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_item = aexa.websets.items.delete(
|
deleted_item = await aexa.websets.items.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -408,7 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
for sdk_item in item_iterator:
|
async for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -475,7 +475,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
entity_type = "unknown"
|
entity_type = "unknown"
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
@@ -495,7 +495,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Get sample items if requested
|
# Get sample items if requested
|
||||||
sample_items: List[WebsetItemModel] = []
|
sample_items: List[WebsetItemModel] = []
|
||||||
if input_data.sample_size > 0:
|
if input_data.sample_size > 0:
|
||||||
items_response = aexa.websets.items.list(
|
items_response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
# Convert to our stable models
|
# Convert to our stable models
|
||||||
@@ -569,7 +569,7 @@ class ExaGetNewItemsBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get items starting from cursor
|
# Get items starting from cursor
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.since_cursor,
|
cursor=input_data.since_cursor,
|
||||||
limit=input_data.max_items,
|
limit=input_data.max_items,
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
# Create mock SDK monitor object
|
# Create mock SDK monitor object
|
||||||
mock_monitor = MagicMock()
|
mock_monitor = MagicMock()
|
||||||
@@ -263,7 +263,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
|
monitors=MagicMock(create=AsyncMock(return_value=mock_monitor))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -320,7 +320,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = aexa.websets.monitors.create(params=payload)
|
sdk_monitor = await aexa.websets.monitors.create(params=payload)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -384,7 +384,7 @@ class ExaGetMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -476,7 +476,7 @@ class ExaUpdateMonitorBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = aexa.websets.monitors.update(
|
sdk_monitor = await aexa.websets.monitors.update(
|
||||||
monitor_id=input_data.monitor_id, params=payload
|
monitor_id=input_data.monitor_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -522,7 +522,9 @@ class ExaDeleteMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
|
deleted_monitor = await aexa.websets.monitors.delete(
|
||||||
|
monitor_id=input_data.monitor_id
|
||||||
|
)
|
||||||
|
|
||||||
yield "monitor_id", deleted_monitor.id
|
yield "monitor_id", deleted_monitor.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -579,7 +581,7 @@ class ExaListMonitorsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = aexa.websets.monitors.list(
|
response = await aexa.websets.monitors.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
WebsetTargetStatus.IDLE,
|
WebsetTargetStatus.IDLE,
|
||||||
WebsetTargetStatus.ANY_COMPLETE,
|
WebsetTargetStatus.ANY_COMPLETE,
|
||||||
]:
|
]:
|
||||||
final_webset = aexa.websets.wait_until_idle(
|
final_webset = await aexa.websets.wait_until_idle(
|
||||||
id=input_data.webset_id,
|
id=input_data.webset_id,
|
||||||
timeout=input_data.timeout,
|
timeout=input_data.timeout,
|
||||||
poll_interval=input_data.check_interval,
|
poll_interval=input_data.check_interval,
|
||||||
@@ -164,7 +164,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
interval = input_data.check_interval
|
interval = input_data.check_interval
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current webset status
|
# Get current webset status
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
current_status = (
|
current_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -209,7 +209,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
|
|
||||||
# Timeout reached
|
# Timeout reached
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
final_status = (
|
final_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -345,7 +345,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current search status using SDK
|
# Get current search status using SDK
|
||||||
search = aexa.websets.searches.get(
|
search = await aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -401,7 +401,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
search = aexa.websets.searches.get(
|
search = await aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -503,7 +503,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current enrichment status using SDK
|
# Get current enrichment status using SDK
|
||||||
enrichment = aexa.websets.enrichments.get(
|
enrichment = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -548,7 +548,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
enrichment = aexa.websets.enrichments.get(
|
enrichment = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -575,7 +575,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
) -> tuple[list[SampleEnrichmentModel], int]:
|
) -> tuple[list[SampleEnrichmentModel], int]:
|
||||||
"""Get sample enriched data and count."""
|
"""Get sample enriched data and count."""
|
||||||
# Get a few items to see enrichment results using SDK
|
# Get a few items to see enrichment results using SDK
|
||||||
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
|
response = await aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||||
|
|
||||||
sample_data: list[SampleEnrichmentModel] = []
|
sample_data: list[SampleEnrichmentModel] = []
|
||||||
enriched_count = 0
|
enriched_count = 0
|
||||||
|
|||||||
@@ -317,7 +317,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
|
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = aexa.websets.searches.create(
|
sdk_search = await aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -350,7 +350,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
poll_start = time.time()
|
poll_start = time.time()
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_search = aexa.websets.searches.get(
|
current_search = await aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=search_id
|
webset_id=input_data.webset_id, id=search_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -442,7 +442,7 @@ class ExaGetWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = aexa.websets.searches.get(
|
sdk_search = await aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -523,7 +523,7 @@ class ExaCancelWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_search = aexa.websets.searches.cancel(
|
canceled_search = await aexa.websets.searches.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -604,7 +604,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset to check existing searches
|
# Get webset to check existing searches
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Look for existing search with same query
|
# Look for existing search with same query
|
||||||
existing_search = None
|
existing_search = None
|
||||||
@@ -636,7 +636,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
if input_data.entity_type != SearchEntityType.AUTO:
|
if input_data.entity_type != SearchEntityType.AUTO:
|
||||||
payload["entity"] = {"type": input_data.entity_type.value}
|
payload["entity"] = {"type": input_data.entity_type.value}
|
||||||
|
|
||||||
sdk_search = aexa.websets.searches.create(
|
sdk_search = await aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -531,12 +531,12 @@ class LLMResponse(BaseModel):
|
|||||||
|
|
||||||
def convert_openai_tool_fmt_to_anthropic(
|
def convert_openai_tool_fmt_to_anthropic(
|
||||||
openai_tools: list[dict] | None = None,
|
openai_tools: list[dict] | None = None,
|
||||||
) -> Iterable[ToolParam] | anthropic.NotGiven:
|
) -> Iterable[ToolParam] | anthropic.Omit:
|
||||||
"""
|
"""
|
||||||
Convert OpenAI tool format to Anthropic tool format.
|
Convert OpenAI tool format to Anthropic tool format.
|
||||||
"""
|
"""
|
||||||
if not openai_tools or len(openai_tools) == 0:
|
if not openai_tools or len(openai_tools) == 0:
|
||||||
return anthropic.NOT_GIVEN
|
return anthropic.omit
|
||||||
|
|
||||||
anthropic_tools = []
|
anthropic_tools = []
|
||||||
for tool in openai_tools:
|
for tool in openai_tools:
|
||||||
@@ -596,10 +596,10 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
|||||||
|
|
||||||
def get_parallel_tool_calls_param(
|
def get_parallel_tool_calls_param(
|
||||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||||
):
|
) -> bool | openai.Omit:
|
||||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||||
return openai.NOT_GIVEN
|
return openai.omit
|
||||||
return parallel_tool_calls
|
return parallel_tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import queue
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from multiprocessing import Manager
|
|
||||||
from queue import Empty
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Annotated,
|
Annotated,
|
||||||
@@ -1200,12 +1199,16 @@ class NodeExecutionEntry(BaseModel):
|
|||||||
|
|
||||||
class ExecutionQueue(Generic[T]):
|
class ExecutionQueue(Generic[T]):
|
||||||
"""
|
"""
|
||||||
Queue for managing the execution of agents.
|
Thread-safe queue for managing node execution within a single graph execution.
|
||||||
This will be shared between different processes
|
|
||||||
|
Note: Uses queue.Queue (not multiprocessing.Queue) since all access is from
|
||||||
|
threads within the same process. If migrating back to ProcessPoolExecutor,
|
||||||
|
replace with multiprocessing.Manager().Queue() for cross-process safety.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.queue = Manager().Queue()
|
# Thread-safe queue (not multiprocessing) — see class docstring
|
||||||
|
self.queue: queue.Queue[T] = queue.Queue()
|
||||||
|
|
||||||
def add(self, execution: T) -> T:
|
def add(self, execution: T) -> T:
|
||||||
self.queue.put(execution)
|
self.queue.put(execution)
|
||||||
@@ -1220,7 +1223,7 @@ class ExecutionQueue(Generic[T]):
|
|||||||
def get_or_none(self) -> T | None:
|
def get_or_none(self) -> T | None:
|
||||||
try:
|
try:
|
||||||
return self.queue.get_nowait()
|
return self.queue.get_nowait()
|
||||||
except Empty:
|
except queue.Empty:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,58 @@
|
|||||||
|
"""Tests for ExecutionQueue thread-safety."""
|
||||||
|
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from backend.data.execution import ExecutionQueue
|
||||||
|
|
||||||
|
|
||||||
|
def test_execution_queue_uses_stdlib_queue():
|
||||||
|
"""Verify ExecutionQueue uses queue.Queue (not multiprocessing)."""
|
||||||
|
q = ExecutionQueue()
|
||||||
|
assert isinstance(q.queue, queue.Queue)
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_operations():
|
||||||
|
"""Test add, get, empty, and get_or_none."""
|
||||||
|
q = ExecutionQueue()
|
||||||
|
|
||||||
|
assert q.empty() is True
|
||||||
|
assert q.get_or_none() is None
|
||||||
|
|
||||||
|
result = q.add("item1")
|
||||||
|
assert result == "item1"
|
||||||
|
assert q.empty() is False
|
||||||
|
|
||||||
|
item = q.get()
|
||||||
|
assert item == "item1"
|
||||||
|
assert q.empty() is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_thread_safety():
|
||||||
|
"""Test concurrent access from multiple threads."""
|
||||||
|
q = ExecutionQueue()
|
||||||
|
results = []
|
||||||
|
num_items = 100
|
||||||
|
|
||||||
|
def producer():
|
||||||
|
for i in range(num_items):
|
||||||
|
q.add(f"item_{i}")
|
||||||
|
|
||||||
|
def consumer():
|
||||||
|
count = 0
|
||||||
|
while count < num_items:
|
||||||
|
item = q.get_or_none()
|
||||||
|
if item is not None:
|
||||||
|
results.append(item)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
producer_thread = threading.Thread(target=producer)
|
||||||
|
consumer_thread = threading.Thread(target=consumer)
|
||||||
|
|
||||||
|
producer_thread.start()
|
||||||
|
consumer_thread.start()
|
||||||
|
|
||||||
|
producer_thread.join(timeout=5)
|
||||||
|
consumer_thread.join(timeout=5)
|
||||||
|
|
||||||
|
assert len(results) == num_items
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -225,6 +226,10 @@ class SyncRabbitMQ(RabbitMQBase):
|
|||||||
class AsyncRabbitMQ(RabbitMQBase):
|
class AsyncRabbitMQ(RabbitMQBase):
|
||||||
"""Asynchronous RabbitMQ client"""
|
"""Asynchronous RabbitMQ client"""
|
||||||
|
|
||||||
|
def __init__(self, config: RabbitMQConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self._reconnect_lock: asyncio.Lock | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return bool(self._connection and not self._connection.is_closed)
|
return bool(self._connection and not self._connection.is_closed)
|
||||||
@@ -235,7 +240,17 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
|
|
||||||
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
|
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
if self.is_connected:
|
if self.is_connected and self._channel and not self._channel.is_closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.is_connected
|
||||||
|
and self._connection
|
||||||
|
and (self._channel is None or self._channel.is_closed)
|
||||||
|
):
|
||||||
|
self._channel = await self._connection.channel()
|
||||||
|
await self._channel.set_qos(prefetch_count=1)
|
||||||
|
await self.declare_infrastructure()
|
||||||
return
|
return
|
||||||
|
|
||||||
self._connection = await aio_pika.connect_robust(
|
self._connection = await aio_pika.connect_robust(
|
||||||
@@ -291,24 +306,46 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
exchange, routing_key=queue.routing_key or queue.name
|
exchange, routing_key=queue.routing_key or queue.name
|
||||||
)
|
)
|
||||||
|
|
||||||
@func_retry
|
@property
|
||||||
async def publish_message(
|
def _lock(self) -> asyncio.Lock:
|
||||||
|
if self._reconnect_lock is None:
|
||||||
|
self._reconnect_lock = asyncio.Lock()
|
||||||
|
return self._reconnect_lock
|
||||||
|
|
||||||
|
async def _ensure_channel(self) -> aio_pika.abc.AbstractChannel:
|
||||||
|
"""Get a valid channel, reconnecting if the current one is stale.
|
||||||
|
|
||||||
|
Uses a lock to prevent concurrent reconnection attempts from racing.
|
||||||
|
"""
|
||||||
|
if self.is_ready:
|
||||||
|
return self._channel # type: ignore # is_ready guarantees non-None
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
if self.is_ready:
|
||||||
|
return self._channel # type: ignore
|
||||||
|
|
||||||
|
self._channel = None
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
if self._channel is None:
|
||||||
|
raise RuntimeError("Channel should be established after connect")
|
||||||
|
|
||||||
|
return self._channel
|
||||||
|
|
||||||
|
async def _publish_once(
|
||||||
self,
|
self,
|
||||||
routing_key: str,
|
routing_key: str,
|
||||||
message: str,
|
message: str,
|
||||||
exchange: Optional[Exchange] = None,
|
exchange: Optional[Exchange] = None,
|
||||||
persistent: bool = True,
|
persistent: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.is_ready:
|
channel = await self._ensure_channel()
|
||||||
await self.connect()
|
|
||||||
|
|
||||||
if self._channel is None:
|
|
||||||
raise RuntimeError("Channel should be established after connect")
|
|
||||||
|
|
||||||
if exchange:
|
if exchange:
|
||||||
exchange_obj = await self._channel.get_exchange(exchange.name)
|
exchange_obj = await channel.get_exchange(exchange.name)
|
||||||
else:
|
else:
|
||||||
exchange_obj = self._channel.default_exchange
|
exchange_obj = channel.default_exchange
|
||||||
|
|
||||||
await exchange_obj.publish(
|
await exchange_obj.publish(
|
||||||
aio_pika.Message(
|
aio_pika.Message(
|
||||||
@@ -322,9 +359,23 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
routing_key=routing_key,
|
routing_key=routing_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@func_retry
|
||||||
|
async def publish_message(
|
||||||
|
self,
|
||||||
|
routing_key: str,
|
||||||
|
message: str,
|
||||||
|
exchange: Optional[Exchange] = None,
|
||||||
|
persistent: bool = True,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
await self._publish_once(routing_key, message, exchange, persistent)
|
||||||
|
except aio_pika.exceptions.ChannelInvalidStateError:
|
||||||
|
logger.warning(
|
||||||
|
"RabbitMQ channel invalid, forcing reconnect and retrying publish"
|
||||||
|
)
|
||||||
|
async with self._lock:
|
||||||
|
self._channel = None
|
||||||
|
await self._publish_once(routing_key, message, exchange, persistent)
|
||||||
|
|
||||||
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
|
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
|
||||||
if not self.is_ready:
|
return await self._ensure_channel()
|
||||||
await self.connect()
|
|
||||||
if self._channel is None:
|
|
||||||
raise RuntimeError("Channel should be established after connect")
|
|
||||||
return self._channel
|
|
||||||
|
|||||||
@@ -342,6 +342,14 @@ async def store_media_file(
|
|||||||
if not target_path.is_file():
|
if not target_path.is_file():
|
||||||
raise ValueError(f"Local file does not exist: {target_path}")
|
raise ValueError(f"Local file does not exist: {target_path}")
|
||||||
|
|
||||||
|
# Virus scan the local file before any further processing
|
||||||
|
local_content = target_path.read_bytes()
|
||||||
|
if len(local_content) > MAX_FILE_SIZE_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large: {len(local_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||||
|
)
|
||||||
|
await scan_content_safe(local_content, filename=sanitized_file)
|
||||||
|
|
||||||
# Return based on requested format
|
# Return based on requested format
|
||||||
if return_format == "for_local_processing":
|
if return_format == "for_local_processing":
|
||||||
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
||||||
|
|||||||
@@ -247,3 +247,100 @@ class TestFileCloudIntegration:
|
|||||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
return_format="for_local_processing",
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_media_file_local_path_scanned(self):
|
||||||
|
"""Test that local file paths are scanned for viruses."""
|
||||||
|
graph_exec_id = "test-exec-123"
|
||||||
|
local_file = "test_video.mp4"
|
||||||
|
file_content = b"fake video content"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.util.file.get_cloud_storage_handler"
|
||||||
|
) as mock_handler_getter, patch(
|
||||||
|
"backend.util.file.scan_content_safe"
|
||||||
|
) as mock_scan, patch(
|
||||||
|
"backend.util.file.Path"
|
||||||
|
) as mock_path_class:
|
||||||
|
|
||||||
|
# Mock cloud storage handler - not a cloud path
|
||||||
|
mock_handler = MagicMock()
|
||||||
|
mock_handler.is_cloud_path.return_value = False
|
||||||
|
mock_handler_getter.return_value = mock_handler
|
||||||
|
|
||||||
|
# Mock virus scanner
|
||||||
|
mock_scan.return_value = None
|
||||||
|
|
||||||
|
# Mock file system operations
|
||||||
|
mock_base_path = MagicMock()
|
||||||
|
mock_target_path = MagicMock()
|
||||||
|
mock_resolved_path = MagicMock()
|
||||||
|
|
||||||
|
mock_path_class.return_value = mock_base_path
|
||||||
|
mock_base_path.mkdir = MagicMock()
|
||||||
|
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
||||||
|
mock_target_path.resolve.return_value = mock_resolved_path
|
||||||
|
mock_resolved_path.is_relative_to.return_value = True
|
||||||
|
mock_resolved_path.is_file.return_value = True
|
||||||
|
mock_resolved_path.read_bytes.return_value = file_content
|
||||||
|
mock_resolved_path.relative_to.return_value = Path(local_file)
|
||||||
|
mock_resolved_path.name = local_file
|
||||||
|
|
||||||
|
result = await store_media_file(
|
||||||
|
file=MediaFileType(local_file),
|
||||||
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify virus scan was called for local file
|
||||||
|
mock_scan.assert_called_once_with(file_content, filename=local_file)
|
||||||
|
|
||||||
|
# Result should be the relative path
|
||||||
|
assert str(result) == local_file
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_media_file_local_path_virus_detected(self):
|
||||||
|
"""Test that infected local files raise VirusDetectedError."""
|
||||||
|
from backend.api.features.store.exceptions import VirusDetectedError
|
||||||
|
|
||||||
|
graph_exec_id = "test-exec-123"
|
||||||
|
local_file = "infected.exe"
|
||||||
|
file_content = b"malicious content"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.util.file.get_cloud_storage_handler"
|
||||||
|
) as mock_handler_getter, patch(
|
||||||
|
"backend.util.file.scan_content_safe"
|
||||||
|
) as mock_scan, patch(
|
||||||
|
"backend.util.file.Path"
|
||||||
|
) as mock_path_class:
|
||||||
|
|
||||||
|
# Mock cloud storage handler - not a cloud path
|
||||||
|
mock_handler = MagicMock()
|
||||||
|
mock_handler.is_cloud_path.return_value = False
|
||||||
|
mock_handler_getter.return_value = mock_handler
|
||||||
|
|
||||||
|
# Mock virus scanner to detect virus
|
||||||
|
mock_scan.side_effect = VirusDetectedError(
|
||||||
|
"EICAR-Test-File", "File rejected due to virus detection"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock file system operations
|
||||||
|
mock_base_path = MagicMock()
|
||||||
|
mock_target_path = MagicMock()
|
||||||
|
mock_resolved_path = MagicMock()
|
||||||
|
|
||||||
|
mock_path_class.return_value = mock_base_path
|
||||||
|
mock_base_path.mkdir = MagicMock()
|
||||||
|
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
||||||
|
mock_target_path.resolve.return_value = mock_resolved_path
|
||||||
|
mock_resolved_path.is_relative_to.return_value = True
|
||||||
|
mock_resolved_path.is_file.return_value = True
|
||||||
|
mock_resolved_path.read_bytes.return_value = file_content
|
||||||
|
|
||||||
|
with pytest.raises(VirusDetectedError):
|
||||||
|
await store_media_file(
|
||||||
|
file=MediaFileType(local_file),
|
||||||
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|||||||
7199
autogpt_platform/backend/poetry.lock
generated
7199
autogpt_platform/backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -12,16 +12,17 @@ python = ">=3.10,<3.14"
|
|||||||
aio-pika = "^9.5.5"
|
aio-pika = "^9.5.5"
|
||||||
aiohttp = "^3.10.0"
|
aiohttp = "^3.10.0"
|
||||||
aiodns = "^3.5.0"
|
aiodns = "^3.5.0"
|
||||||
anthropic = "^0.59.0"
|
anthropic = "^0.79.0"
|
||||||
apscheduler = "^3.11.1"
|
apscheduler = "^3.11.1"
|
||||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||||
|
claude-agent-sdk = "^0.1.0"
|
||||||
click = "^8.2.0"
|
click = "^8.2.0"
|
||||||
cryptography = "^45.0"
|
cryptography = "^46.0"
|
||||||
discord-py = "^2.5.2"
|
discord-py = "^2.5.2"
|
||||||
e2b-code-interpreter = "^1.5.2"
|
e2b-code-interpreter = "^1.5.2"
|
||||||
elevenlabs = "^1.50.0"
|
elevenlabs = "^1.50.0"
|
||||||
fastapi = "^0.116.1"
|
fastapi = "^0.128.5"
|
||||||
feedparser = "^6.0.11"
|
feedparser = "^6.0.11"
|
||||||
flake8 = "^7.3.0"
|
flake8 = "^7.3.0"
|
||||||
google-api-python-client = "^2.177.0"
|
google-api-python-client = "^2.177.0"
|
||||||
@@ -35,10 +36,10 @@ jinja2 = "^3.1.6"
|
|||||||
jsonref = "^1.1.0"
|
jsonref = "^1.1.0"
|
||||||
jsonschema = "^4.25.0"
|
jsonschema = "^4.25.0"
|
||||||
langfuse = "^3.11.0"
|
langfuse = "^3.11.0"
|
||||||
launchdarkly-server-sdk = "^9.12.0"
|
launchdarkly-server-sdk = "^9.14.1"
|
||||||
mem0ai = "^0.1.115"
|
mem0ai = "^0.1.115"
|
||||||
moviepy = "^2.1.2"
|
moviepy = "^2.1.2"
|
||||||
ollama = "^0.5.1"
|
ollama = "^0.6.1"
|
||||||
openai = "^1.97.1"
|
openai = "^1.97.1"
|
||||||
orjson = "^3.10.0"
|
orjson = "^3.10.0"
|
||||||
pika = "^1.3.2"
|
pika = "^1.3.2"
|
||||||
@@ -48,16 +49,16 @@ postmarker = "^1.0"
|
|||||||
praw = "~7.8.1"
|
praw = "~7.8.1"
|
||||||
prisma = "^0.15.0"
|
prisma = "^0.15.0"
|
||||||
rank-bm25 = "^0.2.2"
|
rank-bm25 = "^0.2.2"
|
||||||
prometheus-client = "^0.22.1"
|
prometheus-client = "^0.24.1"
|
||||||
prometheus-fastapi-instrumentator = "^7.0.0"
|
prometheus-fastapi-instrumentator = "^7.0.0"
|
||||||
psutil = "^7.0.0"
|
psutil = "^7.0.0"
|
||||||
psycopg2-binary = "^2.9.10"
|
psycopg2-binary = "^2.9.10"
|
||||||
pydantic = { extras = ["email"], version = "^2.11.7" }
|
pydantic = { extras = ["email"], version = "^2.12.5" }
|
||||||
pydantic-settings = "^2.10.1"
|
pydantic-settings = "^2.12.0"
|
||||||
pytest = "^8.4.1"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.1.0"
|
pytest-asyncio = "^1.1.0"
|
||||||
python-dotenv = "^1.1.1"
|
python-dotenv = "^1.1.1"
|
||||||
python-multipart = "^0.0.20"
|
python-multipart = "^0.0.22"
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
regex = "^2025.9.18"
|
regex = "^2025.9.18"
|
||||||
replicate = "^1.0.6"
|
replicate = "^1.0.6"
|
||||||
@@ -65,11 +66,11 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal
|
|||||||
sqlalchemy = "^2.0.40"
|
sqlalchemy = "^2.0.40"
|
||||||
strenum = "^0.4.9"
|
strenum = "^0.4.9"
|
||||||
stripe = "^11.5.0"
|
stripe = "^11.5.0"
|
||||||
supabase = "2.17.0"
|
supabase = "2.27.3"
|
||||||
tenacity = "^9.1.2"
|
tenacity = "^9.1.4"
|
||||||
todoist-api-python = "^2.1.7"
|
todoist-api-python = "^2.1.7"
|
||||||
tweepy = "^4.16.0"
|
tweepy = "^4.16.0"
|
||||||
uvicorn = { extras = ["standard"], version = "^0.35.0" }
|
uvicorn = { extras = ["standard"], version = "^0.40.0" }
|
||||||
websockets = "^15.0"
|
websockets = "^15.0"
|
||||||
youtube-transcript-api = "^1.2.1"
|
youtube-transcript-api = "^1.2.1"
|
||||||
yt-dlp = "2025.12.08"
|
yt-dlp = "2025.12.08"
|
||||||
@@ -77,7 +78,7 @@ zerobouncesdk = "^1.1.2"
|
|||||||
# NOTE: please insert new dependencies in their alphabetical location
|
# NOTE: please insert new dependencies in their alphabetical location
|
||||||
pytest-snapshot = "^0.9.0"
|
pytest-snapshot = "^0.9.0"
|
||||||
aiofiles = "^24.1.0"
|
aiofiles = "^24.1.0"
|
||||||
tiktoken = "^0.9.0"
|
tiktoken = "^0.12.0"
|
||||||
aioclamd = "^1.0.0"
|
aioclamd = "^1.0.0"
|
||||||
setuptools = "^80.9.0"
|
setuptools = "^80.9.0"
|
||||||
gcloud-aio-storage = "^9.5.0"
|
gcloud-aio-storage = "^9.5.0"
|
||||||
@@ -95,13 +96,13 @@ black = "^24.10.0"
|
|||||||
faker = "^38.2.0"
|
faker = "^38.2.0"
|
||||||
httpx = "^0.28.1"
|
httpx = "^0.28.1"
|
||||||
isort = "^5.13.2"
|
isort = "^5.13.2"
|
||||||
poethepoet = "^0.37.0"
|
poethepoet = "^0.41.0"
|
||||||
pre-commit = "^4.4.0"
|
pre-commit = "^4.4.0"
|
||||||
pyright = "^1.1.407"
|
pyright = "^1.1.407"
|
||||||
pytest-mock = "^3.15.1"
|
pytest-mock = "^3.15.1"
|
||||||
pytest-watcher = "^0.4.2"
|
pytest-watcher = "^0.6.3"
|
||||||
requests = "^2.32.5"
|
requests = "^2.32.5"
|
||||||
ruff = "^0.14.5"
|
ruff = "^0.15.0"
|
||||||
# NOTE: please insert new dependencies in their alphabetical location
|
# NOTE: please insert new dependencies in their alphabetical location
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
@@ -102,7 +102,7 @@
|
|||||||
"react-markdown": "9.0.3",
|
"react-markdown": "9.0.3",
|
||||||
"react-modal": "3.16.3",
|
"react-modal": "3.16.3",
|
||||||
"react-shepherd": "6.1.9",
|
"react-shepherd": "6.1.9",
|
||||||
"react-window": "1.8.11",
|
"react-window": "2.2.0",
|
||||||
"recharts": "3.3.0",
|
"recharts": "3.3.0",
|
||||||
"rehype-autolink-headings": "7.1.0",
|
"rehype-autolink-headings": "7.1.0",
|
||||||
"rehype-highlight": "7.0.2",
|
"rehype-highlight": "7.0.2",
|
||||||
@@ -140,7 +140,7 @@
|
|||||||
"@types/react": "18.3.17",
|
"@types/react": "18.3.17",
|
||||||
"@types/react-dom": "18.3.5",
|
"@types/react-dom": "18.3.5",
|
||||||
"@types/react-modal": "3.16.3",
|
"@types/react-modal": "3.16.3",
|
||||||
"@types/react-window": "1.8.8",
|
"@types/react-window": "2.0.0",
|
||||||
"@vitejs/plugin-react": "5.1.2",
|
"@vitejs/plugin-react": "5.1.2",
|
||||||
"axe-playwright": "2.2.2",
|
"axe-playwright": "2.2.2",
|
||||||
"chromatic": "13.3.3",
|
"chromatic": "13.3.3",
|
||||||
|
|||||||
38
autogpt_platform/frontend/pnpm-lock.yaml
generated
38
autogpt_platform/frontend/pnpm-lock.yaml
generated
@@ -228,8 +228,8 @@ importers:
|
|||||||
specifier: 6.1.9
|
specifier: 6.1.9
|
||||||
version: 6.1.9(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(typescript@5.9.3)
|
version: 6.1.9(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(typescript@5.9.3)
|
||||||
react-window:
|
react-window:
|
||||||
specifier: 1.8.11
|
specifier: 2.2.0
|
||||||
version: 1.8.11(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
version: 2.2.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
recharts:
|
recharts:
|
||||||
specifier: 3.3.0
|
specifier: 3.3.0
|
||||||
version: 3.3.0(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react-is@18.3.1)(react@18.3.1)(redux@5.0.1)
|
version: 3.3.0(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react-is@18.3.1)(react@18.3.1)(redux@5.0.1)
|
||||||
@@ -337,8 +337,8 @@ importers:
|
|||||||
specifier: 3.16.3
|
specifier: 3.16.3
|
||||||
version: 3.16.3
|
version: 3.16.3
|
||||||
'@types/react-window':
|
'@types/react-window':
|
||||||
specifier: 1.8.8
|
specifier: 2.0.0
|
||||||
version: 1.8.8
|
version: 2.0.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
'@vitejs/plugin-react':
|
'@vitejs/plugin-react':
|
||||||
specifier: 5.1.2
|
specifier: 5.1.2
|
||||||
version: 5.1.2(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2))
|
version: 5.1.2(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2))
|
||||||
@@ -3469,8 +3469,9 @@ packages:
|
|||||||
'@types/react-modal@3.16.3':
|
'@types/react-modal@3.16.3':
|
||||||
resolution: {integrity: sha512-xXuGavyEGaFQDgBv4UVm8/ZsG+qxeQ7f77yNrW3n+1J6XAstUy5rYHeIHPh1KzsGc6IkCIdu6lQ2xWzu1jBTLg==}
|
resolution: {integrity: sha512-xXuGavyEGaFQDgBv4UVm8/ZsG+qxeQ7f77yNrW3n+1J6XAstUy5rYHeIHPh1KzsGc6IkCIdu6lQ2xWzu1jBTLg==}
|
||||||
|
|
||||||
'@types/react-window@1.8.8':
|
'@types/react-window@2.0.0':
|
||||||
resolution: {integrity: sha512-8Ls660bHR1AUA2kuRvVG9D/4XpRC6wjAaPT9dil7Ckc76eP9TKWZwwmgfq8Q1LANX3QNDnoU4Zp48A3w+zK69Q==}
|
resolution: {integrity: sha512-E8hMDtImEpMk1SjswSvqoSmYvk7GEtyVaTa/GJV++FdDNuMVVEzpAClyJ0nqeKYBrMkGiyH6M1+rPLM0Nu1exQ==}
|
||||||
|
deprecated: This is a stub types definition. react-window provides its own type definitions, so you do not need this installed.
|
||||||
|
|
||||||
'@types/react@18.3.17':
|
'@types/react@18.3.17':
|
||||||
resolution: {integrity: sha512-opAQ5no6LqJNo9TqnxBKsgnkIYHozW9KSTlFVoSUJYh1Fl/sswkEoqIugRSm7tbh6pABtYjGAjW+GOS23j8qbw==}
|
resolution: {integrity: sha512-opAQ5no6LqJNo9TqnxBKsgnkIYHozW9KSTlFVoSUJYh1Fl/sswkEoqIugRSm7tbh6pABtYjGAjW+GOS23j8qbw==}
|
||||||
@@ -5976,9 +5977,6 @@ packages:
|
|||||||
resolution: {integrity: sha512-UERzLsxzllchadvbPs5aolHh65ISpKpM+ccLbOJ8/vvpBKmAWf+la7dXFy7Mr0ySHbdHrFv5kGFCUHHe6GFEmw==}
|
resolution: {integrity: sha512-UERzLsxzllchadvbPs5aolHh65ISpKpM+ccLbOJ8/vvpBKmAWf+la7dXFy7Mr0ySHbdHrFv5kGFCUHHe6GFEmw==}
|
||||||
engines: {node: '>= 4.0.0'}
|
engines: {node: '>= 4.0.0'}
|
||||||
|
|
||||||
memoize-one@5.2.1:
|
|
||||||
resolution: {integrity: sha512-zYiwtZUcYyXKo/np96AGZAckk+FWWsUdJ3cHGGmld7+AhvcWmQyGCYUh1hc4Q/pkOhb65dQR/pqCyK0cOaHz4Q==}
|
|
||||||
|
|
||||||
merge-stream@2.0.0:
|
merge-stream@2.0.0:
|
||||||
resolution: {integrity: sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==}
|
resolution: {integrity: sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==}
|
||||||
|
|
||||||
@@ -6891,12 +6889,11 @@ packages:
|
|||||||
'@types/react':
|
'@types/react':
|
||||||
optional: true
|
optional: true
|
||||||
|
|
||||||
react-window@1.8.11:
|
react-window@2.2.0:
|
||||||
resolution: {integrity: sha512-+SRbUVT2scadgFSWx+R1P754xHPEqvcfSfVX10QYg6POOz+WNgkN48pS+BtZNIMGiL1HYrSEiCkwsMS15QogEQ==}
|
resolution: {integrity: sha512-Y2L7yonHq6K1pQA2P98wT5QdIsEcjBTB7T8o6Mub12hH9eYppXoYu6vgClmcjlh3zfNcW2UrXiJJJqDxUY7GVw==}
|
||||||
engines: {node: '>8.0.0'}
|
|
||||||
peerDependencies:
|
peerDependencies:
|
||||||
react: ^15.0.0 || ^16.0.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
|
react: ^18.0.0 || ^19.0.0
|
||||||
react-dom: ^15.0.0 || ^16.0.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
|
react-dom: ^18.0.0 || ^19.0.0
|
||||||
|
|
||||||
react@18.3.1:
|
react@18.3.1:
|
||||||
resolution: {integrity: sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==}
|
resolution: {integrity: sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==}
|
||||||
@@ -11603,9 +11600,12 @@ snapshots:
|
|||||||
dependencies:
|
dependencies:
|
||||||
'@types/react': 18.3.17
|
'@types/react': 18.3.17
|
||||||
|
|
||||||
'@types/react-window@1.8.8':
|
'@types/react-window@2.0.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
|
||||||
dependencies:
|
dependencies:
|
||||||
'@types/react': 18.3.17
|
react-window: 2.2.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
|
transitivePeerDependencies:
|
||||||
|
- react
|
||||||
|
- react-dom
|
||||||
|
|
||||||
'@types/react@18.3.17':
|
'@types/react@18.3.17':
|
||||||
dependencies:
|
dependencies:
|
||||||
@@ -14545,8 +14545,6 @@ snapshots:
|
|||||||
dependencies:
|
dependencies:
|
||||||
fs-monkey: 1.1.0
|
fs-monkey: 1.1.0
|
||||||
|
|
||||||
memoize-one@5.2.1: {}
|
|
||||||
|
|
||||||
merge-stream@2.0.0: {}
|
merge-stream@2.0.0: {}
|
||||||
|
|
||||||
merge2@1.4.1: {}
|
merge2@1.4.1: {}
|
||||||
@@ -15592,10 +15590,8 @@ snapshots:
|
|||||||
optionalDependencies:
|
optionalDependencies:
|
||||||
'@types/react': 18.3.17
|
'@types/react': 18.3.17
|
||||||
|
|
||||||
react-window@1.8.11(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
react-window@2.2.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
||||||
dependencies:
|
dependencies:
|
||||||
'@babel/runtime': 7.28.4
|
|
||||||
memoize-one: 5.2.1
|
|
||||||
react: 18.3.1
|
react: 18.3.1
|
||||||
react-dom: 18.3.1(react@18.3.1)
|
react-dom: 18.3.1(react@18.3.1)
|
||||||
|
|
||||||
|
|||||||
@@ -12307,7 +12307,9 @@
|
|||||||
"title": "Location"
|
"title": "Location"
|
||||||
},
|
},
|
||||||
"msg": { "type": "string", "title": "Message" },
|
"msg": { "type": "string", "title": "Message" },
|
||||||
"type": { "type": "string", "title": "Error Type" }
|
"type": { "type": "string", "title": "Error Type" },
|
||||||
|
"input": { "title": "Input" },
|
||||||
|
"ctx": { "type": "object", "title": "Context" }
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["loc", "msg", "type"],
|
"required": ["loc", "msg", "type"],
|
||||||
|
|||||||
@@ -104,7 +104,31 @@ export function FileInput(props: Props) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const getFileLabelFromValue = (val: string) => {
|
const getFileLabelFromValue = (val: unknown): string => {
|
||||||
|
// Handle object format from external API: { name, type, size, data }
|
||||||
|
if (val && typeof val === "object") {
|
||||||
|
const obj = val as Record<string, unknown>;
|
||||||
|
if (typeof obj.name === "string") {
|
||||||
|
return getFileLabel(
|
||||||
|
obj.name,
|
||||||
|
typeof obj.type === "string" ? obj.type : "",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (typeof obj.type === "string") {
|
||||||
|
const mimeParts = obj.type.split("/");
|
||||||
|
if (mimeParts.length > 1) {
|
||||||
|
return `${mimeParts[1].toUpperCase()} file`;
|
||||||
|
}
|
||||||
|
return `${obj.type} file`;
|
||||||
|
}
|
||||||
|
return "File";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle string values (data URIs or file paths)
|
||||||
|
if (typeof val !== "string") {
|
||||||
|
return "File";
|
||||||
|
}
|
||||||
|
|
||||||
if (val.startsWith("data:")) {
|
if (val.startsWith("data:")) {
|
||||||
const matches = val.match(/^data:([^;]+);/);
|
const matches = val.match(/^data:([^;]+);/);
|
||||||
if (matches?.[1]) {
|
if (matches?.[1]) {
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import { Button } from "@/components/atoms/Button/Button";
|
|||||||
import { Input } from "@/components/atoms/Input/Input";
|
import { Input } from "@/components/atoms/Input/Input";
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { Bell, MagnifyingGlass, X } from "@phosphor-icons/react";
|
import { Bell, MagnifyingGlass, X } from "@phosphor-icons/react";
|
||||||
import { FixedSizeList as List } from "react-window";
|
import { List, type RowComponentProps } from "react-window";
|
||||||
import { AgentExecutionWithInfo } from "../../helpers";
|
import { AgentExecutionWithInfo } from "../../helpers";
|
||||||
import { ActivityItem } from "../ActivityItem";
|
import { ActivityItem } from "../ActivityItem";
|
||||||
import styles from "./styles.module.css";
|
import styles from "./styles.module.css";
|
||||||
@@ -19,14 +19,16 @@ interface Props {
|
|||||||
recentFailures: AgentExecutionWithInfo[];
|
recentFailures: AgentExecutionWithInfo[];
|
||||||
}
|
}
|
||||||
|
|
||||||
interface VirtualizedItemProps {
|
interface ActivityRowProps {
|
||||||
index: number;
|
executions: AgentExecutionWithInfo[];
|
||||||
style: React.CSSProperties;
|
|
||||||
data: AgentExecutionWithInfo[];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function VirtualizedActivityItem({ index, style, data }: VirtualizedItemProps) {
|
function VirtualizedActivityItem({
|
||||||
const execution = data[index];
|
index,
|
||||||
|
style,
|
||||||
|
executions,
|
||||||
|
}: RowComponentProps<ActivityRowProps>) {
|
||||||
|
const execution = executions[index];
|
||||||
return (
|
return (
|
||||||
<div style={style}>
|
<div style={style}>
|
||||||
<ActivityItem execution={execution} />
|
<ActivityItem execution={execution} />
|
||||||
@@ -129,14 +131,13 @@ export function ActivityDropdown({
|
|||||||
>
|
>
|
||||||
{filteredExecutions.length > 0 ? (
|
{filteredExecutions.length > 0 ? (
|
||||||
<List
|
<List
|
||||||
height={listHeight}
|
defaultHeight={listHeight}
|
||||||
width={320} // Match dropdown width (w-80 = 20rem = 320px)
|
rowCount={filteredExecutions.length}
|
||||||
itemCount={filteredExecutions.length}
|
rowHeight={itemHeight}
|
||||||
itemSize={itemHeight}
|
rowProps={{ executions: filteredExecutions }}
|
||||||
itemData={filteredExecutions}
|
rowComponent={VirtualizedActivityItem}
|
||||||
>
|
style={{ width: 320, height: listHeight }}
|
||||||
{VirtualizedActivityItem}
|
/>
|
||||||
</List>
|
|
||||||
) : (
|
) : (
|
||||||
<div className="flex h-full flex-col items-center justify-center gap-5 pb-8 pt-6">
|
<div className="flex h-full flex-col items-center justify-center gap-5 pb-8 pt-6">
|
||||||
<div className="mx-auto inline-flex flex-col items-center justify-center rounded-full bg-bgLightGrey p-6">
|
<div className="mx-auto inline-flex flex-col items-center justify-center rounded-full bg-bgLightGrey p-6">
|
||||||
|
|||||||
Reference in New Issue
Block a user