mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-11 23:35:25 -05:00
Compare commits
7 Commits
chore/comb
...
refactor/a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad1a814724 | ||
|
|
562cf04ab6 | ||
|
|
90b3b5ba16 | ||
|
|
f4f81bc4fc | ||
|
|
c5abc01f25 | ||
|
|
8b7053c1de | ||
|
|
e00c1202ad |
2
.github/workflows/classic-frontend-ci.yml
vendored
2
.github/workflows/classic-frontend-ci.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||||
if: github.event_name == 'push'
|
if: github.event_name == 'push'
|
||||||
uses: peter-evans/create-pull-request@v8
|
uses: peter-evans/create-pull-request@v7
|
||||||
with:
|
with:
|
||||||
add-paths: classic/frontend/build/web
|
add-paths: classic/frontend/build/web
|
||||||
base: ${{ github.ref_name }}
|
base: ${{ github.ref_name }}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Get CI failure details
|
- name: Get CI failure details
|
||||||
id: failure_details
|
id: failure_details
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const run = await github.rest.actions.getWorkflowRun({
|
const run = await github.rest.actions.getWorkflowRun({
|
||||||
|
|||||||
9
.github/workflows/claude-dependabot.yml
vendored
9
.github/workflows/claude-dependabot.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -78,7 +78,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -124,7 +124,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
@@ -309,7 +309,6 @@ jobs:
|
|||||||
uses: anthropics/claude-code-action@v1
|
uses: anthropics/claude-code-action@v1
|
||||||
with:
|
with:
|
||||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||||
allowed_bots: "dependabot[bot]"
|
|
||||||
claude_args: |
|
claude_args: |
|
||||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||||
prompt: |
|
prompt: |
|
||||||
|
|||||||
8
.github/workflows/claude.yml
vendored
8
.github/workflows/claude.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -94,7 +94,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -140,7 +140,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
|
|||||||
8
.github/workflows/copilot-setup-steps.yml
vendored
8
.github/workflows/copilot-setup-steps.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -76,7 +76,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -132,7 +132,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
|
|||||||
2
.github/workflows/docs-block-sync.yml
vendored
2
.github/workflows/docs-block-sync.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
2
.github/workflows/docs-claude-review.yml
vendored
2
.github/workflows/docs-claude-review.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
2
.github/workflows/docs-enhance.yml
vendored
2
.github/workflows/docs-enhance.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger deploy workflow
|
- name: Trigger deploy workflow
|
||||||
uses: peter-evans/repository-dispatch@v4
|
uses: peter-evans/repository-dispatch@v3
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger deploy workflow
|
- name: Trigger deploy workflow
|
||||||
uses: peter-evans/repository-dispatch@v4
|
uses: peter-evans/repository-dispatch@v3
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
|
|||||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -88,7 +88,7 @@ jobs:
|
|||||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ jobs:
|
|||||||
- name: Check comment permissions and deployment status
|
- name: Check comment permissions and deployment status
|
||||||
id: check_status
|
id: check_status
|
||||||
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const commentBody = context.payload.comment.body.trim();
|
const commentBody = context.payload.comment.body.trim();
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post permission denied comment
|
- name: Post permission denied comment
|
||||||
if: steps.check_status.outputs.permission_denied == 'true'
|
if: steps.check_status.outputs.permission_denied == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -68,7 +68,7 @@ jobs:
|
|||||||
- name: Get PR details for deployment
|
- name: Get PR details for deployment
|
||||||
id: pr_details
|
id: pr_details
|
||||||
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const pr = await github.rest.pulls.get({
|
const pr = await github.rest.pulls.get({
|
||||||
@@ -82,7 +82,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Dispatch Deploy Event
|
- name: Dispatch Deploy Event
|
||||||
if: steps.check_status.outputs.should_deploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true'
|
||||||
uses: peter-evans/repository-dispatch@v4
|
uses: peter-evans/repository-dispatch@v3
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
@@ -98,7 +98,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post deploy success comment
|
- name: Post deploy success comment
|
||||||
if: steps.check_status.outputs.should_deploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -110,7 +110,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Dispatch Undeploy Event (from comment)
|
- name: Dispatch Undeploy Event (from comment)
|
||||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: peter-evans/repository-dispatch@v4
|
uses: peter-evans/repository-dispatch@v3
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
@@ -126,7 +126,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post undeploy success comment
|
- name: Post undeploy success comment
|
||||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -139,7 +139,7 @@ jobs:
|
|||||||
- name: Check deployment status on PR close
|
- name: Check deployment status on PR close
|
||||||
id: check_pr_close
|
id: check_pr_close
|
||||||
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const comments = await github.rest.issues.listComments({
|
const comments = await github.rest.issues.listComments({
|
||||||
@@ -168,7 +168,7 @@ jobs:
|
|||||||
github.event_name == 'pull_request' &&
|
github.event_name == 'pull_request' &&
|
||||||
github.event.action == 'closed' &&
|
github.event.action == 'closed' &&
|
||||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||||
uses: peter-evans/repository-dispatch@v4
|
uses: peter-evans/repository-dispatch@v3
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
@@ -187,7 +187,7 @@ jobs:
|
|||||||
github.event_name == 'pull_request' &&
|
github.event_name == 'pull_request' &&
|
||||||
github.event.action == 'closed' &&
|
github.event.action == 'closed' &&
|
||||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
|
|||||||
22
.github/workflows/platform-frontend-ci.yml
vendored
22
.github/workflows/platform-frontend-ci.yml
vendored
@@ -42,7 +42,7 @@ jobs:
|
|||||||
- 'autogpt_platform/frontend/src/components/**'
|
- 'autogpt_platform/frontend/src/components/**'
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -54,7 +54,7 @@ jobs:
|
|||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
key: ${{ steps.cache-key.outputs.key }}
|
||||||
@@ -74,7 +74,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -82,7 +82,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -112,7 +112,7 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -120,7 +120,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -153,7 +153,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -176,7 +176,7 @@ jobs:
|
|||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
- name: Cache Docker layers
|
- name: Cache Docker layers
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: /tmp/.buildx-cache
|
path: /tmp/.buildx-cache
|
||||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -231,7 +231,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -282,7 +282,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -290,7 +290,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
|||||||
12
.github/workflows/platform-fullstack-ci.yml
vendored
12
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ jobs:
|
|||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
key: ${{ steps.cache-key.outputs.key }}
|
||||||
@@ -56,7 +56,7 @@ jobs:
|
|||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
types:
|
types:
|
||||||
runs-on: big-boi
|
runs-on: ubuntu-latest
|
||||||
needs: setup
|
needs: setup
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
@@ -68,7 +68,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -85,10 +85,10 @@ jobs:
|
|||||||
|
|
||||||
- name: Run docker compose
|
- name: Run docker compose
|
||||||
run: |
|
run: |
|
||||||
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
|
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
|||||||
1877
autogpt_platform/autogpt_libs/poetry.lock
generated
1877
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.10,<4.0"
|
python = ">=3.10,<4.0"
|
||||||
colorama = "^0.4.6"
|
colorama = "^0.4.6"
|
||||||
cryptography = "^46.0"
|
cryptography = "^45.0"
|
||||||
expiringdict = "^1.2.2"
|
expiringdict = "^1.2.2"
|
||||||
fastapi = "^0.128.0"
|
fastapi = "^0.116.1"
|
||||||
google-cloud-logging = "^3.13.0"
|
google-cloud-logging = "^3.12.1"
|
||||||
launchdarkly-server-sdk = "^9.15.0"
|
launchdarkly-server-sdk = "^9.12.0"
|
||||||
pydantic = "^2.12.5"
|
pydantic = "^2.11.7"
|
||||||
pydantic-settings = "^2.12.0"
|
pydantic-settings = "^2.10.1"
|
||||||
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||||
redis = "^7.1.1"
|
redis = "^6.2.0"
|
||||||
supabase = "^2.28.0"
|
supabase = "^2.16.0"
|
||||||
uvicorn = "^0.40.0"
|
uvicorn = "^0.35.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pyright = "^1.1.408"
|
pyright = "^1.1.404"
|
||||||
pytest = "^9.0.2"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.3.0"
|
pytest-asyncio = "^1.1.0"
|
||||||
pytest-mock = "^3.15.1"
|
pytest-mock = "^3.14.1"
|
||||||
pytest-cov = "^7.0.0"
|
pytest-cov = "^6.2.1"
|
||||||
ruff = "^0.15.0"
|
ruff = "^0.12.11"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|||||||
@@ -93,12 +93,6 @@ class ChatConfig(BaseSettings):
|
|||||||
description="Name of the prompt in Langfuse to fetch",
|
description="Name of the prompt in Langfuse to fetch",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extended thinking configuration for Claude models
|
|
||||||
thinking_enabled: bool = Field(
|
|
||||||
default=True,
|
|
||||||
description="Enable adaptive thinking for Claude models via OpenRouter",
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("api_key", mode="before")
|
@field_validator("api_key", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_api_key(cls, v):
|
def get_api_key(cls, v):
|
||||||
|
|||||||
@@ -45,7 +45,10 @@ async def create_chat_session(
|
|||||||
successfulAgentRuns=SafeJson({}),
|
successfulAgentRuns=SafeJson({}),
|
||||||
successfulAgentSchedules=SafeJson({}),
|
successfulAgentSchedules=SafeJson({}),
|
||||||
)
|
)
|
||||||
return await PrismaChatSession.prisma().create(data=data)
|
return await PrismaChatSession.prisma().create(
|
||||||
|
data=data,
|
||||||
|
include={"Messages": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def update_chat_session(
|
async def update_chat_session(
|
||||||
|
|||||||
@@ -18,10 +18,6 @@ class ResponseType(str, Enum):
|
|||||||
START = "start"
|
START = "start"
|
||||||
FINISH = "finish"
|
FINISH = "finish"
|
||||||
|
|
||||||
# Step lifecycle (one LLM API call within a message)
|
|
||||||
START_STEP = "start-step"
|
|
||||||
FINISH_STEP = "finish-step"
|
|
||||||
|
|
||||||
# Text streaming
|
# Text streaming
|
||||||
TEXT_START = "text-start"
|
TEXT_START = "text-start"
|
||||||
TEXT_DELTA = "text-delta"
|
TEXT_DELTA = "text-delta"
|
||||||
@@ -61,16 +57,6 @@ class StreamStart(StreamBaseResponse):
|
|||||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_sse(self) -> str:
|
|
||||||
"""Convert to SSE format, excluding non-protocol fields like taskId."""
|
|
||||||
import json
|
|
||||||
|
|
||||||
data: dict[str, Any] = {
|
|
||||||
"type": self.type.value,
|
|
||||||
"messageId": self.messageId,
|
|
||||||
}
|
|
||||||
return f"data: {json.dumps(data)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
class StreamFinish(StreamBaseResponse):
|
class StreamFinish(StreamBaseResponse):
|
||||||
"""End of message/stream."""
|
"""End of message/stream."""
|
||||||
@@ -78,26 +64,6 @@ class StreamFinish(StreamBaseResponse):
|
|||||||
type: ResponseType = ResponseType.FINISH
|
type: ResponseType = ResponseType.FINISH
|
||||||
|
|
||||||
|
|
||||||
class StreamStartStep(StreamBaseResponse):
|
|
||||||
"""Start of a step (one LLM API call within a message).
|
|
||||||
|
|
||||||
The AI SDK uses this to add a step-start boundary to message.parts,
|
|
||||||
enabling visual separation between multiple LLM calls in a single message.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.START_STEP
|
|
||||||
|
|
||||||
|
|
||||||
class StreamFinishStep(StreamBaseResponse):
|
|
||||||
"""End of a step (one LLM API call within a message).
|
|
||||||
|
|
||||||
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
|
|
||||||
so the next LLM call in a tool-call continuation starts with clean state.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.FINISH_STEP
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Text Streaming ==========
|
# ========== Text Streaming ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -151,7 +117,7 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
|||||||
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
||||||
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
||||||
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
||||||
# Keep these for internal backend use
|
# Additional fields for internal use (not part of AI SDK spec but useful)
|
||||||
toolName: str | None = Field(
|
toolName: str | None = Field(
|
||||||
default=None, description="Name of the tool that was executed"
|
default=None, description="Name of the tool that was executed"
|
||||||
)
|
)
|
||||||
@@ -159,17 +125,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
|||||||
default=True, description="Whether the tool execution succeeded"
|
default=True, description="Whether the tool execution succeeded"
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_sse(self) -> str:
|
|
||||||
"""Convert to SSE format, excluding non-spec fields."""
|
|
||||||
import json
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"type": self.type.value,
|
|
||||||
"toolCallId": self.toolCallId,
|
|
||||||
"output": self.output,
|
|
||||||
}
|
|
||||||
return f"data: {json.dumps(data)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Other ==========
|
# ========== Other ==========
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from autogpt_libs import auth
|
from autogpt_libs import auth
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
|
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -17,29 +17,7 @@ from . import stream_registry
|
|||||||
from .completion_handler import process_operation_failure, process_operation_success
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||||
from .response_model import StreamFinish, StreamHeartbeat
|
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
||||||
from .tools.models import (
|
|
||||||
AgentDetailsResponse,
|
|
||||||
AgentOutputResponse,
|
|
||||||
AgentPreviewResponse,
|
|
||||||
AgentSavedResponse,
|
|
||||||
AgentsFoundResponse,
|
|
||||||
BlockListResponse,
|
|
||||||
BlockOutputResponse,
|
|
||||||
ClarificationNeededResponse,
|
|
||||||
DocPageResponse,
|
|
||||||
DocSearchResultsResponse,
|
|
||||||
ErrorResponse,
|
|
||||||
ExecutionStartedResponse,
|
|
||||||
InputValidationErrorResponse,
|
|
||||||
NeedLoginResponse,
|
|
||||||
NoResultsResponse,
|
|
||||||
OperationInProgressResponse,
|
|
||||||
OperationPendingResponse,
|
|
||||||
OperationStartedResponse,
|
|
||||||
SetupRequirementsResponse,
|
|
||||||
UnderstandingUpdatedResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -288,36 +266,12 @@ async def stream_chat_post(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
|
||||||
|
|
||||||
stream_start_time = time.perf_counter()
|
|
||||||
log_meta = {"component": "ChatStream", "session_id": session_id}
|
|
||||||
if user_id:
|
|
||||||
log_meta["user_id"] = user_id
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
|
|
||||||
f"user={user_id}, message_len={len(request.message)}",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
|
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
logger.info(
|
|
||||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Create a task in the stream registry for reconnection support
|
||||||
task_id = str(uuid_module.uuid4())
|
task_id = str(uuid_module.uuid4())
|
||||||
operation_id = str(uuid_module.uuid4())
|
operation_id = str(uuid_module.uuid4())
|
||||||
log_meta["task_id"] = task_id
|
|
||||||
|
|
||||||
task_create_start = time.perf_counter()
|
|
||||||
await stream_registry.create_task(
|
await stream_registry.create_task(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -326,28 +280,14 @@ async def stream_chat_post(
|
|||||||
tool_name="chat",
|
tool_name="chat",
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
)
|
)
|
||||||
logger.info(
|
|
||||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Background task that runs the AI generation independently of SSE connection
|
# Background task that runs the AI generation independently of SSE connection
|
||||||
async def run_ai_generation():
|
async def run_ai_generation():
|
||||||
import time as time_module
|
|
||||||
|
|
||||||
gen_start_time = time_module.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
first_chunk_time, ttfc = None, None
|
|
||||||
chunk_count = 0
|
|
||||||
try:
|
try:
|
||||||
|
# Emit a start event with task_id for reconnection
|
||||||
|
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||||
|
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||||
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
session_id,
|
session_id,
|
||||||
request.message,
|
request.message,
|
||||||
@@ -355,79 +295,25 @@ async def stream_chat_post(
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
context=request.context,
|
context=request.context,
|
||||||
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
|
|
||||||
):
|
):
|
||||||
chunk_count += 1
|
|
||||||
if first_chunk_time is None:
|
|
||||||
first_chunk_time = time_module.perf_counter()
|
|
||||||
ttfc = first_chunk_time - gen_start_time
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunk_type": type(chunk).__name__,
|
|
||||||
"time_to_first_chunk_ms": ttfc * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Write to Redis (subscribers will receive via XREAD)
|
# Write to Redis (subscribers will receive via XREAD)
|
||||||
await stream_registry.publish_chunk(task_id, chunk)
|
await stream_registry.publish_chunk(task_id, chunk)
|
||||||
|
|
||||||
gen_end_time = time_module.perf_counter()
|
# Mark task as completed
|
||||||
total_time = (gen_end_time - gen_start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
|
||||||
f"task={task_id}, session={session_id}, "
|
|
||||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"time_to_first_chunk_ms": (
|
|
||||||
ttfc * 1000 if ttfc is not None else None
|
|
||||||
),
|
|
||||||
"n_chunks": chunk_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await stream_registry.mark_task_completed(task_id, "completed")
|
await stream_registry.mark_task_completed(task_id, "completed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed = time_module.perf_counter() - gen_start_time
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
f"Error in background AI generation for session {session_id}: {e}"
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed * 1000,
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
await stream_registry.mark_task_completed(task_id, "failed")
|
await stream_registry.mark_task_completed(task_id, "failed")
|
||||||
|
|
||||||
# Start the AI generation in a background task
|
# Start the AI generation in a background task
|
||||||
bg_task = asyncio.create_task(run_ai_generation())
|
bg_task = asyncio.create_task(run_ai_generation())
|
||||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
# SSE endpoint that subscribes to the task's stream
|
# SSE endpoint that subscribes to the task's stream
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
import time as time_module
|
|
||||||
|
|
||||||
event_gen_start = time_module.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
|
||||||
f"user={user_id}",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
subscriber_queue = None
|
subscriber_queue = None
|
||||||
first_chunk_yielded = False
|
|
||||||
chunks_yielded = 0
|
|
||||||
try:
|
try:
|
||||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
@@ -442,70 +328,22 @@ async def stream_chat_post(
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Read from the subscriber queue and yield to SSE
|
# Read from the subscriber queue and yield to SSE
|
||||||
logger.info(
|
|
||||||
"[TIMING] Starting to read from subscriber_queue",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||||
chunks_yielded += 1
|
|
||||||
|
|
||||||
if not first_chunk_yielded:
|
|
||||||
first_chunk_yielded = True
|
|
||||||
elapsed = time_module.perf_counter() - event_gen_start
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
|
|
||||||
f"type={type(chunk).__name__}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunk_type": type(chunk).__name__,
|
|
||||||
"elapsed_ms": elapsed * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
yield chunk.to_sse()
|
yield chunk.to_sse()
|
||||||
|
|
||||||
# Check for finish signal
|
# Check for finish signal
|
||||||
if isinstance(chunk, StreamFinish):
|
if isinstance(chunk, StreamFinish):
|
||||||
total_time = time_module.perf_counter() - event_gen_start
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
|
|
||||||
f"n_chunks={chunks_yielded}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunks_yielded": chunks_yielded,
|
|
||||||
"total_time_ms": total_time * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
# Send heartbeat to keep connection alive
|
||||||
yield StreamHeartbeat().to_sse()
|
yield StreamHeartbeat().to_sse()
|
||||||
|
|
||||||
except GeneratorExit:
|
except GeneratorExit:
|
||||||
logger.info(
|
|
||||||
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunks_yielded": chunks_yielded,
|
|
||||||
"reason": "client_disconnect",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
pass # Client disconnected - background task continues
|
pass # Client disconnected - background task continues
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
||||||
logger.error(
|
|
||||||
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||||
if subscriber_queue is not None:
|
if subscriber_queue is not None:
|
||||||
@@ -519,18 +357,6 @@ async def stream_chat_post(
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||||
total_time = time_module.perf_counter() - event_gen_start
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
|
||||||
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time * 1000,
|
|
||||||
"chunks_yielded": chunks_yielded,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
@@ -548,90 +374,63 @@ async def stream_chat_post(
|
|||||||
@router.get(
|
@router.get(
|
||||||
"/sessions/{session_id}/stream",
|
"/sessions/{session_id}/stream",
|
||||||
)
|
)
|
||||||
async def resume_session_stream(
|
async def stream_chat_get(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
||||||
user_id: str | None = Depends(auth.get_user_id),
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
|
is_user_message: bool = Query(default=True),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Resume an active stream for a session.
|
Stream chat responses for a session (GET - legacy endpoint).
|
||||||
|
|
||||||
Called by the AI SDK's ``useChat(resume: true)`` on page load.
|
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||||
Checks for an active (in-progress) task on the session and either replays
|
- Text fragments as they are generated
|
||||||
the full SSE stream or returns 204 No Content if nothing is running.
|
- Tool call UI elements (if invoked)
|
||||||
|
- Tool execution results
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier.
|
session_id: The chat session identifier to associate with the streamed messages.
|
||||||
|
message: The user's new message to process.
|
||||||
user_id: Optional authenticated user ID.
|
user_id: Optional authenticated user ID.
|
||||||
|
is_user_message: Whether the message is a user message.
|
||||||
Returns:
|
Returns:
|
||||||
StreamingResponse (SSE) when an active stream exists,
|
StreamingResponse: SSE-formatted response chunks.
|
||||||
or 204 No Content when there is nothing to resume.
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
active_task, _last_id = await stream_registry.get_active_task_for_session(
|
|
||||||
session_id, user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not active_task:
|
|
||||||
return Response(status_code=204)
|
|
||||||
|
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
|
||||||
task_id=active_task.task_id,
|
|
||||||
user_id=user_id,
|
|
||||||
last_message_id="0-0", # Full replay so useChat rebuilds the message
|
|
||||||
)
|
|
||||||
|
|
||||||
if subscriber_queue is None:
|
|
||||||
return Response(status_code=204)
|
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
first_chunk_type: str | None = None
|
first_chunk_type: str | None = None
|
||||||
try:
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
while True:
|
session_id,
|
||||||
try:
|
message,
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
is_user_message=is_user_message,
|
||||||
if chunk_count < 3:
|
user_id=user_id,
|
||||||
logger.info(
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
"Resume stream chunk",
|
):
|
||||||
extra={
|
if chunk_count < 3:
|
||||||
"session_id": session_id,
|
logger.info(
|
||||||
"chunk_type": str(chunk.type),
|
"Chat stream chunk",
|
||||||
},
|
extra={
|
||||||
)
|
"session_id": session_id,
|
||||||
if not first_chunk_type:
|
"chunk_type": str(chunk.type),
|
||||||
first_chunk_type = str(chunk.type)
|
},
|
||||||
chunk_count += 1
|
|
||||||
yield chunk.to_sse()
|
|
||||||
|
|
||||||
if isinstance(chunk, StreamFinish):
|
|
||||||
break
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
yield StreamHeartbeat().to_sse()
|
|
||||||
except GeneratorExit:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
await stream_registry.unsubscribe_from_task(
|
|
||||||
active_task.task_id, subscriber_queue
|
|
||||||
)
|
)
|
||||||
except Exception as unsub_err:
|
if not first_chunk_type:
|
||||||
logger.error(
|
first_chunk_type = str(chunk.type)
|
||||||
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
|
chunk_count += 1
|
||||||
exc_info=True,
|
yield chunk.to_sse()
|
||||||
)
|
logger.info(
|
||||||
logger.info(
|
"Chat stream completed",
|
||||||
"Resume stream completed",
|
extra={
|
||||||
extra={
|
"session_id": session_id,
|
||||||
"session_id": session_id,
|
"chunk_count": chunk_count,
|
||||||
"n_chunks": chunk_count,
|
"first_chunk_type": first_chunk_type,
|
||||||
"first_chunk_type": first_chunk_type,
|
},
|
||||||
},
|
)
|
||||||
)
|
# AI SDK protocol termination
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
event_generator(),
|
||||||
@@ -639,8 +438,8 @@ async def resume_session_stream(
|
|||||||
headers={
|
headers={
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
"X-Accel-Buffering": "no",
|
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||||
"x-vercel-ai-ui-message-stream": "v1",
|
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -952,42 +751,3 @@ async def health_check() -> dict:
|
|||||||
"service": "chat",
|
"service": "chat",
|
||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
|
|
||||||
|
|
||||||
ToolResponseUnion = (
|
|
||||||
AgentsFoundResponse
|
|
||||||
| NoResultsResponse
|
|
||||||
| AgentDetailsResponse
|
|
||||||
| SetupRequirementsResponse
|
|
||||||
| ExecutionStartedResponse
|
|
||||||
| NeedLoginResponse
|
|
||||||
| ErrorResponse
|
|
||||||
| InputValidationErrorResponse
|
|
||||||
| AgentOutputResponse
|
|
||||||
| UnderstandingUpdatedResponse
|
|
||||||
| AgentPreviewResponse
|
|
||||||
| AgentSavedResponse
|
|
||||||
| ClarificationNeededResponse
|
|
||||||
| BlockListResponse
|
|
||||||
| BlockOutputResponse
|
|
||||||
| DocSearchResultsResponse
|
|
||||||
| DocPageResponse
|
|
||||||
| OperationStartedResponse
|
|
||||||
| OperationPendingResponse
|
|
||||||
| OperationInProgressResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/schema/tool-responses",
|
|
||||||
response_model=ToolResponseUnion,
|
|
||||||
include_in_schema=True,
|
|
||||||
summary="[Dummy] Tool response type export for codegen",
|
|
||||||
description="This endpoint is not meant to be called. It exists solely to "
|
|
||||||
"expose tool response models in the OpenAPI schema for frontend codegen.",
|
|
||||||
)
|
|
||||||
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
|
|
||||||
"""Never called at runtime. Exists only so Orval generates TS types."""
|
|
||||||
raise HTTPException(status_code=501, detail="Schema-only endpoint")
|
|
||||||
|
|||||||
@@ -52,10 +52,8 @@ from .response_model import (
|
|||||||
StreamBaseResponse,
|
StreamBaseResponse,
|
||||||
StreamError,
|
StreamError,
|
||||||
StreamFinish,
|
StreamFinish,
|
||||||
StreamFinishStep,
|
|
||||||
StreamHeartbeat,
|
StreamHeartbeat,
|
||||||
StreamStart,
|
StreamStart,
|
||||||
StreamStartStep,
|
|
||||||
StreamTextDelta,
|
StreamTextDelta,
|
||||||
StreamTextEnd,
|
StreamTextEnd,
|
||||||
StreamTextStart,
|
StreamTextStart,
|
||||||
@@ -353,10 +351,6 @@ async def stream_chat_completion(
|
|||||||
retry_count: int = 0,
|
retry_count: int = 0,
|
||||||
session: ChatSession | None = None,
|
session: ChatSession | None = None,
|
||||||
context: dict[str, str] | None = None, # {url: str, content: str}
|
context: dict[str, str] | None = None, # {url: str, content: str}
|
||||||
_continuation_message_id: (
|
|
||||||
str | None
|
|
||||||
) = None, # Internal: reuse message ID for tool call continuations
|
|
||||||
_task_id: str | None = None, # Internal: task ID for SSE reconnection support
|
|
||||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
"""Main entry point for streaming chat completions with database handling.
|
"""Main entry point for streaming chat completions with database handling.
|
||||||
|
|
||||||
@@ -377,45 +371,21 @@ async def stream_chat_completion(
|
|||||||
ValueError: If max_context_messages is exceeded
|
ValueError: If max_context_messages is exceeded
|
||||||
|
|
||||||
"""
|
"""
|
||||||
completion_start = time.monotonic()
|
|
||||||
|
|
||||||
# Build log metadata for structured logging
|
|
||||||
log_meta = {"component": "ChatService", "session_id": session_id}
|
|
||||||
if user_id:
|
|
||||||
log_meta["user_id"] = user_id
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
|
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
||||||
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"message_len": len(message) if message else 0,
|
|
||||||
"is_user_message": is_user_message,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only fetch from Redis if session not provided (initial call)
|
# Only fetch from Redis if session not provided (initial call)
|
||||||
if session is None:
|
if session is None:
|
||||||
fetch_start = time.monotonic()
|
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
fetch_time = (time.monotonic() - fetch_start) * 1000
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
|
f"Fetched session from Redis: {session.session_id if session else 'None'}, "
|
||||||
f"n_messages={len(session.messages) if session else 0}",
|
f"message_count={len(session.messages) if session else 0}"
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"duration_ms": fetch_time,
|
|
||||||
"n_messages": len(session.messages) if session else 0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] Using provided session, messages={len(session.messages)}",
|
f"Using provided session object: {session.session_id}, "
|
||||||
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
f"message_count={len(session.messages)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
@@ -436,25 +406,17 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
# Track user message in PostHog
|
# Track user message in PostHog
|
||||||
if is_user_message:
|
if is_user_message:
|
||||||
posthog_start = time.monotonic()
|
|
||||||
track_user_message(
|
track_user_message(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
message_length=len(message),
|
message_length=len(message),
|
||||||
)
|
)
|
||||||
posthog_time = (time.monotonic() - posthog_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
upsert_start = time.monotonic()
|
|
||||||
session = await upsert_chat_session(session)
|
|
||||||
upsert_time = (time.monotonic() - upsert_start) * 1000
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
f"Upserting session: {session.session_id} with user id {session.user_id}, "
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
|
f"message_count={len(session.messages)}"
|
||||||
)
|
)
|
||||||
|
session = await upsert_chat_session(session)
|
||||||
assert session, "Session not found"
|
assert session, "Session not found"
|
||||||
|
|
||||||
# Generate title for new sessions on first user message (non-blocking)
|
# Generate title for new sessions on first user message (non-blocking)
|
||||||
@@ -492,13 +454,7 @@ async def stream_chat_completion(
|
|||||||
asyncio.create_task(_update_title())
|
asyncio.create_task(_update_title())
|
||||||
|
|
||||||
# Build system prompt with business understanding
|
# Build system prompt with business understanding
|
||||||
prompt_start = time.monotonic()
|
|
||||||
system_prompt, understanding = await _build_system_prompt(user_id)
|
system_prompt, understanding = await _build_system_prompt(user_id)
|
||||||
prompt_time = (time.monotonic() - prompt_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize variables for streaming
|
# Initialize variables for streaming
|
||||||
assistant_response = ChatMessage(
|
assistant_response = ChatMessage(
|
||||||
@@ -523,27 +479,13 @@ async def stream_chat_completion(
|
|||||||
# Generate unique IDs for AI SDK protocol
|
# Generate unique IDs for AI SDK protocol
|
||||||
import uuid as uuid_module
|
import uuid as uuid_module
|
||||||
|
|
||||||
is_continuation = _continuation_message_id is not None
|
message_id = str(uuid_module.uuid4())
|
||||||
message_id = _continuation_message_id or str(uuid_module.uuid4())
|
|
||||||
text_block_id = str(uuid_module.uuid4())
|
text_block_id = str(uuid_module.uuid4())
|
||||||
|
|
||||||
# Only yield message start for the initial call, not for continuations.
|
# Yield message start
|
||||||
setup_time = (time.monotonic() - completion_start) * 1000
|
yield StreamStart(messageId=message_id)
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
|
||||||
)
|
|
||||||
if not is_continuation:
|
|
||||||
yield StreamStart(messageId=message_id, taskId=_task_id)
|
|
||||||
|
|
||||||
# Emit start-step before each LLM call (AI SDK uses this to add step boundaries)
|
|
||||||
yield StreamStartStep()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(
|
|
||||||
"[TIMING] Calling _stream_chat_chunks",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
async for chunk in _stream_chat_chunks(
|
async for chunk in _stream_chat_chunks(
|
||||||
session=session,
|
session=session,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@@ -643,10 +585,6 @@ async def stream_chat_completion(
|
|||||||
)
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
elif isinstance(chunk, StreamFinish):
|
elif isinstance(chunk, StreamFinish):
|
||||||
if has_done_tool_call:
|
|
||||||
# Tool calls happened — close the step but don't send message-level finish.
|
|
||||||
# The continuation will open a new step, and finish will come at the end.
|
|
||||||
yield StreamFinishStep()
|
|
||||||
if not has_done_tool_call:
|
if not has_done_tool_call:
|
||||||
# Emit text-end before finish if we received text but haven't closed it
|
# Emit text-end before finish if we received text but haven't closed it
|
||||||
if has_received_text and not text_streaming_ended:
|
if has_received_text and not text_streaming_ended:
|
||||||
@@ -678,8 +616,6 @@ async def stream_chat_completion(
|
|||||||
has_saved_assistant_message = True
|
has_saved_assistant_message = True
|
||||||
|
|
||||||
has_yielded_end = True
|
has_yielded_end = True
|
||||||
# Emit finish-step before finish (resets AI SDK text/reasoning state)
|
|
||||||
yield StreamFinishStep()
|
|
||||||
yield chunk
|
yield chunk
|
||||||
elif isinstance(chunk, StreamError):
|
elif isinstance(chunk, StreamError):
|
||||||
has_yielded_error = True
|
has_yielded_error = True
|
||||||
@@ -729,10 +665,6 @@ async def stream_chat_completion(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
||||||
)
|
)
|
||||||
# Close the current step before retrying so the recursive call's
|
|
||||||
# StreamStartStep doesn't produce unbalanced step events.
|
|
||||||
if not has_yielded_end:
|
|
||||||
yield StreamFinishStep()
|
|
||||||
should_retry = True
|
should_retry = True
|
||||||
else:
|
else:
|
||||||
# Non-retryable error or max retries exceeded
|
# Non-retryable error or max retries exceeded
|
||||||
@@ -768,7 +700,6 @@ async def stream_chat_completion(
|
|||||||
error_response = StreamError(errorText=error_message)
|
error_response = StreamError(errorText=error_message)
|
||||||
yield error_response
|
yield error_response
|
||||||
if not has_yielded_end:
|
if not has_yielded_end:
|
||||||
yield StreamFinishStep()
|
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -783,8 +714,6 @@ async def stream_chat_completion(
|
|||||||
retry_count=retry_count + 1,
|
retry_count=retry_count + 1,
|
||||||
session=session,
|
session=session,
|
||||||
context=context,
|
context=context,
|
||||||
_continuation_message_id=message_id, # Reuse message ID since start was already sent
|
|
||||||
_task_id=_task_id,
|
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
return # Exit after retry to avoid double-saving in finally block
|
return # Exit after retry to avoid double-saving in finally block
|
||||||
@@ -854,8 +783,6 @@ async def stream_chat_completion(
|
|||||||
session=session, # Pass session object to avoid Redis refetch
|
session=session, # Pass session object to avoid Redis refetch
|
||||||
context=context,
|
context=context,
|
||||||
tool_call_response=str(tool_response_messages),
|
tool_call_response=str(tool_response_messages),
|
||||||
_continuation_message_id=message_id, # Reuse message ID to avoid duplicates
|
|
||||||
_task_id=_task_id,
|
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
@@ -966,21 +893,9 @@ async def _stream_chat_chunks(
|
|||||||
SSE formatted JSON response objects
|
SSE formatted JSON response objects
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import time as time_module
|
|
||||||
|
|
||||||
stream_chunks_start = time_module.perf_counter()
|
|
||||||
model = config.model
|
model = config.model
|
||||||
|
|
||||||
# Build log metadata for structured logging
|
logger.info("Starting pure chat stream")
|
||||||
log_meta = {"component": "ChatService", "session_id": session.session_id}
|
|
||||||
if session.user_id:
|
|
||||||
log_meta["user_id"] = session.user_id
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
|
|
||||||
f"user={session.user_id}, n_messages={len(session.messages)}",
|
|
||||||
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = session.to_openai_messages()
|
messages = session.to_openai_messages()
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
@@ -991,18 +906,12 @@ async def _stream_chat_chunks(
|
|||||||
messages = [system_message] + messages
|
messages = [system_message] + messages
|
||||||
|
|
||||||
# Apply context window management
|
# Apply context window management
|
||||||
context_start = time_module.perf_counter()
|
|
||||||
context_result = await _manage_context_window(
|
context_result = await _manage_context_window(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=model,
|
model=model,
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
base_url=config.base_url,
|
base_url=config.base_url,
|
||||||
)
|
)
|
||||||
context_time = (time_module.perf_counter() - context_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
if context_result.error:
|
if context_result.error:
|
||||||
if "System prompt dropped" in context_result.error:
|
if "System prompt dropped" in context_result.error:
|
||||||
@@ -1037,19 +946,9 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
while retry_count <= MAX_RETRIES:
|
while retry_count <= MAX_RETRIES:
|
||||||
try:
|
try:
|
||||||
elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000
|
|
||||||
retry_info = (
|
|
||||||
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
|
|
||||||
)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
|
f"Creating OpenAI chat completion stream..."
|
||||||
extra={
|
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed,
|
|
||||||
"retry_count": retry_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build extra_body for OpenRouter tracing and PostHog analytics
|
# Build extra_body for OpenRouter tracing and PostHog analytics
|
||||||
@@ -1066,11 +965,6 @@ async def _stream_chat_chunks(
|
|||||||
:128
|
:128
|
||||||
] # OpenRouter limit
|
] # OpenRouter limit
|
||||||
|
|
||||||
# Enable adaptive thinking for Anthropic models via OpenRouter
|
|
||||||
if config.thinking_enabled and "anthropic" in model.lower():
|
|
||||||
extra_body["reasoning"] = {"enabled": True}
|
|
||||||
|
|
||||||
api_call_start = time_module.perf_counter()
|
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
@@ -1080,11 +974,6 @@ async def _stream_chat_chunks(
|
|||||||
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Variables to accumulate tool calls
|
# Variables to accumulate tool calls
|
||||||
tool_calls: list[dict[str, Any]] = []
|
tool_calls: list[dict[str, Any]] = []
|
||||||
@@ -1095,13 +984,10 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
# Track if we've started the text block
|
# Track if we've started the text block
|
||||||
text_started = False
|
text_started = False
|
||||||
first_content_chunk = True
|
|
||||||
chunk_count = 0
|
|
||||||
|
|
||||||
# Process the stream
|
# Process the stream
|
||||||
chunk: ChatCompletionChunk
|
chunk: ChatCompletionChunk
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
chunk_count += 1
|
|
||||||
if chunk.usage:
|
if chunk.usage:
|
||||||
yield StreamUsage(
|
yield StreamUsage(
|
||||||
promptTokens=chunk.usage.prompt_tokens,
|
promptTokens=chunk.usage.prompt_tokens,
|
||||||
@@ -1124,23 +1010,6 @@ async def _stream_chat_chunks(
|
|||||||
if not text_started and text_block_id:
|
if not text_started and text_block_id:
|
||||||
yield StreamTextStart(id=text_block_id)
|
yield StreamTextStart(id=text_block_id)
|
||||||
text_started = True
|
text_started = True
|
||||||
# Log timing for first content chunk
|
|
||||||
if first_content_chunk:
|
|
||||||
first_content_chunk = False
|
|
||||||
ttfc = (
|
|
||||||
time_module.perf_counter() - api_call_start
|
|
||||||
) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
|
|
||||||
f"(since API call), n_chunks={chunk_count}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"time_to_first_chunk_ms": ttfc,
|
|
||||||
"n_chunks": chunk_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Stream the text delta
|
# Stream the text delta
|
||||||
text_response = StreamTextDelta(
|
text_response = StreamTextDelta(
|
||||||
id=text_block_id or "",
|
id=text_block_id or "",
|
||||||
@@ -1197,21 +1066,7 @@ async def _stream_chat_chunks(
|
|||||||
toolName=tool_calls[idx]["function"]["name"],
|
toolName=tool_calls[idx]["function"]["name"],
|
||||||
)
|
)
|
||||||
emitted_start_for_idx.add(idx)
|
emitted_start_for_idx.add(idx)
|
||||||
stream_duration = time_module.perf_counter() - api_call_start
|
logger.info(f"Stream complete. Finish reason: {finish_reason}")
|
||||||
logger.info(
|
|
||||||
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
|
|
||||||
f"duration={stream_duration:.2f}s, "
|
|
||||||
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"stream_duration_ms": stream_duration * 1000,
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
"n_chunks": chunk_count,
|
|
||||||
"n_tool_calls": len(tool_calls),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Yield all accumulated tool calls after the stream is complete
|
# Yield all accumulated tool calls after the stream is complete
|
||||||
# This ensures all tool call arguments have been fully received
|
# This ensures all tool call arguments have been fully received
|
||||||
@@ -1231,12 +1086,6 @@ async def _stream_chat_chunks(
|
|||||||
# Re-raise to trigger retry logic in the parent function
|
# Re-raise to trigger retry logic in the parent function
|
||||||
raise
|
raise
|
||||||
|
|
||||||
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
|
|
||||||
f"session={session.session_id}, user={session.user_id}",
|
|
||||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
|
||||||
)
|
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1716,7 +1565,6 @@ async def _execute_long_running_tool_with_streaming(
|
|||||||
task_id,
|
task_id,
|
||||||
StreamError(errorText=str(e)),
|
StreamError(errorText=str(e)),
|
||||||
)
|
)
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||||
|
|
||||||
await _update_pending_operation(
|
await _update_pending_operation(
|
||||||
@@ -1833,10 +1681,6 @@ async def _generate_llm_continuation(
|
|||||||
if session_id:
|
if session_id:
|
||||||
extra_body["session_id"] = session_id[:128]
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
# Enable adaptive thinking for Anthropic models via OpenRouter
|
|
||||||
if config.thinking_enabled and "anthropic" in config.model.lower():
|
|
||||||
extra_body["reasoning"] = {"enabled": True}
|
|
||||||
|
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
last_error: Exception | None = None
|
last_error: Exception | None = None
|
||||||
response = None
|
response = None
|
||||||
@@ -1967,10 +1811,6 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
if session_id:
|
if session_id:
|
||||||
extra_body["session_id"] = session_id[:128]
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
# Enable adaptive thinking for Anthropic models via OpenRouter
|
|
||||||
if config.thinking_enabled and "anthropic" in config.model.lower():
|
|
||||||
extra_body["reasoning"] = {"enabled": True}
|
|
||||||
|
|
||||||
# Make streaming LLM call (no tools - just text response)
|
# Make streaming LLM call (no tools - just text response)
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
@@ -1982,7 +1822,6 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
|
|
||||||
# Publish start event
|
# Publish start event
|
||||||
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
||||||
await stream_registry.publish_chunk(task_id, StreamStartStep())
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
||||||
|
|
||||||
# Stream the response
|
# Stream the response
|
||||||
@@ -2006,7 +1845,6 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
|
|
||||||
# Publish end events
|
# Publish end events
|
||||||
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
|
||||||
|
|
||||||
if assistant_content:
|
if assistant_content:
|
||||||
# Reload session from DB to avoid race condition with user messages
|
# Reload session from DB to avoid race condition with user messages
|
||||||
@@ -2048,5 +1886,4 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
task_id,
|
task_id,
|
||||||
StreamError(errorText=f"Failed to generate response: {e}"),
|
StreamError(errorText=f"Failed to generate response: {e}"),
|
||||||
)
|
)
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||||
|
|||||||
@@ -104,24 +104,6 @@ async def create_task(
|
|||||||
Returns:
|
Returns:
|
||||||
The created ActiveTask instance (metadata only)
|
The created ActiveTask instance (metadata only)
|
||||||
"""
|
"""
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
# Build log metadata for structured logging
|
|
||||||
log_meta = {
|
|
||||||
"component": "StreamRegistry",
|
|
||||||
"task_id": task_id,
|
|
||||||
"session_id": session_id,
|
|
||||||
}
|
|
||||||
if user_id:
|
|
||||||
log_meta["user_id"] = user_id
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
|
|
||||||
task = ActiveTask(
|
task = ActiveTask(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -132,18 +114,10 @@ async def create_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Store metadata in Redis
|
# Store metadata in Redis
|
||||||
redis_start = time.perf_counter()
|
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
redis_time = (time.perf_counter() - redis_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
meta_key = _get_task_meta_key(task_id)
|
||||||
op_key = _get_operation_mapping_key(operation_id)
|
op_key = _get_operation_mapping_key(operation_id)
|
||||||
|
|
||||||
hset_start = time.perf_counter()
|
|
||||||
await redis.hset( # type: ignore[misc]
|
await redis.hset( # type: ignore[misc]
|
||||||
meta_key,
|
meta_key,
|
||||||
mapping={
|
mapping={
|
||||||
@@ -157,22 +131,12 @@ async def create_task(
|
|||||||
"created_at": task.created_at.isoformat(),
|
"created_at": task.created_at.isoformat(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
hset_time = (time.perf_counter() - hset_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] redis.hset took {hset_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
await redis.expire(meta_key, config.stream_ttl)
|
await redis.expire(meta_key, config.stream_ttl)
|
||||||
|
|
||||||
# Create operation_id -> task_id mapping for webhook lookups
|
# Create operation_id -> task_id mapping for webhook lookups
|
||||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||||
|
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
logger.debug(f"Created task {task_id} for session {session_id}")
|
||||||
logger.info(
|
|
||||||
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
|
|
||||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
@@ -192,60 +156,26 @@ async def publish_chunk(
|
|||||||
Returns:
|
Returns:
|
||||||
The Redis Stream message ID
|
The Redis Stream message ID
|
||||||
"""
|
"""
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
chunk_type = type(chunk).__name__
|
|
||||||
chunk_json = chunk.model_dump_json()
|
chunk_json = chunk.model_dump_json()
|
||||||
message_id = "0-0"
|
message_id = "0-0"
|
||||||
|
|
||||||
# Build log metadata
|
|
||||||
log_meta = {
|
|
||||||
"component": "StreamRegistry",
|
|
||||||
"task_id": task_id,
|
|
||||||
"chunk_type": chunk_type,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
# Write to Redis Stream for persistence and real-time delivery
|
# Write to Redis Stream for persistence and real-time delivery
|
||||||
xadd_start = time.perf_counter()
|
|
||||||
raw_id = await redis.xadd(
|
raw_id = await redis.xadd(
|
||||||
stream_key,
|
stream_key,
|
||||||
{"data": chunk_json},
|
{"data": chunk_json},
|
||||||
maxlen=config.stream_max_length,
|
maxlen=config.stream_max_length,
|
||||||
)
|
)
|
||||||
xadd_time = (time.perf_counter() - xadd_start) * 1000
|
|
||||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||||
|
|
||||||
# Set TTL on stream to match task metadata TTL
|
# Set TTL on stream to match task metadata TTL
|
||||||
await redis.expire(stream_key, config.stream_ttl)
|
await redis.expire(stream_key, config.stream_ttl)
|
||||||
|
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
|
||||||
# Only log timing for significant chunks or slow operations
|
|
||||||
if (
|
|
||||||
chunk_type
|
|
||||||
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
|
|
||||||
or total_time > 50
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"xadd_time_ms": xadd_time,
|
|
||||||
"message_id": message_id,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed = (time.perf_counter() - start_time) * 1000
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}",
|
f"Failed to publish chunk for task {task_id}: {e}",
|
||||||
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -270,61 +200,24 @@ async def subscribe_to_task(
|
|||||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||||
or user doesn't have access
|
or user doesn't have access
|
||||||
"""
|
"""
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
# Build log metadata
|
|
||||||
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
|
||||||
if user_id:
|
|
||||||
log_meta["user_id"] = user_id
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
|
|
||||||
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
|
|
||||||
)
|
|
||||||
|
|
||||||
redis_start = time.perf_counter()
|
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
meta_key = _get_task_meta_key(task_id)
|
meta_key = _get_task_meta_key(task_id)
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
hgetall_time = (time.perf_counter() - redis_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
if not meta:
|
if not meta:
|
||||||
elapsed = (time.perf_counter() - start_time) * 1000
|
logger.debug(f"Task {task_id} not found in Redis")
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed,
|
|
||||||
"reason": "task_not_found",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||||
task_status = meta.get("status", "")
|
task_status = meta.get("status", "")
|
||||||
task_user_id = meta.get("user_id", "") or None
|
task_user_id = meta.get("user_id", "") or None
|
||||||
log_meta["session_id"] = meta.get("session_id", "")
|
|
||||||
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
# Validate ownership - if task has an owner, requester must match
|
||||||
if task_user_id:
|
if task_user_id:
|
||||||
if user_id != task_user_id:
|
if user_id != task_user_id:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}",
|
f"User {user_id} denied access to task {task_id} "
|
||||||
extra={
|
f"owned by {task_user_id}"
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"task_owner": task_user_id,
|
|
||||||
"reason": "access_denied",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -332,19 +225,7 @@ async def subscribe_to_task(
|
|||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
# Step 1: Replay messages from Redis Stream
|
# Step 1: Replay messages from Redis Stream
|
||||||
xread_start = time.perf_counter()
|
|
||||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"duration_ms": xread_time,
|
|
||||||
"task_status": task_status,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
replayed_count = 0
|
replayed_count = 0
|
||||||
replay_last_id = last_message_id
|
replay_last_id = last_message_id
|
||||||
@@ -363,48 +244,19 @@ async def subscribe_to_task(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to replay message: {e}")
|
logger.warning(f"Failed to replay message: {e}")
|
||||||
|
|
||||||
logger.info(
|
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
|
||||||
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"n_messages_replayed": replayed_count,
|
|
||||||
"replay_last_id": replay_last_id,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 2: If task is still running, start stream listener for live updates
|
# Step 2: If task is still running, start stream listener for live updates
|
||||||
if task_status == "running":
|
if task_status == "running":
|
||||||
logger.info(
|
|
||||||
"[TIMING] Task still running, starting _stream_listener",
|
|
||||||
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
|
||||||
)
|
|
||||||
listener_task = asyncio.create_task(
|
listener_task = asyncio.create_task(
|
||||||
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
|
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
||||||
)
|
)
|
||||||
# Track listener task for cleanup on unsubscribe
|
# Track listener task for cleanup on unsubscribe
|
||||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||||
else:
|
else:
|
||||||
# Task is completed/failed - add finish marker
|
# Task is completed/failed - add finish marker
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Task already {task_status}, adding StreamFinish",
|
|
||||||
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
|
||||||
)
|
|
||||||
await subscriber_queue.put(StreamFinish())
|
await subscriber_queue.put(StreamFinish())
|
||||||
|
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
|
|
||||||
f"n_messages_replayed={replayed_count}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"n_messages_replayed": replayed_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return subscriber_queue
|
return subscriber_queue
|
||||||
|
|
||||||
|
|
||||||
@@ -412,7 +264,6 @@ async def _stream_listener(
|
|||||||
task_id: str,
|
task_id: str,
|
||||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||||
last_replayed_id: str,
|
last_replayed_id: str,
|
||||||
log_meta: dict | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||||
|
|
||||||
@@ -423,27 +274,10 @@ async def _stream_listener(
|
|||||||
task_id: Task ID to listen for
|
task_id: Task ID to listen for
|
||||||
subscriber_queue: Queue to deliver messages to
|
subscriber_queue: Queue to deliver messages to
|
||||||
last_replayed_id: Last message ID from replay (continue from here)
|
last_replayed_id: Last message ID from replay (continue from here)
|
||||||
log_meta: Structured logging metadata
|
|
||||||
"""
|
"""
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
# Use provided log_meta or build minimal one
|
|
||||||
if log_meta is None:
|
|
||||||
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
|
|
||||||
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
|
|
||||||
)
|
|
||||||
|
|
||||||
queue_id = id(subscriber_queue)
|
queue_id = id(subscriber_queue)
|
||||||
# Track the last successfully delivered message ID for recovery hints
|
# Track the last successfully delivered message ID for recovery hints
|
||||||
last_delivered_id = last_replayed_id
|
last_delivered_id = last_replayed_id
|
||||||
messages_delivered = 0
|
|
||||||
first_message_time = None
|
|
||||||
xread_count = 0
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
@@ -453,39 +287,9 @@ async def _stream_listener(
|
|||||||
while True:
|
while True:
|
||||||
# Block for up to 30 seconds waiting for new messages
|
# Block for up to 30 seconds waiting for new messages
|
||||||
# This allows periodic checking if task is still running
|
# This allows periodic checking if task is still running
|
||||||
xread_start = time.perf_counter()
|
|
||||||
xread_count += 1
|
|
||||||
messages = await redis.xread(
|
messages = await redis.xread(
|
||||||
{stream_key: current_id}, block=30000, count=100
|
{stream_key: current_id}, block=30000, count=100
|
||||||
)
|
)
|
||||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
|
||||||
|
|
||||||
if messages:
|
|
||||||
msg_count = sum(len(msgs) for _, msgs in messages)
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"xread_count": xread_count,
|
|
||||||
"n_messages": msg_count,
|
|
||||||
"duration_ms": xread_time,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
elif xread_time > 1000:
|
|
||||||
# Only log timeouts (30s blocking)
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"xread_count": xread_count,
|
|
||||||
"duration_ms": xread_time,
|
|
||||||
"reason": "timeout",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if not messages:
|
if not messages:
|
||||||
# Timeout - check if task is still running
|
# Timeout - check if task is still running
|
||||||
@@ -522,30 +326,10 @@ async def _stream_listener(
|
|||||||
)
|
)
|
||||||
# Update last delivered ID on successful delivery
|
# Update last delivered ID on successful delivery
|
||||||
last_delivered_id = current_id
|
last_delivered_id = current_id
|
||||||
messages_delivered += 1
|
|
||||||
if first_message_time is None:
|
|
||||||
first_message_time = time.perf_counter()
|
|
||||||
elapsed = (first_message_time - start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed,
|
|
||||||
"chunk_type": type(chunk).__name__,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
|
f"Subscriber queue full for task {task_id}, "
|
||||||
extra={
|
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"timeout_s": QUEUE_PUT_TIMEOUT,
|
|
||||||
"reason": "queue_full",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
# Send overflow error with recovery info
|
# Send overflow error with recovery info
|
||||||
try:
|
try:
|
||||||
@@ -567,44 +351,15 @@ async def _stream_listener(
|
|||||||
|
|
||||||
# Stop listening on finish
|
# Stop listening on finish
|
||||||
if isinstance(chunk, StreamFinish):
|
if isinstance(chunk, StreamFinish):
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"messages_delivered": messages_delivered,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(f"Error processing stream message: {e}")
|
||||||
f"Error processing stream message: {e}",
|
|
||||||
extra={"json_fields": {**log_meta, "error": str(e)}},
|
|
||||||
)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
elapsed = (time.perf_counter() - start_time) * 1000
|
logger.debug(f"Stream listener cancelled for task {task_id}")
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed,
|
|
||||||
"messages_delivered": messages_delivered,
|
|
||||||
"reason": "cancelled",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
raise # Re-raise to propagate cancellation
|
raise # Re-raise to propagate cancellation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed = (time.perf_counter() - start_time) * 1000
|
logger.error(f"Stream listener error for task {task_id}: {e}")
|
||||||
logger.error(
|
|
||||||
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
|
|
||||||
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
|
||||||
)
|
|
||||||
# On error, send finish to unblock subscriber
|
# On error, send finish to unblock subscriber
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
@@ -613,24 +368,10 @@ async def _stream_listener(
|
|||||||
)
|
)
|
||||||
except (asyncio.TimeoutError, asyncio.QueueFull):
|
except (asyncio.TimeoutError, asyncio.QueueFull):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Could not deliver finish event after error",
|
f"Could not deliver finish event for task {task_id} after error"
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
# Clean up listener task mapping on exit
|
# Clean up listener task mapping on exit
|
||||||
total_time = (time.perf_counter() - start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
|
|
||||||
f"delivered={messages_delivered}, xread_count={xread_count}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"messages_delivered": messages_delivered,
|
|
||||||
"xread_count": xread_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
_listener_tasks.pop(queue_id, None)
|
_listener_tasks.pop(queue_id, None)
|
||||||
|
|
||||||
|
|
||||||
@@ -857,10 +598,8 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
|||||||
ResponseType,
|
ResponseType,
|
||||||
StreamError,
|
StreamError,
|
||||||
StreamFinish,
|
StreamFinish,
|
||||||
StreamFinishStep,
|
|
||||||
StreamHeartbeat,
|
StreamHeartbeat,
|
||||||
StreamStart,
|
StreamStart,
|
||||||
StreamStartStep,
|
|
||||||
StreamTextDelta,
|
StreamTextDelta,
|
||||||
StreamTextEnd,
|
StreamTextEnd,
|
||||||
StreamTextStart,
|
StreamTextStart,
|
||||||
@@ -874,8 +613,6 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
|||||||
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
||||||
ResponseType.START.value: StreamStart,
|
ResponseType.START.value: StreamStart,
|
||||||
ResponseType.FINISH.value: StreamFinish,
|
ResponseType.FINISH.value: StreamFinish,
|
||||||
ResponseType.START_STEP.value: StreamStartStep,
|
|
||||||
ResponseType.FINISH_STEP.value: StreamFinishStep,
|
|
||||||
ResponseType.TEXT_START.value: StreamTextStart,
|
ResponseType.TEXT_START.value: StreamTextStart,
|
||||||
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
||||||
ResponseType.TEXT_END.value: StreamTextEnd,
|
ResponseType.TEXT_END.value: StreamTextEnd,
|
||||||
|
|||||||
@@ -13,32 +13,10 @@ from backend.api.features.chat.tools.models import (
|
|||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
from backend.data.block import BlockType, get_block
|
from backend.data.block import get_block
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_TARGET_RESULTS = 10
|
|
||||||
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
|
|
||||||
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
|
|
||||||
_OVERFETCH_PAGE_SIZE = 40
|
|
||||||
|
|
||||||
# Block types that only work within graphs and cannot run standalone in CoPilot.
|
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES = {
|
|
||||||
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
|
|
||||||
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
|
|
||||||
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
|
|
||||||
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
|
|
||||||
BlockType.NOTE, # Visual annotation only - no runtime behavior
|
|
||||||
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
|
|
||||||
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
|
|
||||||
}
|
|
||||||
|
|
||||||
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
|
||||||
COPILOT_EXCLUDED_BLOCK_IDS = {
|
|
||||||
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
|
|
||||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class FindBlockTool(BaseTool):
|
class FindBlockTool(BaseTool):
|
||||||
"""Tool for searching available blocks."""
|
"""Tool for searching available blocks."""
|
||||||
@@ -110,7 +88,7 @@ class FindBlockTool(BaseTool):
|
|||||||
query=query,
|
query=query,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
page_size=_OVERFETCH_PAGE_SIZE,
|
page_size=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
@@ -130,90 +108,60 @@ class FindBlockTool(BaseTool):
|
|||||||
block = get_block(block_id)
|
block = get_block(block_id)
|
||||||
|
|
||||||
# Skip disabled blocks
|
# Skip disabled blocks
|
||||||
if not block or block.disabled:
|
if block and not block.disabled:
|
||||||
continue
|
# Get input/output schemas
|
||||||
|
input_schema = {}
|
||||||
|
output_schema = {}
|
||||||
|
try:
|
||||||
|
input_schema = block.input_schema.jsonschema()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
output_schema = block.output_schema.jsonschema()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Skip blocks excluded from CoPilot (graph-only blocks)
|
# Get categories from block instance
|
||||||
if (
|
categories = []
|
||||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
if hasattr(block, "categories") and block.categories:
|
||||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
categories = [cat.value for cat in block.categories]
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get input/output schemas
|
# Extract required inputs for easier use
|
||||||
input_schema = {}
|
required_inputs: list[BlockInputFieldInfo] = []
|
||||||
output_schema = {}
|
if input_schema:
|
||||||
try:
|
properties = input_schema.get("properties", {})
|
||||||
input_schema = block.input_schema.jsonschema()
|
required_fields = set(input_schema.get("required", []))
|
||||||
except Exception as e:
|
# Get credential field names to exclude from required inputs
|
||||||
logger.debug(
|
credentials_fields = set(
|
||||||
"Failed to generate input schema for block %s: %s",
|
block.input_schema.get_credentials_fields().keys()
|
||||||
block_id,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
output_schema = block.output_schema.jsonschema()
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(
|
|
||||||
"Failed to generate output schema for block %s: %s",
|
|
||||||
block_id,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get categories from block instance
|
|
||||||
categories = []
|
|
||||||
if hasattr(block, "categories") and block.categories:
|
|
||||||
categories = [cat.value for cat in block.categories]
|
|
||||||
|
|
||||||
# Extract required inputs for easier use
|
|
||||||
required_inputs: list[BlockInputFieldInfo] = []
|
|
||||||
if input_schema:
|
|
||||||
properties = input_schema.get("properties", {})
|
|
||||||
required_fields = set(input_schema.get("required", []))
|
|
||||||
# Get credential field names to exclude from required inputs
|
|
||||||
credentials_fields = set(
|
|
||||||
block.input_schema.get_credentials_fields().keys()
|
|
||||||
)
|
|
||||||
|
|
||||||
for field_name, field_schema in properties.items():
|
|
||||||
# Skip credential fields - they're handled separately
|
|
||||||
if field_name in credentials_fields:
|
|
||||||
continue
|
|
||||||
|
|
||||||
required_inputs.append(
|
|
||||||
BlockInputFieldInfo(
|
|
||||||
name=field_name,
|
|
||||||
type=field_schema.get("type", "string"),
|
|
||||||
description=field_schema.get("description", ""),
|
|
||||||
required=field_name in required_fields,
|
|
||||||
default=field_schema.get("default"),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
blocks.append(
|
for field_name, field_schema in properties.items():
|
||||||
BlockInfoSummary(
|
# Skip credential fields - they're handled separately
|
||||||
id=block_id,
|
if field_name in credentials_fields:
|
||||||
name=block.name,
|
continue
|
||||||
description=block.description or "",
|
|
||||||
categories=categories,
|
required_inputs.append(
|
||||||
input_schema=input_schema,
|
BlockInputFieldInfo(
|
||||||
output_schema=output_schema,
|
name=field_name,
|
||||||
required_inputs=required_inputs,
|
type=field_schema.get("type", "string"),
|
||||||
|
description=field_schema.get("description", ""),
|
||||||
|
required=field_name in required_fields,
|
||||||
|
default=field_schema.get("default"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
blocks.append(
|
||||||
|
BlockInfoSummary(
|
||||||
|
id=block_id,
|
||||||
|
name=block.name,
|
||||||
|
description=block.description or "",
|
||||||
|
categories=categories,
|
||||||
|
input_schema=input_schema,
|
||||||
|
output_schema=output_schema,
|
||||||
|
required_inputs=required_inputs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if len(blocks) >= _TARGET_RESULTS:
|
|
||||||
break
|
|
||||||
|
|
||||||
if blocks and len(blocks) < _TARGET_RESULTS:
|
|
||||||
logger.debug(
|
|
||||||
"find_block returned %d/%d results for query '%s' "
|
|
||||||
"(filtered %d excluded/disabled blocks)",
|
|
||||||
len(blocks),
|
|
||||||
_TARGET_RESULTS,
|
|
||||||
query,
|
|
||||||
len(results) - len(blocks),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not blocks:
|
if not blocks:
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
|
|||||||
@@ -1,139 +0,0 @@
|
|||||||
"""Tests for block filtering in FindBlockTool."""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.api.features.chat.tools.find_block import (
|
|
||||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
|
||||||
FindBlockTool,
|
|
||||||
)
|
|
||||||
from backend.api.features.chat.tools.models import BlockListResponse
|
|
||||||
from backend.data.block import BlockType
|
|
||||||
|
|
||||||
from ._test_data import make_session
|
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-find-block"
|
|
||||||
|
|
||||||
|
|
||||||
def make_mock_block(
|
|
||||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
|
||||||
):
|
|
||||||
"""Create a mock block for testing."""
|
|
||||||
mock = MagicMock()
|
|
||||||
mock.id = block_id
|
|
||||||
mock.name = name
|
|
||||||
mock.description = f"{name} description"
|
|
||||||
mock.block_type = block_type
|
|
||||||
mock.disabled = disabled
|
|
||||||
mock.input_schema = MagicMock()
|
|
||||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
|
||||||
mock.input_schema.get_credentials_fields.return_value = {}
|
|
||||||
mock.output_schema = MagicMock()
|
|
||||||
mock.output_schema.jsonschema.return_value = {}
|
|
||||||
mock.categories = []
|
|
||||||
return mock
|
|
||||||
|
|
||||||
|
|
||||||
class TestFindBlockFiltering:
|
|
||||||
"""Tests for block filtering in FindBlockTool."""
|
|
||||||
|
|
||||||
def test_excluded_block_types_contains_expected_types(self):
|
|
||||||
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
|
|
||||||
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
|
|
||||||
def test_excluded_block_ids_contains_smart_decision_maker(self):
|
|
||||||
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
|
||||||
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_excluded_block_type_filtered_from_results(self):
|
|
||||||
"""Verify blocks with excluded BlockTypes are filtered from search results."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
|
|
||||||
search_results = [
|
|
||||||
{"content_id": "input-block-id", "score": 0.9},
|
|
||||||
{"content_id": "standard-block-id", "score": 0.8},
|
|
||||||
]
|
|
||||||
|
|
||||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
|
||||||
standard_block = make_mock_block(
|
|
||||||
"standard-block-id", "HTTP Request", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
def mock_get_block(block_id):
|
|
||||||
return {
|
|
||||||
"input-block-id": input_block,
|
|
||||||
"standard-block-id": standard_block,
|
|
||||||
}.get(block_id)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=(search_results, 2),
|
|
||||||
):
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.get_block",
|
|
||||||
side_effect=mock_get_block,
|
|
||||||
):
|
|
||||||
tool = FindBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID, session=session, query="test"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should only return the standard block, not the INPUT block
|
|
||||||
assert isinstance(response, BlockListResponse)
|
|
||||||
assert len(response.blocks) == 1
|
|
||||||
assert response.blocks[0].id == "standard-block-id"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_excluded_block_id_filtered_from_results(self):
|
|
||||||
"""Verify SmartDecisionMakerBlock is filtered from search results."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
|
||||||
search_results = [
|
|
||||||
{"content_id": smart_decision_id, "score": 0.9},
|
|
||||||
{"content_id": "normal-block-id", "score": 0.8},
|
|
||||||
]
|
|
||||||
|
|
||||||
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
|
|
||||||
smart_block = make_mock_block(
|
|
||||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
normal_block = make_mock_block(
|
|
||||||
"normal-block-id", "Normal Block", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
def mock_get_block(block_id):
|
|
||||||
return {
|
|
||||||
smart_decision_id: smart_block,
|
|
||||||
"normal-block-id": normal_block,
|
|
||||||
}.get(block_id)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=(search_results, 2),
|
|
||||||
):
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.get_block",
|
|
||||||
side_effect=mock_get_block,
|
|
||||||
):
|
|
||||||
tool = FindBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID, session=session, query="decision"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should only return normal block, not SmartDecisionMakerBlock
|
|
||||||
assert isinstance(response, BlockListResponse)
|
|
||||||
assert len(response.blocks) == 1
|
|
||||||
assert response.blocks[0].id == "normal-block-id"
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
"""Shared helpers for chat tools."""
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
def get_inputs_from_schema(
|
|
||||||
input_schema: dict[str, Any],
|
|
||||||
exclude_fields: set[str] | None = None,
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Extract input field info from JSON schema."""
|
|
||||||
if not isinstance(input_schema, dict):
|
|
||||||
return []
|
|
||||||
|
|
||||||
exclude = exclude_fields or set()
|
|
||||||
properties = input_schema.get("properties", {})
|
|
||||||
required = set(input_schema.get("required", []))
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"name": name,
|
|
||||||
"title": schema.get("title", name),
|
|
||||||
"type": schema.get("type", "string"),
|
|
||||||
"description": schema.get("description", ""),
|
|
||||||
"required": name in required,
|
|
||||||
"default": schema.get("default"),
|
|
||||||
}
|
|
||||||
for name, schema in properties.items()
|
|
||||||
if name not in exclude
|
|
||||||
]
|
|
||||||
@@ -24,7 +24,6 @@ from backend.util.timezone_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .helpers import get_inputs_from_schema
|
|
||||||
from .models import (
|
from .models import (
|
||||||
AgentDetails,
|
AgentDetails,
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
@@ -262,7 +261,7 @@ class RunAgentTool(BaseTool):
|
|||||||
),
|
),
|
||||||
requirements={
|
requirements={
|
||||||
"credentials": requirements_creds_list,
|
"credentials": requirements_creds_list,
|
||||||
"inputs": get_inputs_from_schema(graph.input_schema),
|
"inputs": self._get_inputs_list(graph.input_schema),
|
||||||
"execution_modes": self._get_execution_modes(graph),
|
"execution_modes": self._get_execution_modes(graph),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@@ -370,6 +369,22 @@ class RunAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
|
"""Extract inputs list from schema."""
|
||||||
|
inputs_list = []
|
||||||
|
if isinstance(input_schema, dict) and "properties" in input_schema:
|
||||||
|
for field_name, field_schema in input_schema["properties"].items():
|
||||||
|
inputs_list.append(
|
||||||
|
{
|
||||||
|
"name": field_name,
|
||||||
|
"title": field_schema.get("title", field_name),
|
||||||
|
"type": field_schema.get("type", "string"),
|
||||||
|
"description": field_schema.get("description", ""),
|
||||||
|
"required": field_name in input_schema.get("required", []),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return inputs_list
|
||||||
|
|
||||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||||
"""Get available execution modes for the graph."""
|
"""Get available execution modes for the graph."""
|
||||||
trigger_info = graph.trigger_setup_info
|
trigger_info = graph.trigger_setup_info
|
||||||
@@ -383,7 +398,7 @@ class RunAgentTool(BaseTool):
|
|||||||
suffix: str,
|
suffix: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build a message describing available inputs for an agent."""
|
"""Build a message describing available inputs for an agent."""
|
||||||
inputs_list = get_inputs_from_schema(graph.input_schema)
|
inputs_list = self._get_inputs_list(graph.input_schema)
|
||||||
required_names = [i["name"] for i in inputs_list if i["required"]]
|
required_names = [i["name"] for i in inputs_list if i["required"]]
|
||||||
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
||||||
|
|
||||||
|
|||||||
@@ -8,19 +8,14 @@ from typing import Any
|
|||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.api.features.chat.tools.find_block import (
|
from backend.data.block import get_block
|
||||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
|
||||||
)
|
|
||||||
from backend.data.block import AnyBlockSchema, get_block
|
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .helpers import get_inputs_from_schema
|
|
||||||
from .models import (
|
from .models import (
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -29,10 +24,7 @@ from .models import (
|
|||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
UserReadiness,
|
UserReadiness,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import build_missing_credentials_from_field_info
|
||||||
build_missing_credentials_from_field_info,
|
|
||||||
match_credentials_to_requirements,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -81,6 +73,91 @@ class RunBlockTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def _check_block_credentials(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
block: Any,
|
||||||
|
input_data: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Check if user has required credentials for a block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
block: Block to check credentials for
|
||||||
|
input_data: Input data for the block (used to determine provider via discriminator)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[matched_credentials, missing_credentials]
|
||||||
|
"""
|
||||||
|
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||||
|
missing_credentials: list[CredentialsMetaInput] = []
|
||||||
|
input_data = input_data or {}
|
||||||
|
|
||||||
|
# Get credential field info from block's input schema
|
||||||
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
|
|
||||||
|
if not credentials_fields_info:
|
||||||
|
return matched_credentials, missing_credentials
|
||||||
|
|
||||||
|
# Get user's available credentials
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
|
for field_name, field_info in credentials_fields_info.items():
|
||||||
|
effective_field_info = field_info
|
||||||
|
if field_info.discriminator and field_info.discriminator_mapping:
|
||||||
|
# Get discriminator from input, falling back to schema default
|
||||||
|
discriminator_value = input_data.get(field_info.discriminator)
|
||||||
|
if discriminator_value is None:
|
||||||
|
field = block.input_schema.model_fields.get(
|
||||||
|
field_info.discriminator
|
||||||
|
)
|
||||||
|
if field and field.default is not PydanticUndefined:
|
||||||
|
discriminator_value = field.default
|
||||||
|
|
||||||
|
if (
|
||||||
|
discriminator_value
|
||||||
|
and discriminator_value in field_info.discriminator_mapping
|
||||||
|
):
|
||||||
|
effective_field_info = field_info.discriminate(discriminator_value)
|
||||||
|
logger.debug(
|
||||||
|
f"Discriminated provider for {field_name}: "
|
||||||
|
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
matching_cred = next(
|
||||||
|
(
|
||||||
|
cred
|
||||||
|
for cred in available_creds
|
||||||
|
if cred.provider in effective_field_info.provider
|
||||||
|
and cred.type in effective_field_info.supported_types
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if matching_cred:
|
||||||
|
matched_credentials[field_name] = CredentialsMetaInput(
|
||||||
|
id=matching_cred.id,
|
||||||
|
provider=matching_cred.provider, # type: ignore
|
||||||
|
type=matching_cred.type,
|
||||||
|
title=matching_cred.title,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Create a placeholder for the missing credential
|
||||||
|
provider = next(iter(effective_field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
||||||
|
missing_credentials.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=field_name.replace("_", " ").title(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matched_credentials, missing_credentials
|
||||||
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
@@ -135,24 +212,11 @@ class RunBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if block is excluded from CoPilot (graph-only blocks)
|
|
||||||
if (
|
|
||||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
|
||||||
):
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Block '{block.name}' cannot be run directly in CoPilot. "
|
|
||||||
"This block is designed for use within graphs only."
|
|
||||||
),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||||
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
creds_manager = IntegrationCredentialsManager()
|
||||||
matched_credentials, missing_credentials = (
|
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||||
await self._resolve_block_credentials(user_id, block, input_data)
|
user_id, block, input_data
|
||||||
)
|
)
|
||||||
|
|
||||||
if missing_credentials:
|
if missing_credentials:
|
||||||
@@ -281,75 +345,29 @@ class RunBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _resolve_block_credentials(
|
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
block: AnyBlockSchema,
|
|
||||||
input_data: dict[str, Any] | None = None,
|
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
|
||||||
"""
|
|
||||||
Resolve credentials for a block by matching user's available credentials.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
block: Block to resolve credentials for
|
|
||||||
input_data: Input data for the block (used to determine provider via discriminator)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple of (matched_credentials, missing_credentials) - matched credentials
|
|
||||||
are used for block execution, missing ones indicate setup requirements.
|
|
||||||
"""
|
|
||||||
input_data = input_data or {}
|
|
||||||
requirements = self._resolve_discriminated_credentials(block, input_data)
|
|
||||||
|
|
||||||
if not requirements:
|
|
||||||
return {}, []
|
|
||||||
|
|
||||||
return await match_credentials_to_requirements(user_id, requirements)
|
|
||||||
|
|
||||||
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
|
|
||||||
"""Extract non-credential inputs from block schema."""
|
"""Extract non-credential inputs from block schema."""
|
||||||
|
inputs_list = []
|
||||||
schema = block.input_schema.jsonschema()
|
schema = block.input_schema.jsonschema()
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
required_fields = set(schema.get("required", []))
|
||||||
|
|
||||||
|
# Get credential field names to exclude
|
||||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||||
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
|
||||||
|
|
||||||
def _resolve_discriminated_credentials(
|
for field_name, field_schema in properties.items():
|
||||||
self,
|
# Skip credential fields
|
||||||
block: AnyBlockSchema,
|
if field_name in credentials_fields:
|
||||||
input_data: dict[str, Any],
|
continue
|
||||||
) -> dict[str, CredentialsFieldInfo]:
|
|
||||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
|
||||||
if not credentials_fields_info:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
resolved: dict[str, CredentialsFieldInfo] = {}
|
inputs_list.append(
|
||||||
|
{
|
||||||
|
"name": field_name,
|
||||||
|
"title": field_schema.get("title", field_name),
|
||||||
|
"type": field_schema.get("type", "string"),
|
||||||
|
"description": field_schema.get("description", ""),
|
||||||
|
"required": field_name in required_fields,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
for field_name, field_info in credentials_fields_info.items():
|
return inputs_list
|
||||||
effective_field_info = field_info
|
|
||||||
|
|
||||||
if field_info.discriminator and field_info.discriminator_mapping:
|
|
||||||
discriminator_value = input_data.get(field_info.discriminator)
|
|
||||||
if discriminator_value is None:
|
|
||||||
field = block.input_schema.model_fields.get(
|
|
||||||
field_info.discriminator
|
|
||||||
)
|
|
||||||
if field and field.default is not PydanticUndefined:
|
|
||||||
discriminator_value = field.default
|
|
||||||
|
|
||||||
if (
|
|
||||||
discriminator_value
|
|
||||||
and discriminator_value in field_info.discriminator_mapping
|
|
||||||
):
|
|
||||||
effective_field_info = field_info.discriminate(discriminator_value)
|
|
||||||
# For host-scoped credentials, add the discriminator value
|
|
||||||
# (e.g., URL) so _credential_is_for_host can match it
|
|
||||||
effective_field_info.discriminator_values.add(discriminator_value)
|
|
||||||
logger.debug(
|
|
||||||
f"Discriminated provider for {field_name}: "
|
|
||||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
|
||||||
)
|
|
||||||
|
|
||||||
resolved[field_name] = effective_field_info
|
|
||||||
|
|
||||||
return resolved
|
|
||||||
|
|||||||
@@ -1,106 +0,0 @@
|
|||||||
"""Tests for block execution guards in RunBlockTool."""
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.api.features.chat.tools.models import ErrorResponse
|
|
||||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
|
||||||
from backend.data.block import BlockType
|
|
||||||
|
|
||||||
from ._test_data import make_session
|
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-run-block"
|
|
||||||
|
|
||||||
|
|
||||||
def make_mock_block(
|
|
||||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
|
||||||
):
|
|
||||||
"""Create a mock block for testing."""
|
|
||||||
mock = MagicMock()
|
|
||||||
mock.id = block_id
|
|
||||||
mock.name = name
|
|
||||||
mock.block_type = block_type
|
|
||||||
mock.disabled = disabled
|
|
||||||
mock.input_schema = MagicMock()
|
|
||||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
|
||||||
mock.input_schema.get_credentials_fields_info.return_value = []
|
|
||||||
return mock
|
|
||||||
|
|
||||||
|
|
||||||
class TestRunBlockFiltering:
|
|
||||||
"""Tests for block execution guards in RunBlockTool."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_excluded_block_type_returns_error(self):
|
|
||||||
"""Attempting to execute a block with excluded BlockType returns error."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=input_block,
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id="input-block-id",
|
|
||||||
input_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, ErrorResponse)
|
|
||||||
assert "cannot be run directly in CoPilot" in response.message
|
|
||||||
assert "designed for use within graphs only" in response.message
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_excluded_block_id_returns_error(self):
|
|
||||||
"""Attempting to execute SmartDecisionMakerBlock returns error."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
|
||||||
smart_block = make_mock_block(
|
|
||||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=smart_block,
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id=smart_decision_id,
|
|
||||||
input_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, ErrorResponse)
|
|
||||||
assert "cannot be run directly in CoPilot" in response.message
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_non_excluded_block_passes_guard(self):
|
|
||||||
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
standard_block = make_mock_block(
|
|
||||||
"standard-id", "HTTP Request", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=standard_block,
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id="standard-id",
|
|
||||||
input_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should NOT be an ErrorResponse about CoPilot exclusion
|
|
||||||
# (may be other errors like missing credentials, but not the exclusion guard)
|
|
||||||
if isinstance(response, ErrorResponse):
|
|
||||||
assert "cannot be run directly in CoPilot" not in response.message
|
|
||||||
@@ -8,7 +8,6 @@ from backend.api.features.library import model as library_model
|
|||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
Credentials,
|
|
||||||
CredentialsFieldInfo,
|
CredentialsFieldInfo,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
HostScopedCredentials,
|
HostScopedCredentials,
|
||||||
@@ -118,7 +117,7 @@ def build_missing_credentials_from_graph(
|
|||||||
preserving all supported credential types for each field.
|
preserving all supported credential types for each field.
|
||||||
"""
|
"""
|
||||||
matched_keys = set(matched_credentials.keys()) if matched_credentials else set()
|
matched_keys = set(matched_credentials.keys()) if matched_credentials else set()
|
||||||
aggregated_fields = graph.aggregate_credentials_inputs()
|
aggregated_fields = graph.regular_credentials_inputs
|
||||||
|
|
||||||
return {
|
return {
|
||||||
field_key: _serialize_missing_credential(field_key, field_info)
|
field_key: _serialize_missing_credential(field_key, field_info)
|
||||||
@@ -224,99 +223,6 @@ async def get_or_create_library_agent(
|
|||||||
return library_agents[0]
|
return library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
async def match_credentials_to_requirements(
|
|
||||||
user_id: str,
|
|
||||||
requirements: dict[str, CredentialsFieldInfo],
|
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
|
||||||
"""
|
|
||||||
Match user's credentials against a dictionary of credential requirements.
|
|
||||||
|
|
||||||
This is the core matching logic shared by both graph and block credential matching.
|
|
||||||
"""
|
|
||||||
matched: dict[str, CredentialsMetaInput] = {}
|
|
||||||
missing: list[CredentialsMetaInput] = []
|
|
||||||
|
|
||||||
if not requirements:
|
|
||||||
return matched, missing
|
|
||||||
|
|
||||||
available_creds = await get_user_credentials(user_id)
|
|
||||||
|
|
||||||
for field_name, field_info in requirements.items():
|
|
||||||
matching_cred = find_matching_credential(available_creds, field_info)
|
|
||||||
|
|
||||||
if matching_cred:
|
|
||||||
try:
|
|
||||||
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
|
||||||
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
|
||||||
f"credential_id={matching_cred.id}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
provider = next(iter(field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
|
||||||
missing.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=f"{field_name} (validation failed: {e})",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
provider = next(iter(field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
|
||||||
missing.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=field_name.replace("_", " ").title(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matched, missing
|
|
||||||
|
|
||||||
|
|
||||||
async def get_user_credentials(user_id: str) -> list[Credentials]:
|
|
||||||
"""Get all available credentials for a user."""
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
return await creds_manager.store.get_all_creds(user_id)
|
|
||||||
|
|
||||||
|
|
||||||
def find_matching_credential(
|
|
||||||
available_creds: list[Credentials],
|
|
||||||
field_info: CredentialsFieldInfo,
|
|
||||||
) -> Credentials | None:
|
|
||||||
"""Find a credential that matches the required provider, type, scopes, and host."""
|
|
||||||
for cred in available_creds:
|
|
||||||
if cred.provider not in field_info.provider:
|
|
||||||
continue
|
|
||||||
if cred.type not in field_info.supported_types:
|
|
||||||
continue
|
|
||||||
if cred.type == "oauth2" and not _credential_has_required_scopes(
|
|
||||||
cred, field_info
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
|
|
||||||
continue
|
|
||||||
return cred
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def create_credential_meta_from_match(
|
|
||||||
matching_cred: Credentials,
|
|
||||||
) -> CredentialsMetaInput:
|
|
||||||
"""Create a CredentialsMetaInput from a matched credential."""
|
|
||||||
return CredentialsMetaInput(
|
|
||||||
id=matching_cred.id,
|
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def match_user_credentials_to_graph(
|
async def match_user_credentials_to_graph(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
graph: GraphModel,
|
graph: GraphModel,
|
||||||
@@ -338,7 +244,7 @@ async def match_user_credentials_to_graph(
|
|||||||
missing_creds: list[str] = []
|
missing_creds: list[str] = []
|
||||||
|
|
||||||
# Get aggregated credentials requirements from the graph
|
# Get aggregated credentials requirements from the graph
|
||||||
aggregated_creds = graph.aggregate_credentials_inputs()
|
aggregated_creds = graph.regular_credentials_inputs
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Matching credentials for graph {graph.id}: {len(aggregated_creds)} required"
|
f"Matching credentials for graph {graph.id}: {len(aggregated_creds)} required"
|
||||||
)
|
)
|
||||||
@@ -425,6 +331,8 @@ def _credential_has_required_scopes(
|
|||||||
# If no scopes are required, any credential matches
|
# If no scopes are required, any credential matches
|
||||||
if not requirements.required_scopes:
|
if not requirements.required_scopes:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# Check that credential scopes are a superset of required scopes
|
||||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,78 @@
|
|||||||
|
"""Tests for chat tools utility functions."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.data.model import CredentialsFieldInfo
|
||||||
|
|
||||||
|
|
||||||
|
def _make_regular_field() -> CredentialsFieldInfo:
|
||||||
|
return CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["github"],
|
||||||
|
"credentials_types": ["api_key"],
|
||||||
|
"is_auto_credential": False,
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_missing_credentials_excludes_auto_creds():
|
||||||
|
"""
|
||||||
|
build_missing_credentials_from_graph() should use regular_credentials_inputs
|
||||||
|
and thus exclude auto_credentials from the "missing" set.
|
||||||
|
"""
|
||||||
|
from backend.api.features.chat.tools.utils import (
|
||||||
|
build_missing_credentials_from_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
regular_field = _make_regular_field()
|
||||||
|
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
# regular_credentials_inputs should only return the non-auto field
|
||||||
|
mock_graph.regular_credentials_inputs = {
|
||||||
|
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = build_missing_credentials_from_graph(mock_graph, matched_credentials=None)
|
||||||
|
|
||||||
|
# Should include the regular credential
|
||||||
|
assert "github_api_key" in result
|
||||||
|
# Should NOT include the auto_credential (not in regular_credentials_inputs)
|
||||||
|
assert "google_oauth2" not in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_match_user_credentials_excludes_auto_creds():
|
||||||
|
"""
|
||||||
|
match_user_credentials_to_graph() should use regular_credentials_inputs
|
||||||
|
and thus exclude auto_credentials from matching.
|
||||||
|
"""
|
||||||
|
from backend.api.features.chat.tools.utils import match_user_credentials_to_graph
|
||||||
|
|
||||||
|
regular_field = _make_regular_field()
|
||||||
|
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
mock_graph.id = "test-graph"
|
||||||
|
# regular_credentials_inputs returns only non-auto fields
|
||||||
|
mock_graph.regular_credentials_inputs = {
|
||||||
|
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock the credentials manager to return no credentials
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.utils.IntegrationCredentialsManager"
|
||||||
|
) as MockCredsMgr:
|
||||||
|
mock_store = AsyncMock()
|
||||||
|
mock_store.get_all_creds.return_value = []
|
||||||
|
MockCredsMgr.return_value.store = mock_store
|
||||||
|
|
||||||
|
matched, missing = await match_user_credentials_to_graph(
|
||||||
|
user_id="test-user", graph=mock_graph
|
||||||
|
)
|
||||||
|
|
||||||
|
# No credentials available, so github should be missing
|
||||||
|
assert len(matched) == 0
|
||||||
|
assert len(missing) == 1
|
||||||
|
assert "github_api_key" in missing[0]
|
||||||
@@ -1103,7 +1103,7 @@ async def create_preset_from_graph_execution(
|
|||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
f"Graph #{graph_execution.graph_id} not found or accessible"
|
f"Graph #{graph_execution.graph_id} not found or accessible"
|
||||||
)
|
)
|
||||||
elif len(graph.aggregate_credentials_inputs()) > 0:
|
elif len(graph.regular_credentials_inputs) > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Graph execution #{graph_exec_id} can't be turned into a preset "
|
f"Graph execution #{graph_exec_id} can't be turned into a preset "
|
||||||
"because it was run before this feature existed "
|
"because it was run before this feature existed "
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ Includes BM25 reranking for improved lexical relevance.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@@ -363,11 +362,7 @@ async def unified_hybrid_search(
|
|||||||
LIMIT {limit_param} OFFSET {offset_param}
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
results = await query_raw_with_schema(sql_query, *params)
|
||||||
results = await query_raw_with_schema(sql_query, *params)
|
|
||||||
except Exception as e:
|
|
||||||
await _log_vector_error_diagnostics(e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
total = results[0]["total_count"] if results else 0
|
total = results[0]["total_count"] if results else 0
|
||||||
# Apply BM25 reranking
|
# Apply BM25 reranking
|
||||||
@@ -691,11 +686,7 @@ async def hybrid_search(
|
|||||||
LIMIT {limit_param} OFFSET {offset_param}
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
results = await query_raw_with_schema(sql_query, *params)
|
||||||
results = await query_raw_with_schema(sql_query, *params)
|
|
||||||
except Exception as e:
|
|
||||||
await _log_vector_error_diagnostics(e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
total = results[0]["total_count"] if results else 0
|
total = results[0]["total_count"] if results else 0
|
||||||
|
|
||||||
@@ -727,87 +718,6 @@ async def hybrid_search_simple(
|
|||||||
return await hybrid_search(query=query, page=page, page_size=page_size)
|
return await hybrid_search(query=query, page=page, page_size=page_size)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Diagnostics
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
# Rate limit: only log vector error diagnostics once per this interval
|
|
||||||
_VECTOR_DIAG_INTERVAL_SECONDS = 60
|
|
||||||
_last_vector_diag_time: float = 0
|
|
||||||
|
|
||||||
|
|
||||||
async def _log_vector_error_diagnostics(error: Exception) -> None:
|
|
||||||
"""Log diagnostic info when 'type vector does not exist' error occurs.
|
|
||||||
|
|
||||||
Note: Diagnostic queries use query_raw_with_schema which may run on a different
|
|
||||||
pooled connection than the one that failed. Session-level search_path can differ,
|
|
||||||
so these diagnostics show cluster-wide state, not necessarily the failed session.
|
|
||||||
|
|
||||||
Includes rate limiting to avoid log spam - only logs once per minute.
|
|
||||||
Caller should re-raise the error after calling this function.
|
|
||||||
"""
|
|
||||||
global _last_vector_diag_time
|
|
||||||
|
|
||||||
# Check if this is the vector type error
|
|
||||||
error_str = str(error).lower()
|
|
||||||
if not (
|
|
||||||
"type" in error_str and "vector" in error_str and "does not exist" in error_str
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Rate limit: only log once per interval
|
|
||||||
now = time.time()
|
|
||||||
if now - _last_vector_diag_time < _VECTOR_DIAG_INTERVAL_SECONDS:
|
|
||||||
return
|
|
||||||
_last_vector_diag_time = now
|
|
||||||
|
|
||||||
try:
|
|
||||||
diagnostics: dict[str, object] = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
search_path_result = await query_raw_with_schema("SHOW search_path")
|
|
||||||
diagnostics["search_path"] = search_path_result
|
|
||||||
except Exception as e:
|
|
||||||
diagnostics["search_path"] = f"Error: {e}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
schema_result = await query_raw_with_schema("SELECT current_schema()")
|
|
||||||
diagnostics["current_schema"] = schema_result
|
|
||||||
except Exception as e:
|
|
||||||
diagnostics["current_schema"] = f"Error: {e}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
user_result = await query_raw_with_schema(
|
|
||||||
"SELECT current_user, session_user, current_database()"
|
|
||||||
)
|
|
||||||
diagnostics["user_info"] = user_result
|
|
||||||
except Exception as e:
|
|
||||||
diagnostics["user_info"] = f"Error: {e}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Check pgvector extension installation (cluster-wide, stable info)
|
|
||||||
ext_result = await query_raw_with_schema(
|
|
||||||
"SELECT extname, extversion, nspname as schema "
|
|
||||||
"FROM pg_extension e "
|
|
||||||
"JOIN pg_namespace n ON e.extnamespace = n.oid "
|
|
||||||
"WHERE extname = 'vector'"
|
|
||||||
)
|
|
||||||
diagnostics["pgvector_extension"] = ext_result
|
|
||||||
except Exception as e:
|
|
||||||
diagnostics["pgvector_extension"] = f"Error: {e}"
|
|
||||||
|
|
||||||
logger.error(
|
|
||||||
f"Vector type error diagnostics:\n"
|
|
||||||
f" Error: {error}\n"
|
|
||||||
f" search_path: {diagnostics.get('search_path')}\n"
|
|
||||||
f" current_schema: {diagnostics.get('current_schema')}\n"
|
|
||||||
f" user_info: {diagnostics.get('user_info')}\n"
|
|
||||||
f" pgvector_extension: {diagnostics.get('pgvector_extension')}"
|
|
||||||
)
|
|
||||||
except Exception as diag_error:
|
|
||||||
logger.error(f"Failed to collect vector error diagnostics: {diag_error}")
|
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
||||||
# for existing code that expects the popularity parameter
|
# for existing code that expects the popularity parameter
|
||||||
HybridSearchWeights = StoreAgentSearchWeights
|
HybridSearchWeights = StoreAgentSearchWeights
|
||||||
|
|||||||
@@ -478,7 +478,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
webset = await aexa.websets.get(id=input_data.external_id)
|
webset = aexa.websets.get(id=input_data.external_id)
|
||||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||||
|
|
||||||
yield "webset", webset_result
|
yield "webset", webset_result
|
||||||
@@ -494,7 +494,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
count=input_data.search_count,
|
count=input_data.search_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
webset = await aexa.websets.create(
|
webset = aexa.websets.create(
|
||||||
params=CreateWebsetParameters(
|
params=CreateWebsetParameters(
|
||||||
search=search_params,
|
search=search_params,
|
||||||
external_id=input_data.external_id,
|
external_id=input_data.external_id,
|
||||||
@@ -554,7 +554,7 @@ class ExaUpdateWebsetBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
|
sdk_webset = aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -617,7 +617,7 @@ class ExaListWebsetsBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = await aexa.websets.list(
|
response = aexa.websets.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -678,7 +678,7 @@ class ExaGetWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
|
sdk_webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -748,7 +748,7 @@ class ExaDeleteWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
|
deleted_webset = aexa.websets.delete(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
deleted_webset.status.value
|
deleted_webset.status.value
|
||||||
@@ -798,7 +798,7 @@ class ExaCancelWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
|
canceled_webset = aexa.websets.cancel(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
canceled_webset.status.value
|
canceled_webset.status.value
|
||||||
@@ -968,7 +968,7 @@ class ExaPreviewWebsetBlock(Block):
|
|||||||
entity["description"] = input_data.entity_description
|
entity["description"] = input_data.entity_description
|
||||||
payload["entity"] = entity
|
payload["entity"] = entity
|
||||||
|
|
||||||
sdk_preview = await aexa.websets.preview(params=payload)
|
sdk_preview = aexa.websets.preview(params=payload)
|
||||||
|
|
||||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||||
|
|
||||||
@@ -1051,7 +1051,7 @@ class ExaWebsetStatusBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
@@ -1185,7 +1185,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Extract basic info
|
# Extract basic info
|
||||||
webset_id = webset.id
|
webset_id = webset.id
|
||||||
@@ -1211,7 +1211,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
total_items = 0
|
total_items = 0
|
||||||
|
|
||||||
if input_data.include_sample_items and input_data.sample_size > 0:
|
if input_data.include_sample_items and input_data.sample_size > 0:
|
||||||
items_response = await aexa.websets.items.list(
|
items_response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
sample_items_data = [
|
sample_items_data = [
|
||||||
@@ -1362,7 +1362,7 @@ class ExaWebsetReadyCheckBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset details
|
# Get webset details
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = await aexa.websets.enrichments.create(
|
sdk_enrichment = aexa.websets.enrichments.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -223,7 +223,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_enrich = await aexa.websets.enrichments.get(
|
current_enrich = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=enrichment_id
|
webset_id=input_data.webset_id, id=enrichment_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -234,7 +234,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
|
|
||||||
if current_status in ["completed", "failed", "cancelled"]:
|
if current_status in ["completed", "failed", "cancelled"]:
|
||||||
# Estimate items from webset searches
|
# Estimate items from webset searches
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
for search in webset.searches:
|
for search in webset.searches:
|
||||||
if search.progress:
|
if search.progress:
|
||||||
@@ -329,7 +329,7 @@ class ExaGetEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = await aexa.websets.enrichments.get(
|
sdk_enrichment = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -474,7 +474,7 @@ class ExaDeleteEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_enrichment = await aexa.websets.enrichments.delete(
|
deleted_enrichment = aexa.websets.enrichments.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -525,13 +525,13 @@ class ExaCancelEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_enrichment = await aexa.websets.enrichments.cancel(
|
canceled_enrichment = aexa.websets.enrichments.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to estimate how many items were enriched before cancellation
|
# Try to estimate how many items were enriched before cancellation
|
||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
items_response = await aexa.websets.items.list(
|
items_response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=100
|
webset_id=input_data.webset_id, limit=100
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
# Create mock SDK import object
|
# Create mock SDK import object
|
||||||
mock_import = MagicMock()
|
mock_import = MagicMock()
|
||||||
@@ -247,7 +247,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
imports=MagicMock(create=AsyncMock(return_value=mock_import))
|
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -294,7 +294,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_import = await aexa.websets.imports.create(
|
sdk_import = aexa.websets.imports.create(
|
||||||
params=payload, csv_data=input_data.csv_data
|
params=payload, csv_data=input_data.csv_data
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -360,7 +360,7 @@ class ExaGetImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
|
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
|
||||||
|
|
||||||
import_obj = ImportModel.from_sdk(sdk_import)
|
import_obj = ImportModel.from_sdk(sdk_import)
|
||||||
|
|
||||||
@@ -426,7 +426,7 @@ class ExaListImportsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = await aexa.websets.imports.list(
|
response = aexa.websets.imports.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -474,9 +474,7 @@ class ExaDeleteImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_import = await aexa.websets.imports.delete(
|
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
|
||||||
import_id=input_data.import_id
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "import_id", deleted_import.id
|
yield "import_id", deleted_import.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -575,14 +573,14 @@ class ExaExportWebsetBlock(Block):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create async iterator for list_all
|
# Create mock iterator
|
||||||
async def async_item_iterator(*args, **kwargs):
|
mock_items = [mock_item1, mock_item2]
|
||||||
for item in [mock_item1, mock_item2]:
|
|
||||||
yield item
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(items=MagicMock(list_all=async_item_iterator))
|
websets=MagicMock(
|
||||||
|
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
|
||||||
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -604,7 +602,7 @@ class ExaExportWebsetBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
async for sdk_item in item_iterator:
|
for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class ExaGetWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_item = await aexa.websets.items.get(
|
sdk_item = aexa.websets.items.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -269,7 +269,7 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
response = None
|
response = None
|
||||||
|
|
||||||
while time.time() - start_time < input_data.wait_timeout:
|
while time.time() - start_time < input_data.wait_timeout:
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -282,13 +282,13 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
interval = min(interval * 1.2, 10)
|
interval = min(interval * 1.2, 10)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -340,7 +340,7 @@ class ExaDeleteWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_item = await aexa.websets.items.delete(
|
deleted_item = aexa.websets.items.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -408,7 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
async for sdk_item in item_iterator:
|
for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -475,7 +475,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
entity_type = "unknown"
|
entity_type = "unknown"
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
@@ -495,7 +495,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Get sample items if requested
|
# Get sample items if requested
|
||||||
sample_items: List[WebsetItemModel] = []
|
sample_items: List[WebsetItemModel] = []
|
||||||
if input_data.sample_size > 0:
|
if input_data.sample_size > 0:
|
||||||
items_response = await aexa.websets.items.list(
|
items_response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
# Convert to our stable models
|
# Convert to our stable models
|
||||||
@@ -569,7 +569,7 @@ class ExaGetNewItemsBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get items starting from cursor
|
# Get items starting from cursor
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.since_cursor,
|
cursor=input_data.since_cursor,
|
||||||
limit=input_data.max_items,
|
limit=input_data.max_items,
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
# Create mock SDK monitor object
|
# Create mock SDK monitor object
|
||||||
mock_monitor = MagicMock()
|
mock_monitor = MagicMock()
|
||||||
@@ -263,7 +263,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
monitors=MagicMock(create=AsyncMock(return_value=mock_monitor))
|
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -320,7 +320,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = await aexa.websets.monitors.create(params=payload)
|
sdk_monitor = aexa.websets.monitors.create(params=payload)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -384,7 +384,7 @@ class ExaGetMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -476,7 +476,7 @@ class ExaUpdateMonitorBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = await aexa.websets.monitors.update(
|
sdk_monitor = aexa.websets.monitors.update(
|
||||||
monitor_id=input_data.monitor_id, params=payload
|
monitor_id=input_data.monitor_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -522,9 +522,7 @@ class ExaDeleteMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_monitor = await aexa.websets.monitors.delete(
|
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
|
||||||
monitor_id=input_data.monitor_id
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "monitor_id", deleted_monitor.id
|
yield "monitor_id", deleted_monitor.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -581,7 +579,7 @@ class ExaListMonitorsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = await aexa.websets.monitors.list(
|
response = aexa.websets.monitors.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
WebsetTargetStatus.IDLE,
|
WebsetTargetStatus.IDLE,
|
||||||
WebsetTargetStatus.ANY_COMPLETE,
|
WebsetTargetStatus.ANY_COMPLETE,
|
||||||
]:
|
]:
|
||||||
final_webset = await aexa.websets.wait_until_idle(
|
final_webset = aexa.websets.wait_until_idle(
|
||||||
id=input_data.webset_id,
|
id=input_data.webset_id,
|
||||||
timeout=input_data.timeout,
|
timeout=input_data.timeout,
|
||||||
poll_interval=input_data.check_interval,
|
poll_interval=input_data.check_interval,
|
||||||
@@ -164,7 +164,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
interval = input_data.check_interval
|
interval = input_data.check_interval
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current webset status
|
# Get current webset status
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
current_status = (
|
current_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -209,7 +209,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
|
|
||||||
# Timeout reached
|
# Timeout reached
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
final_status = (
|
final_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -345,7 +345,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current search status using SDK
|
# Get current search status using SDK
|
||||||
search = await aexa.websets.searches.get(
|
search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -401,7 +401,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
search = await aexa.websets.searches.get(
|
search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -503,7 +503,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current enrichment status using SDK
|
# Get current enrichment status using SDK
|
||||||
enrichment = await aexa.websets.enrichments.get(
|
enrichment = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -548,7 +548,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
enrichment = await aexa.websets.enrichments.get(
|
enrichment = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -575,7 +575,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
) -> tuple[list[SampleEnrichmentModel], int]:
|
) -> tuple[list[SampleEnrichmentModel], int]:
|
||||||
"""Get sample enriched data and count."""
|
"""Get sample enriched data and count."""
|
||||||
# Get a few items to see enrichment results using SDK
|
# Get a few items to see enrichment results using SDK
|
||||||
response = await aexa.websets.items.list(webset_id=webset_id, limit=5)
|
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||||
|
|
||||||
sample_data: list[SampleEnrichmentModel] = []
|
sample_data: list[SampleEnrichmentModel] = []
|
||||||
enriched_count = 0
|
enriched_count = 0
|
||||||
|
|||||||
@@ -317,7 +317,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
|
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = await aexa.websets.searches.create(
|
sdk_search = aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -350,7 +350,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
poll_start = time.time()
|
poll_start = time.time()
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_search = await aexa.websets.searches.get(
|
current_search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=search_id
|
webset_id=input_data.webset_id, id=search_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -442,7 +442,7 @@ class ExaGetWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = await aexa.websets.searches.get(
|
sdk_search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -523,7 +523,7 @@ class ExaCancelWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_search = await aexa.websets.searches.cancel(
|
canceled_search = aexa.websets.searches.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -604,7 +604,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset to check existing searches
|
# Get webset to check existing searches
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Look for existing search with same query
|
# Look for existing search with same query
|
||||||
existing_search = None
|
existing_search = None
|
||||||
@@ -636,7 +636,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
if input_data.entity_type != SearchEntityType.AUTO:
|
if input_data.entity_type != SearchEntityType.AUTO:
|
||||||
payload["entity"] = {"type": input_data.entity_type.value}
|
payload["entity"] = {"type": input_data.entity_type.value}
|
||||||
|
|
||||||
sdk_search = await aexa.websets.searches.create(
|
sdk_search = aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -531,12 +531,12 @@ class LLMResponse(BaseModel):
|
|||||||
|
|
||||||
def convert_openai_tool_fmt_to_anthropic(
|
def convert_openai_tool_fmt_to_anthropic(
|
||||||
openai_tools: list[dict] | None = None,
|
openai_tools: list[dict] | None = None,
|
||||||
) -> Iterable[ToolParam] | anthropic.Omit:
|
) -> Iterable[ToolParam] | anthropic.NotGiven:
|
||||||
"""
|
"""
|
||||||
Convert OpenAI tool format to Anthropic tool format.
|
Convert OpenAI tool format to Anthropic tool format.
|
||||||
"""
|
"""
|
||||||
if not openai_tools or len(openai_tools) == 0:
|
if not openai_tools or len(openai_tools) == 0:
|
||||||
return anthropic.omit
|
return anthropic.NOT_GIVEN
|
||||||
|
|
||||||
anthropic_tools = []
|
anthropic_tools = []
|
||||||
for tool in openai_tools:
|
for tool in openai_tools:
|
||||||
@@ -596,10 +596,10 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
|||||||
|
|
||||||
def get_parallel_tool_calls_param(
|
def get_parallel_tool_calls_param(
|
||||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||||
) -> bool | openai.Omit:
|
):
|
||||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||||
return openai.omit
|
return openai.NOT_GIVEN
|
||||||
return parallel_tool_calls
|
return parallel_tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -319,6 +319,8 @@ class BlockSchema(BaseModel):
|
|||||||
"credentials_provider": [config.get("provider", "google")],
|
"credentials_provider": [config.get("provider", "google")],
|
||||||
"credentials_types": [config.get("type", "oauth2")],
|
"credentials_types": [config.get("type", "oauth2")],
|
||||||
"credentials_scopes": config.get("scopes"),
|
"credentials_scopes": config.get("scopes"),
|
||||||
|
"is_auto_credential": True,
|
||||||
|
"input_field_name": info["field_name"],
|
||||||
}
|
}
|
||||||
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
||||||
auto_schema, by_alias=True
|
auto_schema, by_alias=True
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import queue
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from multiprocessing import Manager
|
||||||
|
from queue import Empty
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Annotated,
|
Annotated,
|
||||||
@@ -1199,16 +1200,12 @@ class NodeExecutionEntry(BaseModel):
|
|||||||
|
|
||||||
class ExecutionQueue(Generic[T]):
|
class ExecutionQueue(Generic[T]):
|
||||||
"""
|
"""
|
||||||
Thread-safe queue for managing node execution within a single graph execution.
|
Queue for managing the execution of agents.
|
||||||
|
This will be shared between different processes
|
||||||
Note: Uses queue.Queue (not multiprocessing.Queue) since all access is from
|
|
||||||
threads within the same process. If migrating back to ProcessPoolExecutor,
|
|
||||||
replace with multiprocessing.Manager().Queue() for cross-process safety.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Thread-safe queue (not multiprocessing) — see class docstring
|
self.queue = Manager().Queue()
|
||||||
self.queue: queue.Queue[T] = queue.Queue()
|
|
||||||
|
|
||||||
def add(self, execution: T) -> T:
|
def add(self, execution: T) -> T:
|
||||||
self.queue.put(execution)
|
self.queue.put(execution)
|
||||||
@@ -1223,7 +1220,7 @@ class ExecutionQueue(Generic[T]):
|
|||||||
def get_or_none(self) -> T | None:
|
def get_or_none(self) -> T | None:
|
||||||
try:
|
try:
|
||||||
return self.queue.get_nowait()
|
return self.queue.get_nowait()
|
||||||
except queue.Empty:
|
except Empty:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
"""Tests for ExecutionQueue thread-safety."""
|
|
||||||
|
|
||||||
import queue
|
|
||||||
import threading
|
|
||||||
|
|
||||||
from backend.data.execution import ExecutionQueue
|
|
||||||
|
|
||||||
|
|
||||||
def test_execution_queue_uses_stdlib_queue():
|
|
||||||
"""Verify ExecutionQueue uses queue.Queue (not multiprocessing)."""
|
|
||||||
q = ExecutionQueue()
|
|
||||||
assert isinstance(q.queue, queue.Queue)
|
|
||||||
|
|
||||||
|
|
||||||
def test_basic_operations():
|
|
||||||
"""Test add, get, empty, and get_or_none."""
|
|
||||||
q = ExecutionQueue()
|
|
||||||
|
|
||||||
assert q.empty() is True
|
|
||||||
assert q.get_or_none() is None
|
|
||||||
|
|
||||||
result = q.add("item1")
|
|
||||||
assert result == "item1"
|
|
||||||
assert q.empty() is False
|
|
||||||
|
|
||||||
item = q.get()
|
|
||||||
assert item == "item1"
|
|
||||||
assert q.empty() is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_thread_safety():
|
|
||||||
"""Test concurrent access from multiple threads."""
|
|
||||||
q = ExecutionQueue()
|
|
||||||
results = []
|
|
||||||
num_items = 100
|
|
||||||
|
|
||||||
def producer():
|
|
||||||
for i in range(num_items):
|
|
||||||
q.add(f"item_{i}")
|
|
||||||
|
|
||||||
def consumer():
|
|
||||||
count = 0
|
|
||||||
while count < num_items:
|
|
||||||
item = q.get_or_none()
|
|
||||||
if item is not None:
|
|
||||||
results.append(item)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
producer_thread = threading.Thread(target=producer)
|
|
||||||
consumer_thread = threading.Thread(target=consumer)
|
|
||||||
|
|
||||||
producer_thread.start()
|
|
||||||
consumer_thread.start()
|
|
||||||
|
|
||||||
producer_thread.join(timeout=5)
|
|
||||||
consumer_thread.join(timeout=5)
|
|
||||||
|
|
||||||
assert len(results) == num_items
|
|
||||||
@@ -447,8 +447,7 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def credentials_input_schema(self) -> dict[str, Any]:
|
def credentials_input_schema(self) -> dict[str, Any]:
|
||||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
graph_credentials_inputs = self.regular_credentials_inputs
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||||
f"{graph_credentials_inputs}"
|
f"{graph_credentials_inputs}"
|
||||||
@@ -604,6 +603,28 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
for key, (field_info, node_field_pairs) in combined.items()
|
for key, (field_info, node_field_pairs) in combined.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def regular_credentials_inputs(
|
||||||
|
self,
|
||||||
|
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
||||||
|
"""Credentials that need explicit user mapping (CredentialsMetaInput fields)."""
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in self.aggregate_credentials_inputs().items()
|
||||||
|
if not v[0].is_auto_credential
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def auto_credentials_inputs(
|
||||||
|
self,
|
||||||
|
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
||||||
|
"""Credentials embedded in file fields (_credentials_id), resolved at execution time."""
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in self.aggregate_credentials_inputs().items()
|
||||||
|
if v[0].is_auto_credential
|
||||||
|
}
|
||||||
|
|
||||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||||
"""
|
"""
|
||||||
Reassigns all IDs in the graph to new UUIDs.
|
Reassigns all IDs in the graph to new UUIDs.
|
||||||
@@ -654,6 +675,16 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
) and graph_id in graph_id_map:
|
) and graph_id in graph_id_map:
|
||||||
node.input_default["graph_id"] = graph_id_map[graph_id]
|
node.input_default["graph_id"] = graph_id_map[graph_id]
|
||||||
|
|
||||||
|
# Clear auto-credentials references (e.g., _credentials_id in
|
||||||
|
# GoogleDriveFile fields) so the new user must re-authenticate
|
||||||
|
# with their own account
|
||||||
|
for node in graph.nodes:
|
||||||
|
if not node.input_default:
|
||||||
|
continue
|
||||||
|
for key, value in node.input_default.items():
|
||||||
|
if isinstance(value, dict) and "_credentials_id" in value:
|
||||||
|
del value["_credentials_id"]
|
||||||
|
|
||||||
def validate_graph(
|
def validate_graph(
|
||||||
self,
|
self,
|
||||||
for_run: bool = False,
|
for_run: bool = False,
|
||||||
|
|||||||
@@ -463,3 +463,329 @@ def test_node_credentials_optional_with_other_metadata():
|
|||||||
assert node.credentials_optional is True
|
assert node.credentials_optional is True
|
||||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
assert node.metadata["position"] == {"x": 100, "y": 200}
|
||||||
assert node.metadata["customized_name"] == "My Custom Node"
|
assert node.metadata["customized_name"] == "My Custom Node"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for CredentialsFieldInfo.combine() field propagation
|
||||||
|
def test_combine_preserves_is_auto_credential_flag():
|
||||||
|
"""
|
||||||
|
CredentialsFieldInfo.combine() must propagate is_auto_credential and
|
||||||
|
input_field_name to the combined result. Regression test for reviewer
|
||||||
|
finding that combine() dropped these fields.
|
||||||
|
"""
|
||||||
|
from backend.data.model import CredentialsFieldInfo
|
||||||
|
|
||||||
|
auto_field = CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["google"],
|
||||||
|
"credentials_types": ["oauth2"],
|
||||||
|
"credentials_scopes": ["drive.readonly"],
|
||||||
|
"is_auto_credential": True,
|
||||||
|
"input_field_name": "spreadsheet",
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# combine() takes *args of (field_info, key) tuples
|
||||||
|
combined = CredentialsFieldInfo.combine(
|
||||||
|
(auto_field, ("node-1", "credentials")),
|
||||||
|
(auto_field, ("node-2", "credentials")),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(combined) == 1
|
||||||
|
group_key = next(iter(combined))
|
||||||
|
combined_info, combined_keys = combined[group_key]
|
||||||
|
|
||||||
|
assert combined_info.is_auto_credential is True
|
||||||
|
assert combined_info.input_field_name == "spreadsheet"
|
||||||
|
assert combined_keys == {("node-1", "credentials"), ("node-2", "credentials")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_combine_preserves_regular_credential_defaults():
|
||||||
|
"""Regular credentials should have is_auto_credential=False after combine()."""
|
||||||
|
from backend.data.model import CredentialsFieldInfo
|
||||||
|
|
||||||
|
regular_field = CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["github"],
|
||||||
|
"credentials_types": ["api_key"],
|
||||||
|
"is_auto_credential": False,
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
combined = CredentialsFieldInfo.combine(
|
||||||
|
(regular_field, ("node-1", "credentials")),
|
||||||
|
)
|
||||||
|
|
||||||
|
group_key = next(iter(combined))
|
||||||
|
combined_info, _ = combined[group_key]
|
||||||
|
|
||||||
|
assert combined_info.is_auto_credential is False
|
||||||
|
assert combined_info.input_field_name is None
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for _reassign_ids credential clearing (Fix 3: SECRT-1772)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reassign_ids_clears_credentials_id():
|
||||||
|
"""
|
||||||
|
[SECRT-1772] _reassign_ids should clear _credentials_id from
|
||||||
|
GoogleDriveFile-style input_default fields so forked agents
|
||||||
|
don't retain the original creator's credential references.
|
||||||
|
"""
|
||||||
|
from backend.data.graph import GraphModel
|
||||||
|
|
||||||
|
node = Node(
|
||||||
|
id="node-1",
|
||||||
|
block_id=StoreValueBlock().id,
|
||||||
|
input_default={
|
||||||
|
"spreadsheet": {
|
||||||
|
"_credentials_id": "original-cred-id",
|
||||||
|
"id": "file-123",
|
||||||
|
"name": "test.xlsx",
|
||||||
|
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||||
|
"url": "https://docs.google.com/spreadsheets/d/file-123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = Graph(
|
||||||
|
id="test-graph",
|
||||||
|
name="Test",
|
||||||
|
description="Test",
|
||||||
|
nodes=[node],
|
||||||
|
links=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
||||||
|
|
||||||
|
# _credentials_id key should be removed (not set to None) so that
|
||||||
|
# _acquire_auto_credentials correctly errors instead of treating it as chained data
|
||||||
|
assert "_credentials_id" not in graph.nodes[0].input_default["spreadsheet"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_reassign_ids_preserves_non_credential_fields():
|
||||||
|
"""
|
||||||
|
Regression guard: _reassign_ids should NOT modify non-credential fields
|
||||||
|
like name, mimeType, id, url.
|
||||||
|
"""
|
||||||
|
from backend.data.graph import GraphModel
|
||||||
|
|
||||||
|
node = Node(
|
||||||
|
id="node-1",
|
||||||
|
block_id=StoreValueBlock().id,
|
||||||
|
input_default={
|
||||||
|
"spreadsheet": {
|
||||||
|
"_credentials_id": "cred-abc",
|
||||||
|
"id": "file-123",
|
||||||
|
"name": "test.xlsx",
|
||||||
|
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||||
|
"url": "https://docs.google.com/spreadsheets/d/file-123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = Graph(
|
||||||
|
id="test-graph",
|
||||||
|
name="Test",
|
||||||
|
description="Test",
|
||||||
|
nodes=[node],
|
||||||
|
links=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
||||||
|
|
||||||
|
field = graph.nodes[0].input_default["spreadsheet"]
|
||||||
|
assert field["id"] == "file-123"
|
||||||
|
assert field["name"] == "test.xlsx"
|
||||||
|
assert field["mimeType"] == "application/vnd.google-apps.spreadsheet"
|
||||||
|
assert field["url"] == "https://docs.google.com/spreadsheets/d/file-123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_reassign_ids_handles_no_credentials():
|
||||||
|
"""
|
||||||
|
Regression guard: _reassign_ids should not error when input_default
|
||||||
|
has no dict fields with _credentials_id.
|
||||||
|
"""
|
||||||
|
from backend.data.graph import GraphModel
|
||||||
|
|
||||||
|
node = Node(
|
||||||
|
id="node-1",
|
||||||
|
block_id=StoreValueBlock().id,
|
||||||
|
input_default={
|
||||||
|
"input": "some value",
|
||||||
|
"another_input": 42,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = Graph(
|
||||||
|
id="test-graph",
|
||||||
|
name="Test",
|
||||||
|
description="Test",
|
||||||
|
nodes=[node],
|
||||||
|
links=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
||||||
|
|
||||||
|
# Should not error, fields unchanged
|
||||||
|
assert graph.nodes[0].input_default["input"] == "some value"
|
||||||
|
assert graph.nodes[0].input_default["another_input"] == 42
|
||||||
|
|
||||||
|
|
||||||
|
def test_reassign_ids_handles_multiple_credential_fields():
|
||||||
|
"""
|
||||||
|
[SECRT-1772] When a node has multiple dict fields with _credentials_id,
|
||||||
|
ALL of them should be cleared.
|
||||||
|
"""
|
||||||
|
from backend.data.graph import GraphModel
|
||||||
|
|
||||||
|
node = Node(
|
||||||
|
id="node-1",
|
||||||
|
block_id=StoreValueBlock().id,
|
||||||
|
input_default={
|
||||||
|
"spreadsheet": {
|
||||||
|
"_credentials_id": "cred-1",
|
||||||
|
"id": "file-1",
|
||||||
|
"name": "file1.xlsx",
|
||||||
|
},
|
||||||
|
"doc_file": {
|
||||||
|
"_credentials_id": "cred-2",
|
||||||
|
"id": "file-2",
|
||||||
|
"name": "file2.docx",
|
||||||
|
},
|
||||||
|
"plain_input": "not a dict",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = Graph(
|
||||||
|
id="test-graph",
|
||||||
|
name="Test",
|
||||||
|
description="Test",
|
||||||
|
nodes=[node],
|
||||||
|
links=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
||||||
|
|
||||||
|
assert "_credentials_id" not in graph.nodes[0].input_default["spreadsheet"]
|
||||||
|
assert "_credentials_id" not in graph.nodes[0].input_default["doc_file"]
|
||||||
|
assert graph.nodes[0].input_default["plain_input"] == "not a dict"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for discriminate() field propagation
|
||||||
|
def test_discriminate_preserves_is_auto_credential_flag():
|
||||||
|
"""
|
||||||
|
CredentialsFieldInfo.discriminate() must propagate is_auto_credential and
|
||||||
|
input_field_name to the discriminated result. Regression test for
|
||||||
|
discriminate() dropping these fields (same class of bug as combine()).
|
||||||
|
"""
|
||||||
|
from backend.data.model import CredentialsFieldInfo
|
||||||
|
|
||||||
|
auto_field = CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["google", "openai"],
|
||||||
|
"credentials_types": ["oauth2"],
|
||||||
|
"credentials_scopes": ["drive.readonly"],
|
||||||
|
"is_auto_credential": True,
|
||||||
|
"input_field_name": "spreadsheet",
|
||||||
|
"discriminator": "model",
|
||||||
|
"discriminator_mapping": {"gpt-4": "openai", "gemini": "google"},
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
discriminated = auto_field.discriminate("gemini")
|
||||||
|
|
||||||
|
assert discriminated.is_auto_credential is True
|
||||||
|
assert discriminated.input_field_name == "spreadsheet"
|
||||||
|
assert discriminated.provider == frozenset(["google"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_discriminate_preserves_regular_credential_defaults():
|
||||||
|
"""Regular credentials should have is_auto_credential=False after discriminate()."""
|
||||||
|
from backend.data.model import CredentialsFieldInfo
|
||||||
|
|
||||||
|
regular_field = CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["google", "openai"],
|
||||||
|
"credentials_types": ["api_key"],
|
||||||
|
"is_auto_credential": False,
|
||||||
|
"discriminator": "model",
|
||||||
|
"discriminator_mapping": {"gpt-4": "openai", "gemini": "google"},
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
discriminated = regular_field.discriminate("gpt-4")
|
||||||
|
|
||||||
|
assert discriminated.is_auto_credential is False
|
||||||
|
assert discriminated.input_field_name is None
|
||||||
|
assert discriminated.provider == frozenset(["openai"])
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for credentials_input_schema excluding auto_credentials
|
||||||
|
def test_credentials_input_schema_excludes_auto_creds():
|
||||||
|
"""
|
||||||
|
GraphModel.credentials_input_schema should exclude auto_credentials
|
||||||
|
(is_auto_credential=True) from the schema. Auto_credentials are
|
||||||
|
transparently resolved at execution time via file picker data.
|
||||||
|
"""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import PropertyMock, patch
|
||||||
|
|
||||||
|
from backend.data.graph import GraphModel, NodeModel
|
||||||
|
from backend.data.model import CredentialsFieldInfo
|
||||||
|
|
||||||
|
regular_field_info = CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["github"],
|
||||||
|
"credentials_types": ["api_key"],
|
||||||
|
"is_auto_credential": False,
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = GraphModel(
|
||||||
|
id="test-graph",
|
||||||
|
version=1,
|
||||||
|
name="Test",
|
||||||
|
description="Test",
|
||||||
|
user_id="test-user",
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
nodes=[
|
||||||
|
NodeModel(
|
||||||
|
id="node-1",
|
||||||
|
block_id=StoreValueBlock().id,
|
||||||
|
input_default={},
|
||||||
|
graph_id="test-graph",
|
||||||
|
graph_version=1,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
links=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock regular_credentials_inputs to return only the non-auto field (3-tuple)
|
||||||
|
regular_only = {
|
||||||
|
"github_credentials": (
|
||||||
|
regular_field_info,
|
||||||
|
{("node-1", "credentials")},
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
type(graph),
|
||||||
|
"regular_credentials_inputs",
|
||||||
|
new_callable=PropertyMock,
|
||||||
|
return_value=regular_only,
|
||||||
|
):
|
||||||
|
schema = graph.credentials_input_schema
|
||||||
|
field_names = set(schema.get("properties", {}).keys())
|
||||||
|
# Should include regular credential but NOT auto_credential
|
||||||
|
assert "github_credentials" in field_names
|
||||||
|
assert "google_credentials" not in field_names
|
||||||
|
|||||||
@@ -571,6 +571,8 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
discriminator: Optional[str] = None
|
discriminator: Optional[str] = None
|
||||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||||
discriminator_values: set[Any] = Field(default_factory=set)
|
discriminator_values: set[Any] = Field(default_factory=set)
|
||||||
|
is_auto_credential: bool = False
|
||||||
|
input_field_name: Optional[str] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def combine(
|
def combine(
|
||||||
@@ -651,6 +653,9 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
+ "_credentials"
|
+ "_credentials"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Propagate is_auto_credential from the combined field.
|
||||||
|
# All fields in a group should share the same is_auto_credential
|
||||||
|
# value since auto and regular credentials serve different purposes.
|
||||||
result[group_key] = (
|
result[group_key] = (
|
||||||
CredentialsFieldInfo[CP, CT](
|
CredentialsFieldInfo[CP, CT](
|
||||||
credentials_provider=combined.provider,
|
credentials_provider=combined.provider,
|
||||||
@@ -659,6 +664,8 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
discriminator=combined.discriminator,
|
discriminator=combined.discriminator,
|
||||||
discriminator_mapping=combined.discriminator_mapping,
|
discriminator_mapping=combined.discriminator_mapping,
|
||||||
discriminator_values=set(all_discriminator_values),
|
discriminator_values=set(all_discriminator_values),
|
||||||
|
is_auto_credential=combined.is_auto_credential,
|
||||||
|
input_field_name=combined.input_field_name,
|
||||||
),
|
),
|
||||||
combined_keys,
|
combined_keys,
|
||||||
)
|
)
|
||||||
@@ -684,6 +691,8 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
discriminator=self.discriminator,
|
discriminator=self.discriminator,
|
||||||
discriminator_mapping=self.discriminator_mapping,
|
discriminator_mapping=self.discriminator_mapping,
|
||||||
discriminator_values=self.discriminator_values,
|
discriminator_values=self.discriminator_values,
|
||||||
|
is_auto_credential=self.is_auto_credential,
|
||||||
|
input_field_name=self.input_field_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -226,10 +225,6 @@ class SyncRabbitMQ(RabbitMQBase):
|
|||||||
class AsyncRabbitMQ(RabbitMQBase):
|
class AsyncRabbitMQ(RabbitMQBase):
|
||||||
"""Asynchronous RabbitMQ client"""
|
"""Asynchronous RabbitMQ client"""
|
||||||
|
|
||||||
def __init__(self, config: RabbitMQConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
self._reconnect_lock: asyncio.Lock | None = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return bool(self._connection and not self._connection.is_closed)
|
return bool(self._connection and not self._connection.is_closed)
|
||||||
@@ -240,17 +235,7 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
|
|
||||||
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
|
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
if self.is_connected and self._channel and not self._channel.is_closed:
|
if self.is_connected:
|
||||||
return
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.is_connected
|
|
||||||
and self._connection
|
|
||||||
and (self._channel is None or self._channel.is_closed)
|
|
||||||
):
|
|
||||||
self._channel = await self._connection.channel()
|
|
||||||
await self._channel.set_qos(prefetch_count=1)
|
|
||||||
await self.declare_infrastructure()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self._connection = await aio_pika.connect_robust(
|
self._connection = await aio_pika.connect_robust(
|
||||||
@@ -306,46 +291,24 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
exchange, routing_key=queue.routing_key or queue.name
|
exchange, routing_key=queue.routing_key or queue.name
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@func_retry
|
||||||
def _lock(self) -> asyncio.Lock:
|
async def publish_message(
|
||||||
if self._reconnect_lock is None:
|
|
||||||
self._reconnect_lock = asyncio.Lock()
|
|
||||||
return self._reconnect_lock
|
|
||||||
|
|
||||||
async def _ensure_channel(self) -> aio_pika.abc.AbstractChannel:
|
|
||||||
"""Get a valid channel, reconnecting if the current one is stale.
|
|
||||||
|
|
||||||
Uses a lock to prevent concurrent reconnection attempts from racing.
|
|
||||||
"""
|
|
||||||
if self.is_ready:
|
|
||||||
return self._channel # type: ignore # is_ready guarantees non-None
|
|
||||||
|
|
||||||
async with self._lock:
|
|
||||||
# Double-check after acquiring lock
|
|
||||||
if self.is_ready:
|
|
||||||
return self._channel # type: ignore
|
|
||||||
|
|
||||||
self._channel = None
|
|
||||||
await self.connect()
|
|
||||||
|
|
||||||
if self._channel is None:
|
|
||||||
raise RuntimeError("Channel should be established after connect")
|
|
||||||
|
|
||||||
return self._channel
|
|
||||||
|
|
||||||
async def _publish_once(
|
|
||||||
self,
|
self,
|
||||||
routing_key: str,
|
routing_key: str,
|
||||||
message: str,
|
message: str,
|
||||||
exchange: Optional[Exchange] = None,
|
exchange: Optional[Exchange] = None,
|
||||||
persistent: bool = True,
|
persistent: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
channel = await self._ensure_channel()
|
if not self.is_ready:
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
if self._channel is None:
|
||||||
|
raise RuntimeError("Channel should be established after connect")
|
||||||
|
|
||||||
if exchange:
|
if exchange:
|
||||||
exchange_obj = await channel.get_exchange(exchange.name)
|
exchange_obj = await self._channel.get_exchange(exchange.name)
|
||||||
else:
|
else:
|
||||||
exchange_obj = channel.default_exchange
|
exchange_obj = self._channel.default_exchange
|
||||||
|
|
||||||
await exchange_obj.publish(
|
await exchange_obj.publish(
|
||||||
aio_pika.Message(
|
aio_pika.Message(
|
||||||
@@ -359,23 +322,9 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
routing_key=routing_key,
|
routing_key=routing_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@func_retry
|
|
||||||
async def publish_message(
|
|
||||||
self,
|
|
||||||
routing_key: str,
|
|
||||||
message: str,
|
|
||||||
exchange: Optional[Exchange] = None,
|
|
||||||
persistent: bool = True,
|
|
||||||
) -> None:
|
|
||||||
try:
|
|
||||||
await self._publish_once(routing_key, message, exchange, persistent)
|
|
||||||
except aio_pika.exceptions.ChannelInvalidStateError:
|
|
||||||
logger.warning(
|
|
||||||
"RabbitMQ channel invalid, forcing reconnect and retrying publish"
|
|
||||||
)
|
|
||||||
async with self._lock:
|
|
||||||
self._channel = None
|
|
||||||
await self._publish_once(routing_key, message, exchange, persistent)
|
|
||||||
|
|
||||||
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
|
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
|
||||||
return await self._ensure_channel()
|
if not self.is_ready:
|
||||||
|
await self.connect()
|
||||||
|
if self._channel is None:
|
||||||
|
raise RuntimeError("Channel should be established after connect")
|
||||||
|
return self._channel
|
||||||
|
|||||||
@@ -92,6 +92,7 @@ from .utils import (
|
|||||||
block_usage_cost,
|
block_usage_cost,
|
||||||
create_execution_queue_config,
|
create_execution_queue_config,
|
||||||
execution_usage_cost,
|
execution_usage_cost,
|
||||||
|
parse_auto_credential_field,
|
||||||
validate_exec,
|
validate_exec,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -172,6 +173,60 @@ def execute_graph(
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
async def _acquire_auto_credentials(
|
||||||
|
input_model: type[BlockSchema],
|
||||||
|
input_data: dict[str, Any],
|
||||||
|
creds_manager: "IntegrationCredentialsManager",
|
||||||
|
user_id: str,
|
||||||
|
) -> tuple[dict[str, Any], list[AsyncRedisLock]]:
|
||||||
|
"""
|
||||||
|
Resolve auto_credentials from GoogleDriveFileField-style inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(extra_exec_kwargs, locks): kwargs to inject into block execution, and
|
||||||
|
credential locks to release after execution completes.
|
||||||
|
"""
|
||||||
|
extra_exec_kwargs: dict[str, Any] = {}
|
||||||
|
locks: list[AsyncRedisLock] = []
|
||||||
|
|
||||||
|
# NOTE: If a block ever has multiple auto-credential fields, a ValueError
|
||||||
|
# on a later field will strand locks acquired for earlier fields. They'll
|
||||||
|
# auto-expire via Redis TTL, but add a try/except to release partial locks
|
||||||
|
# if that becomes a real scenario.
|
||||||
|
for kwarg_name, info in input_model.get_auto_credentials_fields().items():
|
||||||
|
field_name = info["field_name"]
|
||||||
|
field_data = input_data.get(field_name)
|
||||||
|
|
||||||
|
parsed = parse_auto_credential_field(
|
||||||
|
field_name=field_name,
|
||||||
|
info=info,
|
||||||
|
field_data=field_data,
|
||||||
|
field_present_in_input=field_name in input_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
if parsed.error:
|
||||||
|
raise ValueError(parsed.error)
|
||||||
|
|
||||||
|
if parsed.cred_id:
|
||||||
|
# Credential ID provided - acquire credentials
|
||||||
|
try:
|
||||||
|
credentials, lock = await creds_manager.acquire(user_id, parsed.cred_id)
|
||||||
|
locks.append(lock)
|
||||||
|
extra_exec_kwargs[kwarg_name] = credentials
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f"{parsed.provider.capitalize()} credentials for "
|
||||||
|
f"'{parsed.file_name}' in field '{parsed.field_name}' are not "
|
||||||
|
f"available in your account. "
|
||||||
|
f"This can happen if the agent was created by another "
|
||||||
|
f"user or the credentials were deleted. "
|
||||||
|
f"Please open the agent in the builder and re-select "
|
||||||
|
f"the file to authenticate with your own account."
|
||||||
|
)
|
||||||
|
|
||||||
|
return extra_exec_kwargs, locks
|
||||||
|
|
||||||
|
|
||||||
async def execute_node(
|
async def execute_node(
|
||||||
node: Node,
|
node: Node,
|
||||||
data: NodeExecutionEntry,
|
data: NodeExecutionEntry,
|
||||||
@@ -271,41 +326,14 @@ async def execute_node(
|
|||||||
extra_exec_kwargs[field_name] = credentials
|
extra_exec_kwargs[field_name] = credentials
|
||||||
|
|
||||||
# Handle auto-generated credentials (e.g., from GoogleDriveFileInput)
|
# Handle auto-generated credentials (e.g., from GoogleDriveFileInput)
|
||||||
for kwarg_name, info in input_model.get_auto_credentials_fields().items():
|
auto_extra_kwargs, auto_locks = await _acquire_auto_credentials(
|
||||||
field_name = info["field_name"]
|
input_model=input_model,
|
||||||
field_data = input_data.get(field_name)
|
input_data=input_data,
|
||||||
if field_data and isinstance(field_data, dict):
|
creds_manager=creds_manager,
|
||||||
# Check if _credentials_id key exists in the field data
|
user_id=user_id,
|
||||||
if "_credentials_id" in field_data:
|
)
|
||||||
cred_id = field_data["_credentials_id"]
|
extra_exec_kwargs.update(auto_extra_kwargs)
|
||||||
if cred_id:
|
creds_locks.extend(auto_locks)
|
||||||
# Credential ID provided - acquire credentials
|
|
||||||
provider = info.get("config", {}).get(
|
|
||||||
"provider", "external service"
|
|
||||||
)
|
|
||||||
file_name = field_data.get("name", "selected file")
|
|
||||||
try:
|
|
||||||
credentials, lock = await creds_manager.acquire(
|
|
||||||
user_id, cred_id
|
|
||||||
)
|
|
||||||
creds_locks.append(lock)
|
|
||||||
extra_exec_kwargs[kwarg_name] = credentials
|
|
||||||
except ValueError:
|
|
||||||
# Credential was deleted or doesn't exist
|
|
||||||
raise ValueError(
|
|
||||||
f"Authentication expired for '{file_name}' in field '{field_name}'. "
|
|
||||||
f"The saved {provider.capitalize()} credentials no longer exist. "
|
|
||||||
f"Please re-select the file to re-authenticate."
|
|
||||||
)
|
|
||||||
# else: _credentials_id is explicitly None, skip credentials (for chained data)
|
|
||||||
else:
|
|
||||||
# _credentials_id key missing entirely - this is an error
|
|
||||||
provider = info.get("config", {}).get("provider", "external service")
|
|
||||||
file_name = field_data.get("name", "selected file")
|
|
||||||
raise ValueError(
|
|
||||||
f"Authentication missing for '{file_name}' in field '{field_name}'. "
|
|
||||||
f"Please re-select the file to authenticate with {provider.capitalize()}."
|
|
||||||
)
|
|
||||||
|
|
||||||
output_size = 0
|
output_size = 0
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,320 @@
|
|||||||
|
"""
|
||||||
|
Tests for auto_credentials handling in execute_node().
|
||||||
|
|
||||||
|
These test the _acquire_auto_credentials() helper function extracted from
|
||||||
|
execute_node() (manager.py lines 273-308).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def google_drive_file_data():
|
||||||
|
return {
|
||||||
|
"valid": {
|
||||||
|
"_credentials_id": "cred-id-123",
|
||||||
|
"id": "file-123",
|
||||||
|
"name": "test.xlsx",
|
||||||
|
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||||
|
},
|
||||||
|
"chained": {
|
||||||
|
"_credentials_id": None,
|
||||||
|
"id": "file-456",
|
||||||
|
"name": "chained.xlsx",
|
||||||
|
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||||
|
},
|
||||||
|
"missing_key": {
|
||||||
|
"id": "file-789",
|
||||||
|
"name": "bad.xlsx",
|
||||||
|
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_input_model(mocker: MockerFixture):
|
||||||
|
"""Create a mock input model with get_auto_credentials_fields() returning one field."""
|
||||||
|
input_model = mocker.MagicMock()
|
||||||
|
input_model.get_auto_credentials_fields.return_value = {
|
||||||
|
"credentials": {
|
||||||
|
"field_name": "spreadsheet",
|
||||||
|
"config": {
|
||||||
|
"provider": "google",
|
||||||
|
"type": "oauth2",
|
||||||
|
"scopes": ["https://www.googleapis.com/auth/drive.readonly"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return input_model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_creds_manager(mocker: MockerFixture):
|
||||||
|
manager = mocker.AsyncMock()
|
||||||
|
mock_lock = mocker.AsyncMock()
|
||||||
|
mock_creds = mocker.MagicMock()
|
||||||
|
mock_creds.id = "cred-id-123"
|
||||||
|
mock_creds.provider = "google"
|
||||||
|
manager.acquire.return_value = (mock_creds, mock_lock)
|
||||||
|
return manager, mock_creds, mock_lock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_happy_path(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
google_drive_file_data,
|
||||||
|
mock_input_model,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""When field_data has a valid _credentials_id, credentials should be acquired."""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, mock_creds, mock_lock = mock_creds_manager
|
||||||
|
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
||||||
|
|
||||||
|
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||||
|
input_model=mock_input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
manager.acquire.assert_called_once_with("user-1", "cred-id-123")
|
||||||
|
assert extra_kwargs["credentials"] == mock_creds
|
||||||
|
assert mock_lock in locks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_field_none_static_raises(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
mock_input_model,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[THE BUG FIX TEST — OPEN-2895]
|
||||||
|
When field_data is None and the key IS in input_data (user didn't select a file),
|
||||||
|
should raise ValueError instead of silently skipping.
|
||||||
|
"""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, _, _ = mock_creds_manager
|
||||||
|
# Key is present but value is None = user didn't select a file
|
||||||
|
input_data = {"spreadsheet": None}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No file selected"):
|
||||||
|
await _acquire_auto_credentials(
|
||||||
|
input_model=mock_input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_field_absent_skips(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
mock_input_model,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
When the field key is NOT in input_data at all (upstream connection),
|
||||||
|
should skip without error.
|
||||||
|
"""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, _, _ = mock_creds_manager
|
||||||
|
# Key not present = connected from upstream block
|
||||||
|
input_data = {}
|
||||||
|
|
||||||
|
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||||
|
input_model=mock_input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
manager.acquire.assert_not_called()
|
||||||
|
assert "credentials" not in extra_kwargs
|
||||||
|
assert locks == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_chained_cred_id_none(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
google_drive_file_data,
|
||||||
|
mock_input_model,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
When _credentials_id is explicitly None (chained data from upstream),
|
||||||
|
should skip credential acquisition.
|
||||||
|
"""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, _, _ = mock_creds_manager
|
||||||
|
input_data = {"spreadsheet": google_drive_file_data["chained"]}
|
||||||
|
|
||||||
|
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||||
|
input_model=mock_input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
manager.acquire.assert_not_called()
|
||||||
|
assert "credentials" not in extra_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_missing_cred_id_key_raises(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
google_drive_file_data,
|
||||||
|
mock_input_model,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
When _credentials_id key is missing entirely from field_data dict,
|
||||||
|
should raise ValueError.
|
||||||
|
"""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, _, _ = mock_creds_manager
|
||||||
|
input_data = {"spreadsheet": google_drive_file_data["missing_key"]}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Authentication missing"):
|
||||||
|
await _acquire_auto_credentials(
|
||||||
|
input_model=mock_input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_ownership_mismatch_error(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
google_drive_file_data,
|
||||||
|
mock_input_model,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[SECRT-1772] When acquire() raises ValueError (credential belongs to another user),
|
||||||
|
the error message should mention 'not available' (not 'expired').
|
||||||
|
"""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, _, _ = mock_creds_manager
|
||||||
|
manager.acquire.side_effect = ValueError(
|
||||||
|
"Credentials #cred-id-123 for user #user-2 not found"
|
||||||
|
)
|
||||||
|
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not available in your account"):
|
||||||
|
await _acquire_auto_credentials(
|
||||||
|
input_model=mock_input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_deleted_credential_error(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
google_drive_file_data,
|
||||||
|
mock_input_model,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[SECRT-1772] When acquire() raises ValueError (credential was deleted),
|
||||||
|
the error message should mention 'not available' (not 'expired').
|
||||||
|
"""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, _, _ = mock_creds_manager
|
||||||
|
manager.acquire.side_effect = ValueError(
|
||||||
|
"Credentials #cred-id-123 for user #user-1 not found"
|
||||||
|
)
|
||||||
|
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not available in your account"):
|
||||||
|
await _acquire_auto_credentials(
|
||||||
|
input_model=mock_input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_lock_appended(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
google_drive_file_data,
|
||||||
|
mock_input_model,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""Lock from acquire() should be included in returned locks list."""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, _, mock_lock = mock_creds_manager
|
||||||
|
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
||||||
|
|
||||||
|
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||||
|
input_model=mock_input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(locks) == 1
|
||||||
|
assert locks[0] is mock_lock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_multiple_fields(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""When there are multiple auto_credentials fields, only valid ones should acquire."""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, mock_creds, mock_lock = mock_creds_manager
|
||||||
|
|
||||||
|
input_model = mocker.MagicMock()
|
||||||
|
input_model.get_auto_credentials_fields.return_value = {
|
||||||
|
"credentials": {
|
||||||
|
"field_name": "spreadsheet",
|
||||||
|
"config": {"provider": "google", "type": "oauth2"},
|
||||||
|
},
|
||||||
|
"credentials2": {
|
||||||
|
"field_name": "doc_file",
|
||||||
|
"config": {"provider": "google", "type": "oauth2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
input_data = {
|
||||||
|
"spreadsheet": {
|
||||||
|
"_credentials_id": "cred-id-123",
|
||||||
|
"id": "file-1",
|
||||||
|
"name": "file1.xlsx",
|
||||||
|
},
|
||||||
|
"doc_file": {
|
||||||
|
"_credentials_id": None,
|
||||||
|
"id": "file-2",
|
||||||
|
"name": "chained.doc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||||
|
input_model=input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only the first field should have acquired credentials
|
||||||
|
manager.acquire.assert_called_once_with("user-1", "cred-id-123")
|
||||||
|
assert "credentials" in extra_kwargs
|
||||||
|
assert "credentials2" not in extra_kwargs
|
||||||
|
assert len(locks) == 1
|
||||||
@@ -4,7 +4,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from typing import Mapping, Optional, cast
|
from typing import Any, Mapping, Optional, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, JsonValue, ValidationError
|
from pydantic import BaseModel, JsonValue, ValidationError
|
||||||
|
|
||||||
@@ -55,6 +55,87 @@ from backend.util.type import convert
|
|||||||
config = Config()
|
config = Config()
|
||||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
|
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
|
||||||
|
|
||||||
|
# ============ Auto-Credentials Helpers ============ #
|
||||||
|
|
||||||
|
|
||||||
|
class AutoCredentialFieldInfo(BaseModel):
|
||||||
|
"""Parsed info from an auto-credential field (e.g., GoogleDriveFileField)."""
|
||||||
|
|
||||||
|
cred_id: str | None
|
||||||
|
"""The credential ID to use, or None if not provided."""
|
||||||
|
provider: str
|
||||||
|
"""The provider name (e.g., 'google')."""
|
||||||
|
file_name: str
|
||||||
|
"""The display name for error messages."""
|
||||||
|
field_name: str
|
||||||
|
"""The original field name in the schema."""
|
||||||
|
error: str | None = None
|
||||||
|
"""Validation error message, if any."""
|
||||||
|
|
||||||
|
|
||||||
|
def parse_auto_credential_field(
|
||||||
|
field_name: str,
|
||||||
|
info: dict,
|
||||||
|
field_data: Any,
|
||||||
|
*,
|
||||||
|
field_present_in_input: bool = True,
|
||||||
|
) -> AutoCredentialFieldInfo:
|
||||||
|
"""
|
||||||
|
Parse auto-credential field data and extract credential info.
|
||||||
|
|
||||||
|
This is shared logic used by both credential acquisition (manager.py)
|
||||||
|
and credential validation (utils.py).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_name: The name of the field in the schema
|
||||||
|
info: The auto_credentials field info from get_auto_credentials_fields()
|
||||||
|
field_data: The actual field data from input
|
||||||
|
field_present_in_input: Whether the field key exists in input_data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AutoCredentialFieldInfo with parsed data and any validation errors
|
||||||
|
"""
|
||||||
|
provider = info.get("config", {}).get("provider", "external service")
|
||||||
|
file_name = (
|
||||||
|
field_data.get("name", "selected file")
|
||||||
|
if isinstance(field_data, dict)
|
||||||
|
else "selected file"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = AutoCredentialFieldInfo(
|
||||||
|
cred_id=None,
|
||||||
|
provider=provider,
|
||||||
|
file_name=file_name,
|
||||||
|
field_name=field_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
if field_data and isinstance(field_data, dict):
|
||||||
|
if "_credentials_id" not in field_data:
|
||||||
|
# Key removed (e.g., on fork) — needs re-auth
|
||||||
|
result.error = (
|
||||||
|
f"Authentication missing for '{file_name}' in field "
|
||||||
|
f"'{field_name}'. Please re-select the file to authenticate "
|
||||||
|
f"with {provider.capitalize()}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cred_id = field_data.get("_credentials_id")
|
||||||
|
if cred_id:
|
||||||
|
result.cred_id = cred_id
|
||||||
|
# else: _credentials_id is explicitly None, skip (chained data)
|
||||||
|
elif field_data is None and not field_present_in_input:
|
||||||
|
# Field not in input_data at all = connected from upstream block, skip
|
||||||
|
pass
|
||||||
|
elif field_present_in_input:
|
||||||
|
# field_data is None/empty but key IS in input_data = user didn't select
|
||||||
|
result.error = (
|
||||||
|
f"No file selected for '{field_name}'. "
|
||||||
|
f"Please select a file to provide "
|
||||||
|
f"{provider.capitalize()} authentication."
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
# ============ Resource Helpers ============ #
|
# ============ Resource Helpers ============ #
|
||||||
|
|
||||||
|
|
||||||
@@ -259,7 +340,8 @@ async def _validate_node_input_credentials(
|
|||||||
|
|
||||||
# Find any fields of type CredentialsMetaInput
|
# Find any fields of type CredentialsMetaInput
|
||||||
credentials_fields = block.input_schema.get_credentials_fields()
|
credentials_fields = block.input_schema.get_credentials_fields()
|
||||||
if not credentials_fields:
|
auto_credentials_fields = block.input_schema.get_auto_credentials_fields()
|
||||||
|
if not credentials_fields and not auto_credentials_fields:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Track if any credential field is missing for this node
|
# Track if any credential field is missing for this node
|
||||||
@@ -339,6 +421,52 @@ async def _validate_node_input_credentials(
|
|||||||
] = "Invalid credentials: type/provider mismatch"
|
] = "Invalid credentials: type/provider mismatch"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Validate auto-credentials (GoogleDriveFileField-based)
|
||||||
|
# These have _credentials_id embedded in the file field data
|
||||||
|
if auto_credentials_fields:
|
||||||
|
for _kwarg_name, info in auto_credentials_fields.items():
|
||||||
|
field_name = info["field_name"]
|
||||||
|
# Check input_default and nodes_input_masks for the field value
|
||||||
|
field_value = node.input_default.get(field_name)
|
||||||
|
if nodes_input_masks and node.id in nodes_input_masks:
|
||||||
|
field_value = nodes_input_masks[node.id].get(
|
||||||
|
field_name, field_value
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use shared helper to parse the field
|
||||||
|
parsed = parse_auto_credential_field(
|
||||||
|
field_name=field_name,
|
||||||
|
info=info,
|
||||||
|
field_data=field_value,
|
||||||
|
field_present_in_input=True, # For validation, assume present
|
||||||
|
)
|
||||||
|
|
||||||
|
if parsed.error:
|
||||||
|
has_missing_credentials = True
|
||||||
|
credential_errors[node.id][field_name] = parsed.error
|
||||||
|
continue
|
||||||
|
|
||||||
|
if parsed.cred_id:
|
||||||
|
# Validate that credentials exist and are accessible
|
||||||
|
try:
|
||||||
|
creds_store = get_integration_credentials_store()
|
||||||
|
creds = await creds_store.get_creds_by_id(
|
||||||
|
user_id, parsed.cred_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
has_missing_credentials = True
|
||||||
|
credential_errors[node.id][
|
||||||
|
field_name
|
||||||
|
] = f"Credentials not available: {e}"
|
||||||
|
continue
|
||||||
|
if not creds:
|
||||||
|
has_missing_credentials = True
|
||||||
|
credential_errors[node.id][field_name] = (
|
||||||
|
"The saved credentials are not available "
|
||||||
|
"for your account. Please re-select the file to "
|
||||||
|
"authenticate with your own account."
|
||||||
|
)
|
||||||
|
|
||||||
# If node has optional credentials and any are missing, mark for skipping
|
# If node has optional credentials and any are missing, mark for skipping
|
||||||
# But only if there are no other errors for this node
|
# But only if there are no other errors for this node
|
||||||
if (
|
if (
|
||||||
@@ -370,8 +498,9 @@ def make_node_credentials_input_map(
|
|||||||
"""
|
"""
|
||||||
result: dict[str, dict[str, JsonValue]] = {}
|
result: dict[str, dict[str, JsonValue]] = {}
|
||||||
|
|
||||||
# Get aggregated credentials fields for the graph
|
# Only map regular credentials (not auto_credentials, which are resolved
|
||||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
# at execution time from _credentials_id in file field data)
|
||||||
|
graph_cred_inputs = graph.regular_credentials_inputs
|
||||||
|
|
||||||
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
|
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
|
||||||
# Best-effort map: skip missing items
|
# Best-effort map: skip missing items
|
||||||
|
|||||||
@@ -907,3 +907,335 @@ async def test_stop_graph_execution_cascades_to_child_with_reviews(
|
|||||||
|
|
||||||
# Verify both parent and child status updates
|
# Verify both parent and child status updates
|
||||||
assert mock_execution_db.update_graph_execution_stats.call_count >= 1
|
assert mock_execution_db.update_graph_execution_stats.call_count >= 1
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for auto_credentials validation in _validate_node_input_credentials
|
||||||
|
# (Fix 3: SECRT-1772 + Fix 4: Path 4)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_node_input_credentials_auto_creds_valid(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[SECRT-1772] When a node has auto_credentials with a valid _credentials_id
|
||||||
|
that exists in the store, validation should pass without errors.
|
||||||
|
"""
|
||||||
|
from backend.executor.utils import _validate_node_input_credentials
|
||||||
|
|
||||||
|
mock_node = mocker.MagicMock()
|
||||||
|
mock_node.id = "node-with-auto-creds"
|
||||||
|
mock_node.credentials_optional = False
|
||||||
|
mock_node.input_default = {
|
||||||
|
"spreadsheet": {
|
||||||
|
"_credentials_id": "valid-cred-id",
|
||||||
|
"id": "file-123",
|
||||||
|
"name": "test.xlsx",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_block = mocker.MagicMock()
|
||||||
|
# No regular credentials fields
|
||||||
|
mock_block.input_schema.get_credentials_fields.return_value = {}
|
||||||
|
# Has auto_credentials fields
|
||||||
|
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
||||||
|
"credentials": {
|
||||||
|
"field_name": "spreadsheet",
|
||||||
|
"config": {"provider": "google", "type": "oauth2"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_node.block = mock_block
|
||||||
|
|
||||||
|
mock_graph = mocker.MagicMock()
|
||||||
|
mock_graph.nodes = [mock_node]
|
||||||
|
|
||||||
|
# Mock the credentials store to return valid credentials
|
||||||
|
mock_store = mocker.MagicMock()
|
||||||
|
mock_creds = mocker.MagicMock()
|
||||||
|
mock_creds.id = "valid-cred-id"
|
||||||
|
mock_store.get_creds_by_id = mocker.AsyncMock(return_value=mock_creds)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.executor.utils.get_integration_credentials_store",
|
||||||
|
return_value=mock_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||||
|
graph=mock_graph,
|
||||||
|
user_id="test-user",
|
||||||
|
nodes_input_masks=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_node.id not in errors
|
||||||
|
assert mock_node.id not in nodes_to_skip
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_node_input_credentials_auto_creds_missing(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[SECRT-1772] When a node has auto_credentials with a _credentials_id
|
||||||
|
that doesn't exist for the current user, validation should report an error.
|
||||||
|
"""
|
||||||
|
from backend.executor.utils import _validate_node_input_credentials
|
||||||
|
|
||||||
|
mock_node = mocker.MagicMock()
|
||||||
|
mock_node.id = "node-with-bad-auto-creds"
|
||||||
|
mock_node.credentials_optional = False
|
||||||
|
mock_node.input_default = {
|
||||||
|
"spreadsheet": {
|
||||||
|
"_credentials_id": "other-users-cred-id",
|
||||||
|
"id": "file-123",
|
||||||
|
"name": "test.xlsx",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_block = mocker.MagicMock()
|
||||||
|
mock_block.input_schema.get_credentials_fields.return_value = {}
|
||||||
|
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
||||||
|
"credentials": {
|
||||||
|
"field_name": "spreadsheet",
|
||||||
|
"config": {"provider": "google", "type": "oauth2"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_node.block = mock_block
|
||||||
|
|
||||||
|
mock_graph = mocker.MagicMock()
|
||||||
|
mock_graph.nodes = [mock_node]
|
||||||
|
|
||||||
|
# Mock the credentials store to return None (cred not found for this user)
|
||||||
|
mock_store = mocker.MagicMock()
|
||||||
|
mock_store.get_creds_by_id = mocker.AsyncMock(return_value=None)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.executor.utils.get_integration_credentials_store",
|
||||||
|
return_value=mock_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||||
|
graph=mock_graph,
|
||||||
|
user_id="different-user",
|
||||||
|
nodes_input_masks=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_node.id in errors
|
||||||
|
assert "spreadsheet" in errors[mock_node.id]
|
||||||
|
assert "not available" in errors[mock_node.id]["spreadsheet"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_node_input_credentials_both_regular_and_auto(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[SECRT-1772] A node that has BOTH regular credentials AND auto_credentials
|
||||||
|
should have both validated.
|
||||||
|
"""
|
||||||
|
from backend.executor.utils import _validate_node_input_credentials
|
||||||
|
|
||||||
|
mock_node = mocker.MagicMock()
|
||||||
|
mock_node.id = "node-with-both-creds"
|
||||||
|
mock_node.credentials_optional = False
|
||||||
|
mock_node.input_default = {
|
||||||
|
"credentials": {
|
||||||
|
"id": "regular-cred-id",
|
||||||
|
"provider": "github",
|
||||||
|
"type": "api_key",
|
||||||
|
},
|
||||||
|
"spreadsheet": {
|
||||||
|
"_credentials_id": "auto-cred-id",
|
||||||
|
"id": "file-123",
|
||||||
|
"name": "test.xlsx",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_credentials_field_type = mocker.MagicMock()
|
||||||
|
mock_credentials_meta = mocker.MagicMock()
|
||||||
|
mock_credentials_meta.id = "regular-cred-id"
|
||||||
|
mock_credentials_meta.provider = "github"
|
||||||
|
mock_credentials_meta.type = "api_key"
|
||||||
|
mock_credentials_field_type.model_validate.return_value = mock_credentials_meta
|
||||||
|
|
||||||
|
mock_block = mocker.MagicMock()
|
||||||
|
# Regular credentials field
|
||||||
|
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||||
|
"credentials": mock_credentials_field_type,
|
||||||
|
}
|
||||||
|
# Auto-credentials field
|
||||||
|
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
||||||
|
"auto_credentials": {
|
||||||
|
"field_name": "spreadsheet",
|
||||||
|
"config": {"provider": "google", "type": "oauth2"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_node.block = mock_block
|
||||||
|
|
||||||
|
mock_graph = mocker.MagicMock()
|
||||||
|
mock_graph.nodes = [mock_node]
|
||||||
|
|
||||||
|
# Mock the credentials store to return valid credentials for both
|
||||||
|
mock_store = mocker.MagicMock()
|
||||||
|
mock_regular_creds = mocker.MagicMock()
|
||||||
|
mock_regular_creds.id = "regular-cred-id"
|
||||||
|
mock_regular_creds.provider = "github"
|
||||||
|
mock_regular_creds.type = "api_key"
|
||||||
|
|
||||||
|
mock_auto_creds = mocker.MagicMock()
|
||||||
|
mock_auto_creds.id = "auto-cred-id"
|
||||||
|
|
||||||
|
def get_creds_side_effect(user_id, cred_id):
|
||||||
|
if cred_id == "regular-cred-id":
|
||||||
|
return mock_regular_creds
|
||||||
|
elif cred_id == "auto-cred-id":
|
||||||
|
return mock_auto_creds
|
||||||
|
return None
|
||||||
|
|
||||||
|
mock_store.get_creds_by_id = mocker.AsyncMock(side_effect=get_creds_side_effect)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.executor.utils.get_integration_credentials_store",
|
||||||
|
return_value=mock_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||||
|
graph=mock_graph,
|
||||||
|
user_id="test-user",
|
||||||
|
nodes_input_masks=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Both should validate successfully - no errors
|
||||||
|
assert mock_node.id not in errors
|
||||||
|
assert mock_node.id not in nodes_to_skip
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_node_input_credentials_auto_creds_skipped_when_none(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
When a node has auto_credentials but the field value has _credentials_id=None
|
||||||
|
(e.g., from upstream connection), validation should skip it without error.
|
||||||
|
"""
|
||||||
|
from backend.executor.utils import _validate_node_input_credentials
|
||||||
|
|
||||||
|
mock_node = mocker.MagicMock()
|
||||||
|
mock_node.id = "node-with-chained-auto-creds"
|
||||||
|
mock_node.credentials_optional = False
|
||||||
|
mock_node.input_default = {
|
||||||
|
"spreadsheet": {
|
||||||
|
"_credentials_id": None,
|
||||||
|
"id": "file-123",
|
||||||
|
"name": "test.xlsx",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_block = mocker.MagicMock()
|
||||||
|
mock_block.input_schema.get_credentials_fields.return_value = {}
|
||||||
|
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
||||||
|
"credentials": {
|
||||||
|
"field_name": "spreadsheet",
|
||||||
|
"config": {"provider": "google", "type": "oauth2"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_node.block = mock_block
|
||||||
|
|
||||||
|
mock_graph = mocker.MagicMock()
|
||||||
|
mock_graph.nodes = [mock_node]
|
||||||
|
|
||||||
|
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||||
|
graph=mock_graph,
|
||||||
|
user_id="test-user",
|
||||||
|
nodes_input_masks=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# No error - chained data with None cred_id is valid
|
||||||
|
assert mock_node.id not in errors
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for CredentialsFieldInfo auto_credential tag (Fix 4: Path 4)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_credentials_field_info_auto_credential_tag():
|
||||||
|
"""
|
||||||
|
[Path 4] CredentialsFieldInfo should support is_auto_credential and
|
||||||
|
input_field_name fields for distinguishing auto from regular credentials.
|
||||||
|
"""
|
||||||
|
from backend.data.model import CredentialsFieldInfo
|
||||||
|
|
||||||
|
# Regular credential should have is_auto_credential=False by default
|
||||||
|
regular = CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["github"],
|
||||||
|
"credentials_types": ["api_key"],
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
assert regular.is_auto_credential is False
|
||||||
|
assert regular.input_field_name is None
|
||||||
|
|
||||||
|
# Auto credential should have is_auto_credential=True
|
||||||
|
auto = CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["google"],
|
||||||
|
"credentials_types": ["oauth2"],
|
||||||
|
"is_auto_credential": True,
|
||||||
|
"input_field_name": "spreadsheet",
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
assert auto.is_auto_credential is True
|
||||||
|
assert auto.input_field_name == "spreadsheet"
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_node_credentials_input_map_excludes_auto_creds(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[Path 4] make_node_credentials_input_map should only include regular credentials,
|
||||||
|
not auto_credentials (which are resolved at execution time).
|
||||||
|
"""
|
||||||
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
|
from backend.executor.utils import make_node_credentials_input_map
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
# Create a mock graph with aggregate_credentials_inputs that returns
|
||||||
|
# both regular and auto credentials
|
||||||
|
mock_graph = mocker.MagicMock()
|
||||||
|
|
||||||
|
regular_field_info = CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["github"],
|
||||||
|
"credentials_types": ["api_key"],
|
||||||
|
"is_auto_credential": False,
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock regular_credentials_inputs property (auto_credentials are excluded)
|
||||||
|
mock_graph.regular_credentials_inputs = {
|
||||||
|
"github_creds": (regular_field_info, {("node-1", "credentials")}, True),
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_credentials_input = {
|
||||||
|
"github_creds": CredentialsMetaInput(
|
||||||
|
id="cred-123",
|
||||||
|
provider=ProviderName("github"),
|
||||||
|
type="api_key",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = make_node_credentials_input_map(mock_graph, graph_credentials_input)
|
||||||
|
|
||||||
|
# Regular credentials should be mapped
|
||||||
|
assert "node-1" in result
|
||||||
|
assert "credentials" in result["node-1"]
|
||||||
|
|
||||||
|
# Auto credentials should NOT appear in the result
|
||||||
|
# (they would have been mapped to the kwarg_name "credentials" not "spreadsheet")
|
||||||
|
for node_id, fields in result.items():
|
||||||
|
for field_name, value in fields.items():
|
||||||
|
# Verify no auto-credential phantom entries
|
||||||
|
if isinstance(value, dict):
|
||||||
|
assert "_credentials_id" not in value
|
||||||
|
|||||||
@@ -342,14 +342,6 @@ async def store_media_file(
|
|||||||
if not target_path.is_file():
|
if not target_path.is_file():
|
||||||
raise ValueError(f"Local file does not exist: {target_path}")
|
raise ValueError(f"Local file does not exist: {target_path}")
|
||||||
|
|
||||||
# Virus scan the local file before any further processing
|
|
||||||
local_content = target_path.read_bytes()
|
|
||||||
if len(local_content) > MAX_FILE_SIZE_BYTES:
|
|
||||||
raise ValueError(
|
|
||||||
f"File too large: {len(local_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
|
||||||
)
|
|
||||||
await scan_content_safe(local_content, filename=sanitized_file)
|
|
||||||
|
|
||||||
# Return based on requested format
|
# Return based on requested format
|
||||||
if return_format == "for_local_processing":
|
if return_format == "for_local_processing":
|
||||||
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
||||||
|
|||||||
@@ -247,100 +247,3 @@ class TestFileCloudIntegration:
|
|||||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
return_format="for_local_processing",
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_store_media_file_local_path_scanned(self):
|
|
||||||
"""Test that local file paths are scanned for viruses."""
|
|
||||||
graph_exec_id = "test-exec-123"
|
|
||||||
local_file = "test_video.mp4"
|
|
||||||
file_content = b"fake video content"
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.util.file.get_cloud_storage_handler"
|
|
||||||
) as mock_handler_getter, patch(
|
|
||||||
"backend.util.file.scan_content_safe"
|
|
||||||
) as mock_scan, patch(
|
|
||||||
"backend.util.file.Path"
|
|
||||||
) as mock_path_class:
|
|
||||||
|
|
||||||
# Mock cloud storage handler - not a cloud path
|
|
||||||
mock_handler = MagicMock()
|
|
||||||
mock_handler.is_cloud_path.return_value = False
|
|
||||||
mock_handler_getter.return_value = mock_handler
|
|
||||||
|
|
||||||
# Mock virus scanner
|
|
||||||
mock_scan.return_value = None
|
|
||||||
|
|
||||||
# Mock file system operations
|
|
||||||
mock_base_path = MagicMock()
|
|
||||||
mock_target_path = MagicMock()
|
|
||||||
mock_resolved_path = MagicMock()
|
|
||||||
|
|
||||||
mock_path_class.return_value = mock_base_path
|
|
||||||
mock_base_path.mkdir = MagicMock()
|
|
||||||
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
|
||||||
mock_target_path.resolve.return_value = mock_resolved_path
|
|
||||||
mock_resolved_path.is_relative_to.return_value = True
|
|
||||||
mock_resolved_path.is_file.return_value = True
|
|
||||||
mock_resolved_path.read_bytes.return_value = file_content
|
|
||||||
mock_resolved_path.relative_to.return_value = Path(local_file)
|
|
||||||
mock_resolved_path.name = local_file
|
|
||||||
|
|
||||||
result = await store_media_file(
|
|
||||||
file=MediaFileType(local_file),
|
|
||||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify virus scan was called for local file
|
|
||||||
mock_scan.assert_called_once_with(file_content, filename=local_file)
|
|
||||||
|
|
||||||
# Result should be the relative path
|
|
||||||
assert str(result) == local_file
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_store_media_file_local_path_virus_detected(self):
|
|
||||||
"""Test that infected local files raise VirusDetectedError."""
|
|
||||||
from backend.api.features.store.exceptions import VirusDetectedError
|
|
||||||
|
|
||||||
graph_exec_id = "test-exec-123"
|
|
||||||
local_file = "infected.exe"
|
|
||||||
file_content = b"malicious content"
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.util.file.get_cloud_storage_handler"
|
|
||||||
) as mock_handler_getter, patch(
|
|
||||||
"backend.util.file.scan_content_safe"
|
|
||||||
) as mock_scan, patch(
|
|
||||||
"backend.util.file.Path"
|
|
||||||
) as mock_path_class:
|
|
||||||
|
|
||||||
# Mock cloud storage handler - not a cloud path
|
|
||||||
mock_handler = MagicMock()
|
|
||||||
mock_handler.is_cloud_path.return_value = False
|
|
||||||
mock_handler_getter.return_value = mock_handler
|
|
||||||
|
|
||||||
# Mock virus scanner to detect virus
|
|
||||||
mock_scan.side_effect = VirusDetectedError(
|
|
||||||
"EICAR-Test-File", "File rejected due to virus detection"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock file system operations
|
|
||||||
mock_base_path = MagicMock()
|
|
||||||
mock_target_path = MagicMock()
|
|
||||||
mock_resolved_path = MagicMock()
|
|
||||||
|
|
||||||
mock_path_class.return_value = mock_base_path
|
|
||||||
mock_base_path.mkdir = MagicMock()
|
|
||||||
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
|
||||||
mock_target_path.resolve.return_value = mock_resolved_path
|
|
||||||
mock_resolved_path.is_relative_to.return_value = True
|
|
||||||
mock_resolved_path.is_file.return_value = True
|
|
||||||
mock_resolved_path.read_bytes.return_value = file_content
|
|
||||||
|
|
||||||
with pytest.raises(VirusDetectedError):
|
|
||||||
await store_media_file(
|
|
||||||
file=MediaFileType(local_file),
|
|
||||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|||||||
7101
autogpt_platform/backend/poetry.lock
generated
7101
autogpt_platform/backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -11,17 +11,17 @@ packages = [{ include = "backend", format = "sdist" }]
|
|||||||
python = ">=3.10,<3.14"
|
python = ">=3.10,<3.14"
|
||||||
aio-pika = "^9.5.5"
|
aio-pika = "^9.5.5"
|
||||||
aiohttp = "^3.10.0"
|
aiohttp = "^3.10.0"
|
||||||
aiodns = "^4.0.0"
|
aiodns = "^3.5.0"
|
||||||
anthropic = "^0.79.0"
|
anthropic = "^0.59.0"
|
||||||
apscheduler = "^3.11.1"
|
apscheduler = "^3.11.1"
|
||||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||||
click = "^8.2.0"
|
click = "^8.2.0"
|
||||||
cryptography = "^46.0"
|
cryptography = "^45.0"
|
||||||
discord-py = "^2.5.2"
|
discord-py = "^2.5.2"
|
||||||
e2b-code-interpreter = "^2.4.1"
|
e2b-code-interpreter = "^1.5.2"
|
||||||
elevenlabs = "^1.50.0"
|
elevenlabs = "^1.50.0"
|
||||||
fastapi = "^0.128.6"
|
fastapi = "^0.116.1"
|
||||||
feedparser = "^6.0.11"
|
feedparser = "^6.0.11"
|
||||||
flake8 = "^7.3.0"
|
flake8 = "^7.3.0"
|
||||||
google-api-python-client = "^2.177.0"
|
google-api-python-client = "^2.177.0"
|
||||||
@@ -29,16 +29,16 @@ google-auth-oauthlib = "^1.2.2"
|
|||||||
google-cloud-storage = "^3.2.0"
|
google-cloud-storage = "^3.2.0"
|
||||||
googlemaps = "^4.10.0"
|
googlemaps = "^4.10.0"
|
||||||
gravitasml = "^0.1.4"
|
gravitasml = "^0.1.4"
|
||||||
groq = "^1.0.0"
|
groq = "^0.30.0"
|
||||||
html2text = "^2024.2.26"
|
html2text = "^2024.2.26"
|
||||||
jinja2 = "^3.1.6"
|
jinja2 = "^3.1.6"
|
||||||
jsonref = "^1.1.0"
|
jsonref = "^1.1.0"
|
||||||
jsonschema = "^4.25.0"
|
jsonschema = "^4.25.0"
|
||||||
langfuse = "^3.14.1"
|
langfuse = "^3.11.0"
|
||||||
launchdarkly-server-sdk = "^9.14.1"
|
launchdarkly-server-sdk = "^9.12.0"
|
||||||
mem0ai = "^0.1.115"
|
mem0ai = "^0.1.115"
|
||||||
moviepy = "^2.1.2"
|
moviepy = "^2.1.2"
|
||||||
ollama = "^0.6.1"
|
ollama = "^0.5.1"
|
||||||
openai = "^1.97.1"
|
openai = "^1.97.1"
|
||||||
orjson = "^3.10.0"
|
orjson = "^3.10.0"
|
||||||
pika = "^1.3.2"
|
pika = "^1.3.2"
|
||||||
@@ -48,36 +48,36 @@ postmarker = "^1.0"
|
|||||||
praw = "~7.8.1"
|
praw = "~7.8.1"
|
||||||
prisma = "^0.15.0"
|
prisma = "^0.15.0"
|
||||||
rank-bm25 = "^0.2.2"
|
rank-bm25 = "^0.2.2"
|
||||||
prometheus-client = "^0.24.1"
|
prometheus-client = "^0.22.1"
|
||||||
prometheus-fastapi-instrumentator = "^7.0.0"
|
prometheus-fastapi-instrumentator = "^7.0.0"
|
||||||
psutil = "^7.0.0"
|
psutil = "^7.0.0"
|
||||||
psycopg2-binary = "^2.9.10"
|
psycopg2-binary = "^2.9.10"
|
||||||
pydantic = { extras = ["email"], version = "^2.12.5" }
|
pydantic = { extras = ["email"], version = "^2.11.7" }
|
||||||
pydantic-settings = "^2.12.0"
|
pydantic-settings = "^2.10.1"
|
||||||
pytest = "^8.4.1"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.1.0"
|
pytest-asyncio = "^1.1.0"
|
||||||
python-dotenv = "^1.1.1"
|
python-dotenv = "^1.1.1"
|
||||||
python-multipart = "^0.0.22"
|
python-multipart = "^0.0.20"
|
||||||
redis = "^7.1.1"
|
redis = "^6.2.0"
|
||||||
regex = "^2025.9.18"
|
regex = "^2025.9.18"
|
||||||
replicate = "^1.0.6"
|
replicate = "^1.0.6"
|
||||||
sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlalchemy"], version = "^2.44.0"}
|
sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlalchemy"], version = "^2.44.0"}
|
||||||
sqlalchemy = "^2.0.40"
|
sqlalchemy = "^2.0.40"
|
||||||
strenum = "^0.4.9"
|
strenum = "^0.4.9"
|
||||||
stripe = "^11.5.0"
|
stripe = "^11.5.0"
|
||||||
supabase = "2.28.0"
|
supabase = "2.17.0"
|
||||||
tenacity = "^9.1.4"
|
tenacity = "^9.1.2"
|
||||||
todoist-api-python = "^3.2.1"
|
todoist-api-python = "^2.1.7"
|
||||||
tweepy = "^4.16.0"
|
tweepy = "^4.16.0"
|
||||||
uvicorn = { extras = ["standard"], version = "^0.40.0" }
|
uvicorn = { extras = ["standard"], version = "^0.35.0" }
|
||||||
websockets = "^15.0"
|
websockets = "^15.0"
|
||||||
youtube-transcript-api = "^1.2.1"
|
youtube-transcript-api = "^1.2.1"
|
||||||
yt-dlp = "2026.2.4"
|
yt-dlp = "2025.12.08"
|
||||||
zerobouncesdk = "^1.1.2"
|
zerobouncesdk = "^1.1.2"
|
||||||
# NOTE: please insert new dependencies in their alphabetical location
|
# NOTE: please insert new dependencies in their alphabetical location
|
||||||
pytest-snapshot = "^0.9.0"
|
pytest-snapshot = "^0.9.0"
|
||||||
aiofiles = "^25.1.0"
|
aiofiles = "^24.1.0"
|
||||||
tiktoken = "^0.12.0"
|
tiktoken = "^0.9.0"
|
||||||
aioclamd = "^1.0.0"
|
aioclamd = "^1.0.0"
|
||||||
setuptools = "^80.9.0"
|
setuptools = "^80.9.0"
|
||||||
gcloud-aio-storage = "^9.5.0"
|
gcloud-aio-storage = "^9.5.0"
|
||||||
@@ -85,7 +85,7 @@ pandas = "^2.3.1"
|
|||||||
firecrawl-py = "^4.3.6"
|
firecrawl-py = "^4.3.6"
|
||||||
exa-py = "^1.14.20"
|
exa-py = "^1.14.20"
|
||||||
croniter = "^6.0.0"
|
croniter = "^6.0.0"
|
||||||
stagehand = "^3.5.0"
|
stagehand = "^0.5.1"
|
||||||
gravitas-md2gdocs = "^0.1.0"
|
gravitas-md2gdocs = "^0.1.0"
|
||||||
posthog = "^7.6.0"
|
posthog = "^7.6.0"
|
||||||
|
|
||||||
@@ -94,14 +94,14 @@ aiohappyeyeballs = "^2.6.1"
|
|||||||
black = "^24.10.0"
|
black = "^24.10.0"
|
||||||
faker = "^38.2.0"
|
faker = "^38.2.0"
|
||||||
httpx = "^0.28.1"
|
httpx = "^0.28.1"
|
||||||
isort = "^7.0.0"
|
isort = "^5.13.2"
|
||||||
poethepoet = "^0.41.0"
|
poethepoet = "^0.37.0"
|
||||||
pre-commit = "^4.4.0"
|
pre-commit = "^4.4.0"
|
||||||
pyright = "^1.1.407"
|
pyright = "^1.1.407"
|
||||||
pytest-mock = "^3.15.1"
|
pytest-mock = "^3.15.1"
|
||||||
pytest-watcher = "^0.6.3"
|
pytest-watcher = "^0.4.2"
|
||||||
requests = "^2.32.5"
|
requests = "^2.32.5"
|
||||||
ruff = "^0.15.0"
|
ruff = "^0.14.5"
|
||||||
# NOTE: please insert new dependencies in their alphabetical location
|
# NOTE: please insert new dependencies in their alphabetical location
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
@@ -25,12 +25,8 @@ RUN if [ -f .env.production ]; then \
|
|||||||
cp .env.default .env; \
|
cp .env.default .env; \
|
||||||
fi
|
fi
|
||||||
RUN pnpm run generate:api
|
RUN pnpm run generate:api
|
||||||
# Disable source-map generation in Docker builds to halve webpack memory usage.
|
|
||||||
# Source maps are only useful when SENTRY_AUTH_TOKEN is set (Vercel deploys);
|
|
||||||
# the Docker image never uploads them, so generating them just wastes RAM.
|
|
||||||
ENV NEXT_PUBLIC_SOURCEMAPS="false"
|
|
||||||
# In CI, we want NEXT_PUBLIC_PW_TEST=true during build so Next.js inlines it
|
# In CI, we want NEXT_PUBLIC_PW_TEST=true during build so Next.js inlines it
|
||||||
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=8192" pnpm build; else NODE_OPTIONS="--max-old-space-size=8192" pnpm build; fi
|
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=4096" pnpm build; else NODE_OPTIONS="--max-old-space-size=4096" pnpm build; fi
|
||||||
|
|
||||||
# Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile
|
# Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile
|
||||||
FROM node:21-alpine AS prod
|
FROM node:21-alpine AS prod
|
||||||
|
|||||||
@@ -1,12 +1,8 @@
|
|||||||
import { withSentryConfig } from "@sentry/nextjs";
|
import { withSentryConfig } from "@sentry/nextjs";
|
||||||
|
|
||||||
// Allow Docker builds to skip source-map generation (halves memory usage).
|
|
||||||
// Defaults to true so Vercel/local builds are unaffected.
|
|
||||||
const enableSourceMaps = process.env.NEXT_PUBLIC_SOURCEMAPS !== "false";
|
|
||||||
|
|
||||||
/** @type {import('next').NextConfig} */
|
/** @type {import('next').NextConfig} */
|
||||||
const nextConfig = {
|
const nextConfig = {
|
||||||
productionBrowserSourceMaps: enableSourceMaps,
|
productionBrowserSourceMaps: true,
|
||||||
// Externalize OpenTelemetry packages to fix Turbopack HMR issues
|
// Externalize OpenTelemetry packages to fix Turbopack HMR issues
|
||||||
serverExternalPackages: [
|
serverExternalPackages: [
|
||||||
"@opentelemetry/instrumentation",
|
"@opentelemetry/instrumentation",
|
||||||
@@ -18,37 +14,9 @@ const nextConfig = {
|
|||||||
serverActions: {
|
serverActions: {
|
||||||
bodySizeLimit: "256mb",
|
bodySizeLimit: "256mb",
|
||||||
},
|
},
|
||||||
|
// Increase body size limit for API routes (file uploads) - 256MB to match backend limit
|
||||||
|
proxyClientMaxBodySize: "256mb",
|
||||||
middlewareClientMaxBodySize: "256mb",
|
middlewareClientMaxBodySize: "256mb",
|
||||||
// Limit parallel webpack workers to reduce peak memory during builds.
|
|
||||||
cpus: 2,
|
|
||||||
},
|
|
||||||
// Work around cssnano "Invalid array length" bug in Next.js's bundled
|
|
||||||
// cssnano-simple comment parser when processing very large CSS chunks.
|
|
||||||
// CSS is still bundled correctly; gzip handles most of the size savings anyway.
|
|
||||||
webpack: (config, { dev }) => {
|
|
||||||
if (!dev) {
|
|
||||||
// Next.js adds CssMinimizerPlugin internally (after user config), so we
|
|
||||||
// can't filter it from config.plugins. Instead, intercept the webpack
|
|
||||||
// compilation hooks and replace the buggy plugin's tap with a no-op.
|
|
||||||
config.plugins.push({
|
|
||||||
apply(compiler) {
|
|
||||||
compiler.hooks.compilation.tap(
|
|
||||||
"DisableCssMinimizer",
|
|
||||||
(compilation) => {
|
|
||||||
compilation.hooks.processAssets.intercept({
|
|
||||||
register: (tap) => {
|
|
||||||
if (tap.name === "CssMinimizerPlugin") {
|
|
||||||
return { ...tap, fn: async () => {} };
|
|
||||||
}
|
|
||||||
return tap;
|
|
||||||
},
|
|
||||||
});
|
|
||||||
},
|
|
||||||
);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return config;
|
|
||||||
},
|
},
|
||||||
images: {
|
images: {
|
||||||
domains: [
|
domains: [
|
||||||
@@ -86,16 +54,9 @@ const nextConfig = {
|
|||||||
transpilePackages: ["geist"],
|
transpilePackages: ["geist"],
|
||||||
};
|
};
|
||||||
|
|
||||||
// Only run the Sentry webpack plugin when we can actually upload source maps
|
const isDevelopmentBuild = process.env.NODE_ENV !== "production";
|
||||||
// (i.e. on Vercel with SENTRY_AUTH_TOKEN set). The Sentry *runtime* SDK
|
|
||||||
// (imported in app code) still captures errors without the plugin.
|
|
||||||
// Skipping the plugin saves ~1 GB of peak memory during `next build`.
|
|
||||||
const skipSentryPlugin =
|
|
||||||
process.env.NODE_ENV !== "production" ||
|
|
||||||
!enableSourceMaps ||
|
|
||||||
!process.env.SENTRY_AUTH_TOKEN;
|
|
||||||
|
|
||||||
export default skipSentryPlugin
|
export default isDevelopmentBuild
|
||||||
? nextConfig
|
? nextConfig
|
||||||
: withSentryConfig(nextConfig, {
|
: withSentryConfig(nextConfig, {
|
||||||
// For all available options, see:
|
// For all available options, see:
|
||||||
@@ -135,7 +96,7 @@ export default skipSentryPlugin
|
|||||||
|
|
||||||
// This helps Sentry with sourcemaps... https://docs.sentry.io/platforms/javascript/guides/nextjs/sourcemaps/
|
// This helps Sentry with sourcemaps... https://docs.sentry.io/platforms/javascript/guides/nextjs/sourcemaps/
|
||||||
sourcemaps: {
|
sourcemaps: {
|
||||||
disable: !enableSourceMaps,
|
disable: false,
|
||||||
assets: [".next/**/*.js", ".next/**/*.js.map"],
|
assets: [".next/**/*.js", ".next/**/*.js.map"],
|
||||||
ignore: ["**/node_modules/**"],
|
ignore: ["**/node_modules/**"],
|
||||||
deleteSourcemapsAfterUpload: false, // Source is public anyway :)
|
deleteSourcemapsAfterUpload: false, // Source is public anyway :)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
},
|
},
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"dev": "pnpm run generate:api:force && next dev --turbo",
|
"dev": "pnpm run generate:api:force && next dev --turbo",
|
||||||
"build": "cross-env NODE_OPTIONS=--max-old-space-size=16384 next build",
|
"build": "next build",
|
||||||
"start": "next start",
|
"start": "next start",
|
||||||
"start:standalone": "cd .next/standalone && node server.js",
|
"start:standalone": "cd .next/standalone && node server.js",
|
||||||
"lint": "next lint && prettier --check .",
|
"lint": "next lint && prettier --check .",
|
||||||
@@ -30,7 +30,6 @@
|
|||||||
"defaults"
|
"defaults"
|
||||||
],
|
],
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@ai-sdk/react": "3.0.61",
|
|
||||||
"@faker-js/faker": "10.0.0",
|
"@faker-js/faker": "10.0.0",
|
||||||
"@hookform/resolvers": "5.2.2",
|
"@hookform/resolvers": "5.2.2",
|
||||||
"@next/third-parties": "15.4.6",
|
"@next/third-parties": "15.4.6",
|
||||||
@@ -61,10 +60,6 @@
|
|||||||
"@rjsf/utils": "6.1.2",
|
"@rjsf/utils": "6.1.2",
|
||||||
"@rjsf/validator-ajv8": "6.1.2",
|
"@rjsf/validator-ajv8": "6.1.2",
|
||||||
"@sentry/nextjs": "10.27.0",
|
"@sentry/nextjs": "10.27.0",
|
||||||
"@streamdown/cjk": "1.0.1",
|
|
||||||
"@streamdown/code": "1.0.1",
|
|
||||||
"@streamdown/math": "1.0.1",
|
|
||||||
"@streamdown/mermaid": "1.0.1",
|
|
||||||
"@supabase/ssr": "0.7.0",
|
"@supabase/ssr": "0.7.0",
|
||||||
"@supabase/supabase-js": "2.78.0",
|
"@supabase/supabase-js": "2.78.0",
|
||||||
"@tanstack/react-query": "5.90.6",
|
"@tanstack/react-query": "5.90.6",
|
||||||
@@ -73,7 +68,6 @@
|
|||||||
"@vercel/analytics": "1.5.0",
|
"@vercel/analytics": "1.5.0",
|
||||||
"@vercel/speed-insights": "1.2.0",
|
"@vercel/speed-insights": "1.2.0",
|
||||||
"@xyflow/react": "12.9.2",
|
"@xyflow/react": "12.9.2",
|
||||||
"ai": "6.0.59",
|
|
||||||
"boring-avatars": "1.11.2",
|
"boring-avatars": "1.11.2",
|
||||||
"class-variance-authority": "0.7.1",
|
"class-variance-authority": "0.7.1",
|
||||||
"clsx": "2.1.1",
|
"clsx": "2.1.1",
|
||||||
@@ -93,6 +87,7 @@
|
|||||||
"launchdarkly-react-client-sdk": "3.9.0",
|
"launchdarkly-react-client-sdk": "3.9.0",
|
||||||
"lodash": "4.17.21",
|
"lodash": "4.17.21",
|
||||||
"lucide-react": "0.552.0",
|
"lucide-react": "0.552.0",
|
||||||
|
"moment": "2.30.1",
|
||||||
"next": "15.4.10",
|
"next": "15.4.10",
|
||||||
"next-themes": "0.4.6",
|
"next-themes": "0.4.6",
|
||||||
"nuqs": "2.7.2",
|
"nuqs": "2.7.2",
|
||||||
@@ -107,7 +102,7 @@
|
|||||||
"react-markdown": "9.0.3",
|
"react-markdown": "9.0.3",
|
||||||
"react-modal": "3.16.3",
|
"react-modal": "3.16.3",
|
||||||
"react-shepherd": "6.1.9",
|
"react-shepherd": "6.1.9",
|
||||||
"react-window": "2.2.0",
|
"react-window": "1.8.11",
|
||||||
"recharts": "3.3.0",
|
"recharts": "3.3.0",
|
||||||
"rehype-autolink-headings": "7.1.0",
|
"rehype-autolink-headings": "7.1.0",
|
||||||
"rehype-highlight": "7.0.2",
|
"rehype-highlight": "7.0.2",
|
||||||
@@ -117,11 +112,9 @@
|
|||||||
"remark-math": "6.0.0",
|
"remark-math": "6.0.0",
|
||||||
"shepherd.js": "14.5.1",
|
"shepherd.js": "14.5.1",
|
||||||
"sonner": "2.0.7",
|
"sonner": "2.0.7",
|
||||||
"streamdown": "2.1.0",
|
|
||||||
"tailwind-merge": "2.6.0",
|
"tailwind-merge": "2.6.0",
|
||||||
"tailwind-scrollbar": "3.1.0",
|
"tailwind-scrollbar": "3.1.0",
|
||||||
"tailwindcss-animate": "1.0.7",
|
"tailwindcss-animate": "1.0.7",
|
||||||
"use-stick-to-bottom": "1.1.2",
|
|
||||||
"uuid": "11.1.0",
|
"uuid": "11.1.0",
|
||||||
"vaul": "1.1.2",
|
"vaul": "1.1.2",
|
||||||
"zod": "3.25.76",
|
"zod": "3.25.76",
|
||||||
@@ -147,7 +140,7 @@
|
|||||||
"@types/react": "18.3.17",
|
"@types/react": "18.3.17",
|
||||||
"@types/react-dom": "18.3.5",
|
"@types/react-dom": "18.3.5",
|
||||||
"@types/react-modal": "3.16.3",
|
"@types/react-modal": "3.16.3",
|
||||||
"@types/react-window": "2.0.0",
|
"@types/react-window": "1.8.8",
|
||||||
"@vitejs/plugin-react": "5.1.2",
|
"@vitejs/plugin-react": "5.1.2",
|
||||||
"axe-playwright": "2.2.2",
|
"axe-playwright": "2.2.2",
|
||||||
"chromatic": "13.3.3",
|
"chromatic": "13.3.3",
|
||||||
@@ -179,8 +172,7 @@
|
|||||||
},
|
},
|
||||||
"pnpm": {
|
"pnpm": {
|
||||||
"overrides": {
|
"overrides": {
|
||||||
"@opentelemetry/instrumentation": "0.209.0",
|
"@opentelemetry/instrumentation": "0.209.0"
|
||||||
"lodash-es": "4.17.23"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"
|
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"
|
||||||
|
|||||||
1218
autogpt_platform/frontend/pnpm-lock.yaml
generated
1218
autogpt_platform/frontend/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
|||||||
import debounce from "lodash/debounce";
|
import { debounce } from "lodash";
|
||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
|
import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
|
||||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
import { getQueryClient } from "@/lib/react-query/queryClient";
|
||||||
|
|||||||
@@ -70,10 +70,10 @@ export const HorizontalScroll: React.FC<HorizontalScrollAreaProps> = ({
|
|||||||
{children}
|
{children}
|
||||||
</div>
|
</div>
|
||||||
{canScrollLeft && (
|
{canScrollLeft && (
|
||||||
<div className="pointer-events-none absolute inset-y-0 left-0 w-8 bg-gradient-to-r from-background via-background/80 to-background/0" />
|
<div className="pointer-events-none absolute inset-y-0 left-0 w-8 bg-gradient-to-r from-white via-white/80 to-white/0" />
|
||||||
)}
|
)}
|
||||||
{canScrollRight && (
|
{canScrollRight && (
|
||||||
<div className="pointer-events-none absolute inset-y-0 right-0 w-8 bg-gradient-to-l from-background via-background/80 to-background/0" />
|
<div className="pointer-events-none absolute inset-y-0 right-0 w-8 bg-gradient-to-l from-white via-white/80 to-white/0" />
|
||||||
)}
|
)}
|
||||||
{canScrollLeft && (
|
{canScrollLeft && (
|
||||||
<button
|
<button
|
||||||
|
|||||||
@@ -1,80 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { SidebarProvider } from "@/components/ui/sidebar";
|
|
||||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
|
||||||
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
|
|
||||||
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
|
||||||
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
|
||||||
import { ScaleLoader } from "./components/ScaleLoader/ScaleLoader";
|
|
||||||
import { useCopilotPage } from "./useCopilotPage";
|
|
||||||
|
|
||||||
export function CopilotPage() {
|
|
||||||
const {
|
|
||||||
sessionId,
|
|
||||||
messages,
|
|
||||||
status,
|
|
||||||
error,
|
|
||||||
stop,
|
|
||||||
createSession,
|
|
||||||
onSend,
|
|
||||||
isLoadingSession,
|
|
||||||
isCreatingSession,
|
|
||||||
isUserLoading,
|
|
||||||
isLoggedIn,
|
|
||||||
// Mobile drawer
|
|
||||||
isMobile,
|
|
||||||
isDrawerOpen,
|
|
||||||
sessions,
|
|
||||||
isLoadingSessions,
|
|
||||||
handleOpenDrawer,
|
|
||||||
handleCloseDrawer,
|
|
||||||
handleDrawerOpenChange,
|
|
||||||
handleSelectSession,
|
|
||||||
handleNewChat,
|
|
||||||
} = useCopilotPage();
|
|
||||||
|
|
||||||
if (isUserLoading || !isLoggedIn) {
|
|
||||||
return (
|
|
||||||
<div className="fixed inset-0 z-50 flex items-center justify-center bg-[#f8f8f9]">
|
|
||||||
<ScaleLoader className="text-neutral-400" />
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<SidebarProvider
|
|
||||||
defaultOpen={true}
|
|
||||||
className="h-[calc(100vh-72px)] min-h-0"
|
|
||||||
>
|
|
||||||
{!isMobile && <ChatSidebar />}
|
|
||||||
<div className="relative flex h-full w-full flex-col overflow-hidden bg-[#f8f8f9] px-0">
|
|
||||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
|
||||||
<div className="flex-1 overflow-hidden">
|
|
||||||
<ChatContainer
|
|
||||||
messages={messages}
|
|
||||||
status={status}
|
|
||||||
error={error}
|
|
||||||
sessionId={sessionId}
|
|
||||||
isLoadingSession={isLoadingSession}
|
|
||||||
isCreatingSession={isCreatingSession}
|
|
||||||
onCreateSession={createSession}
|
|
||||||
onSend={onSend}
|
|
||||||
onStop={stop}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{isMobile && (
|
|
||||||
<MobileDrawer
|
|
||||||
isOpen={isDrawerOpen}
|
|
||||||
sessions={sessions}
|
|
||||||
currentSessionId={sessionId}
|
|
||||||
isLoading={isLoadingSessions}
|
|
||||||
onSelectSession={handleSelectSession}
|
|
||||||
onNewChat={handleNewChat}
|
|
||||||
onClose={handleCloseDrawer}
|
|
||||||
onOpenChange={handleDrawerOpenChange}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</SidebarProvider>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
"use client";
|
|
||||||
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
|
|
||||||
import { UIDataTypes, UIMessage, UITools } from "ai";
|
|
||||||
import { LayoutGroup, motion } from "framer-motion";
|
|
||||||
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
|
|
||||||
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
|
|
||||||
import { EmptySession } from "../EmptySession/EmptySession";
|
|
||||||
|
|
||||||
export interface ChatContainerProps {
|
|
||||||
messages: UIMessage<unknown, UIDataTypes, UITools>[];
|
|
||||||
status: string;
|
|
||||||
error: Error | undefined;
|
|
||||||
sessionId: string | null;
|
|
||||||
isLoadingSession: boolean;
|
|
||||||
isCreatingSession: boolean;
|
|
||||||
onCreateSession: () => void | Promise<string>;
|
|
||||||
onSend: (message: string) => void | Promise<void>;
|
|
||||||
onStop: () => void;
|
|
||||||
}
|
|
||||||
export const ChatContainer = ({
|
|
||||||
messages,
|
|
||||||
status,
|
|
||||||
error,
|
|
||||||
sessionId,
|
|
||||||
isLoadingSession,
|
|
||||||
isCreatingSession,
|
|
||||||
onCreateSession,
|
|
||||||
onSend,
|
|
||||||
onStop,
|
|
||||||
}: ChatContainerProps) => {
|
|
||||||
const inputLayoutId = "copilot-2-chat-input";
|
|
||||||
|
|
||||||
return (
|
|
||||||
<CopilotChatActionsProvider onSend={onSend}>
|
|
||||||
<LayoutGroup id="copilot-2-chat-layout">
|
|
||||||
<div className="flex h-full min-h-0 w-full flex-col bg-[#f8f8f9] px-2 lg:px-0">
|
|
||||||
{sessionId ? (
|
|
||||||
<div className="mx-auto flex h-full min-h-0 w-full max-w-3xl flex-col">
|
|
||||||
<ChatMessagesContainer
|
|
||||||
messages={messages}
|
|
||||||
status={status}
|
|
||||||
error={error}
|
|
||||||
isLoading={isLoadingSession}
|
|
||||||
/>
|
|
||||||
<motion.div
|
|
||||||
initial={{ opacity: 0 }}
|
|
||||||
animate={{ opacity: 1 }}
|
|
||||||
transition={{ duration: 0.3 }}
|
|
||||||
className="relative px-3 pb-2 pt-2"
|
|
||||||
>
|
|
||||||
<div className="pointer-events-none absolute left-0 right-0 top-[-18px] z-10 h-6 bg-gradient-to-b from-transparent to-[#f8f8f9]" />
|
|
||||||
<ChatInput
|
|
||||||
inputId="chat-input-session"
|
|
||||||
onSend={onSend}
|
|
||||||
disabled={status === "streaming"}
|
|
||||||
isStreaming={status === "streaming"}
|
|
||||||
onStop={onStop}
|
|
||||||
placeholder="What else can I help with?"
|
|
||||||
/>
|
|
||||||
</motion.div>
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
<EmptySession
|
|
||||||
inputLayoutId={inputLayoutId}
|
|
||||||
isCreatingSession={isCreatingSession}
|
|
||||||
onCreateSession={onCreateSession}
|
|
||||||
onSend={onSend}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</LayoutGroup>
|
|
||||||
</CopilotChatActionsProvider>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
@@ -1,274 +0,0 @@
|
|||||||
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
|
|
||||||
import {
|
|
||||||
Conversation,
|
|
||||||
ConversationContent,
|
|
||||||
ConversationScrollButton,
|
|
||||||
} from "@/components/ai-elements/conversation";
|
|
||||||
import {
|
|
||||||
Message,
|
|
||||||
MessageContent,
|
|
||||||
MessageResponse,
|
|
||||||
} from "@/components/ai-elements/message";
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
|
||||||
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
|
||||||
import { useEffect, useState } from "react";
|
|
||||||
import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent";
|
|
||||||
import { EditAgentTool } from "../../tools/EditAgent/EditAgent";
|
|
||||||
import { FindAgentsTool } from "../../tools/FindAgents/FindAgents";
|
|
||||||
import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks";
|
|
||||||
import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
|
|
||||||
import { RunBlockTool } from "../../tools/RunBlock/RunBlock";
|
|
||||||
import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs";
|
|
||||||
import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput";
|
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
// Workspace media support
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Resolve workspace:// URLs in markdown text to proxy download URLs.
|
|
||||||
* Detects MIME type from the hash fragment (e.g. workspace://id#video/mp4)
|
|
||||||
* and prefixes the alt text with "video:" so the custom img component can
|
|
||||||
* render a <video> element instead.
|
|
||||||
*/
|
|
||||||
function resolveWorkspaceUrls(text: string): string {
|
|
||||||
return text.replace(
|
|
||||||
/!\[([^\]]*)\]\(workspace:\/\/([^)#\s]+)(?:#([^)\s]*))?\)/g,
|
|
||||||
(_match, alt: string, fileId: string, mimeHint?: string) => {
|
|
||||||
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
|
|
||||||
const url = `/api/proxy${apiPath}`;
|
|
||||||
if (mimeHint?.startsWith("video/")) {
|
|
||||||
return ``;
|
|
||||||
}
|
|
||||||
return ``;
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Custom img component for Streamdown that renders <video> elements
|
|
||||||
* for workspace video files (detected via "video:" alt-text prefix).
|
|
||||||
* Falls back to <video> when an <img> fails to load for workspace files.
|
|
||||||
*/
|
|
||||||
function WorkspaceMediaImage(props: React.JSX.IntrinsicElements["img"]) {
|
|
||||||
const { src, alt, ...rest } = props;
|
|
||||||
const [imgFailed, setImgFailed] = useState(false);
|
|
||||||
const isWorkspace = src?.includes("/workspace/files/") ?? false;
|
|
||||||
|
|
||||||
if (!src) return null;
|
|
||||||
|
|
||||||
if (alt?.startsWith("video:") || (imgFailed && isWorkspace)) {
|
|
||||||
return (
|
|
||||||
<span className="my-2 inline-block">
|
|
||||||
<video
|
|
||||||
controls
|
|
||||||
className="h-auto max-w-full rounded-md border border-zinc-200"
|
|
||||||
preload="metadata"
|
|
||||||
>
|
|
||||||
<source src={src} />
|
|
||||||
Your browser does not support the video tag.
|
|
||||||
</video>
|
|
||||||
</span>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
// eslint-disable-next-line @next/next/no-img-element
|
|
||||||
<img
|
|
||||||
src={src}
|
|
||||||
alt={alt || "Image"}
|
|
||||||
className="h-auto max-w-full rounded-md border border-zinc-200"
|
|
||||||
loading="lazy"
|
|
||||||
onError={() => {
|
|
||||||
if (isWorkspace) setImgFailed(true);
|
|
||||||
}}
|
|
||||||
{...rest}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Stable components override for Streamdown (avoids re-creating on every render). */
|
|
||||||
const STREAMDOWN_COMPONENTS = { img: WorkspaceMediaImage };
|
|
||||||
|
|
||||||
const THINKING_PHRASES = [
|
|
||||||
"Thinking...",
|
|
||||||
"Considering this...",
|
|
||||||
"Working through this...",
|
|
||||||
"Analyzing your request...",
|
|
||||||
"Reasoning...",
|
|
||||||
"Looking into it...",
|
|
||||||
"Processing your request...",
|
|
||||||
"Mulling this over...",
|
|
||||||
"Piecing it together...",
|
|
||||||
"On it...",
|
|
||||||
];
|
|
||||||
|
|
||||||
function getRandomPhrase() {
|
|
||||||
return THINKING_PHRASES[Math.floor(Math.random() * THINKING_PHRASES.length)];
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ChatMessagesContainerProps {
|
|
||||||
messages: UIMessage<unknown, UIDataTypes, UITools>[];
|
|
||||||
status: string;
|
|
||||||
error: Error | undefined;
|
|
||||||
isLoading: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
export const ChatMessagesContainer = ({
|
|
||||||
messages,
|
|
||||||
status,
|
|
||||||
error,
|
|
||||||
isLoading,
|
|
||||||
}: ChatMessagesContainerProps) => {
|
|
||||||
const [thinkingPhrase, setThinkingPhrase] = useState(getRandomPhrase);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (status === "submitted") {
|
|
||||||
setThinkingPhrase(getRandomPhrase());
|
|
||||||
}
|
|
||||||
}, [status]);
|
|
||||||
|
|
||||||
const lastMessage = messages[messages.length - 1];
|
|
||||||
const lastAssistantHasVisibleContent =
|
|
||||||
lastMessage?.role === "assistant" &&
|
|
||||||
lastMessage.parts.some(
|
|
||||||
(p) =>
|
|
||||||
(p.type === "text" && p.text.trim().length > 0) ||
|
|
||||||
p.type.startsWith("tool-"),
|
|
||||||
);
|
|
||||||
|
|
||||||
const showThinking =
|
|
||||||
status === "submitted" ||
|
|
||||||
(status === "streaming" && !lastAssistantHasVisibleContent);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Conversation className="min-h-0 flex-1">
|
|
||||||
<ConversationContent className="flex min-h-screen flex-1 flex-col gap-6 px-3 py-6">
|
|
||||||
{isLoading && messages.length === 0 && (
|
|
||||||
<div className="flex min-h-full flex-1 items-center justify-center">
|
|
||||||
<LoadingSpinner className="text-neutral-600" />
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
{messages.map((message, messageIndex) => {
|
|
||||||
const isLastAssistant =
|
|
||||||
messageIndex === messages.length - 1 &&
|
|
||||||
message.role === "assistant";
|
|
||||||
const messageHasVisibleContent = message.parts.some(
|
|
||||||
(p) =>
|
|
||||||
(p.type === "text" && p.text.trim().length > 0) ||
|
|
||||||
p.type.startsWith("tool-"),
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Message from={message.role} key={message.id}>
|
|
||||||
<MessageContent
|
|
||||||
className={
|
|
||||||
"text-[1rem] leading-relaxed " +
|
|
||||||
"group-[.is-user]:rounded-xl group-[.is-user]:bg-purple-100 group-[.is-user]:px-3 group-[.is-user]:py-2.5 group-[.is-user]:text-slate-900 group-[.is-user]:[border-bottom-right-radius:0] " +
|
|
||||||
"group-[.is-assistant]:bg-transparent group-[.is-assistant]:text-slate-900"
|
|
||||||
}
|
|
||||||
>
|
|
||||||
{message.parts.map((part, i) => {
|
|
||||||
switch (part.type) {
|
|
||||||
case "text":
|
|
||||||
return (
|
|
||||||
<MessageResponse
|
|
||||||
key={`${message.id}-${i}`}
|
|
||||||
components={STREAMDOWN_COMPONENTS}
|
|
||||||
>
|
|
||||||
{resolveWorkspaceUrls(part.text)}
|
|
||||||
</MessageResponse>
|
|
||||||
);
|
|
||||||
case "tool-find_block":
|
|
||||||
return (
|
|
||||||
<FindBlocksTool
|
|
||||||
key={`${message.id}-${i}`}
|
|
||||||
part={part as ToolUIPart}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
case "tool-find_agent":
|
|
||||||
case "tool-find_library_agent":
|
|
||||||
return (
|
|
||||||
<FindAgentsTool
|
|
||||||
key={`${message.id}-${i}`}
|
|
||||||
part={part as ToolUIPart}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
case "tool-search_docs":
|
|
||||||
case "tool-get_doc_page":
|
|
||||||
return (
|
|
||||||
<SearchDocsTool
|
|
||||||
key={`${message.id}-${i}`}
|
|
||||||
part={part as ToolUIPart}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
case "tool-run_block":
|
|
||||||
return (
|
|
||||||
<RunBlockTool
|
|
||||||
key={`${message.id}-${i}`}
|
|
||||||
part={part as ToolUIPart}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
case "tool-run_agent":
|
|
||||||
case "tool-schedule_agent":
|
|
||||||
return (
|
|
||||||
<RunAgentTool
|
|
||||||
key={`${message.id}-${i}`}
|
|
||||||
part={part as ToolUIPart}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
case "tool-create_agent":
|
|
||||||
return (
|
|
||||||
<CreateAgentTool
|
|
||||||
key={`${message.id}-${i}`}
|
|
||||||
part={part as ToolUIPart}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
case "tool-edit_agent":
|
|
||||||
return (
|
|
||||||
<EditAgentTool
|
|
||||||
key={`${message.id}-${i}`}
|
|
||||||
part={part as ToolUIPart}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
case "tool-view_agent_output":
|
|
||||||
return (
|
|
||||||
<ViewAgentOutputTool
|
|
||||||
key={`${message.id}-${i}`}
|
|
||||||
part={part as ToolUIPart}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
default:
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
})}
|
|
||||||
{isLastAssistant &&
|
|
||||||
!messageHasVisibleContent &&
|
|
||||||
showThinking && (
|
|
||||||
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
|
||||||
{thinkingPhrase}
|
|
||||||
</span>
|
|
||||||
)}
|
|
||||||
</MessageContent>
|
|
||||||
</Message>
|
|
||||||
);
|
|
||||||
})}
|
|
||||||
{showThinking && lastMessage?.role !== "assistant" && (
|
|
||||||
<Message from="assistant">
|
|
||||||
<MessageContent className="text-[1rem] leading-relaxed">
|
|
||||||
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
|
||||||
{thinkingPhrase}
|
|
||||||
</span>
|
|
||||||
</MessageContent>
|
|
||||||
</Message>
|
|
||||||
)}
|
|
||||||
{error && (
|
|
||||||
<div className="rounded-lg bg-red-50 p-3 text-red-600">
|
|
||||||
Error: {error.message}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</ConversationContent>
|
|
||||||
<ConversationScrollButton />
|
|
||||||
</Conversation>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
@@ -1,188 +0,0 @@
|
|||||||
"use client";
|
|
||||||
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import {
|
|
||||||
Sidebar,
|
|
||||||
SidebarContent,
|
|
||||||
SidebarFooter,
|
|
||||||
SidebarHeader,
|
|
||||||
SidebarTrigger,
|
|
||||||
useSidebar,
|
|
||||||
} from "@/components/ui/sidebar";
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
import { PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
|
|
||||||
import { motion } from "framer-motion";
|
|
||||||
import { parseAsString, useQueryState } from "nuqs";
|
|
||||||
|
|
||||||
export function ChatSidebar() {
|
|
||||||
const { state } = useSidebar();
|
|
||||||
const isCollapsed = state === "collapsed";
|
|
||||||
const [sessionId, setSessionId] = useQueryState("sessionId", parseAsString);
|
|
||||||
|
|
||||||
const { data: sessionsResponse, isLoading: isLoadingSessions } =
|
|
||||||
useGetV2ListSessions({ limit: 50 });
|
|
||||||
|
|
||||||
const sessions =
|
|
||||||
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
|
|
||||||
|
|
||||||
function handleNewChat() {
|
|
||||||
setSessionId(null);
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleSelectSession(id: string) {
|
|
||||||
setSessionId(id);
|
|
||||||
}
|
|
||||||
|
|
||||||
function formatDate(dateString: string) {
|
|
||||||
const date = new Date(dateString);
|
|
||||||
const now = new Date();
|
|
||||||
const diffMs = now.getTime() - date.getTime();
|
|
||||||
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
|
|
||||||
|
|
||||||
if (diffDays === 0) return "Today";
|
|
||||||
if (diffDays === 1) return "Yesterday";
|
|
||||||
if (diffDays < 7) return `${diffDays} days ago`;
|
|
||||||
|
|
||||||
const day = date.getDate();
|
|
||||||
const ordinal =
|
|
||||||
day % 10 === 1 && day !== 11
|
|
||||||
? "st"
|
|
||||||
: day % 10 === 2 && day !== 12
|
|
||||||
? "nd"
|
|
||||||
: day % 10 === 3 && day !== 13
|
|
||||||
? "rd"
|
|
||||||
: "th";
|
|
||||||
const month = date.toLocaleDateString("en-US", { month: "short" });
|
|
||||||
const year = date.getFullYear();
|
|
||||||
|
|
||||||
return `${day}${ordinal} ${month} ${year}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Sidebar
|
|
||||||
variant="inset"
|
|
||||||
collapsible="icon"
|
|
||||||
className="!top-[50px] !h-[calc(100vh-50px)] border-r border-zinc-100 px-0"
|
|
||||||
>
|
|
||||||
{isCollapsed && (
|
|
||||||
<SidebarHeader
|
|
||||||
className={cn(
|
|
||||||
"flex",
|
|
||||||
isCollapsed
|
|
||||||
? "flex-row items-center justify-between gap-y-4 md:flex-col md:items-start md:justify-start"
|
|
||||||
: "flex-row items-center justify-between",
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<motion.div
|
|
||||||
key={isCollapsed ? "header-collapsed" : "header-expanded"}
|
|
||||||
className="flex flex-col items-center gap-3 pt-4"
|
|
||||||
initial={{ opacity: 0, filter: "blur(3px)" }}
|
|
||||||
animate={{ opacity: 1, filter: "blur(0px)" }}
|
|
||||||
transition={{ type: "spring", bounce: 0.2 }}
|
|
||||||
>
|
|
||||||
<div className="flex flex-col items-center gap-2">
|
|
||||||
<SidebarTrigger />
|
|
||||||
<Button
|
|
||||||
variant="ghost"
|
|
||||||
onClick={handleNewChat}
|
|
||||||
style={{ minWidth: "auto", width: "auto" }}
|
|
||||||
>
|
|
||||||
<PlusCircleIcon className="!size-5" />
|
|
||||||
<span className="sr-only">New Chat</span>
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
</motion.div>
|
|
||||||
</SidebarHeader>
|
|
||||||
)}
|
|
||||||
<SidebarContent className="gap-4 overflow-y-auto px-4 py-4 [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
|
||||||
{!isCollapsed && (
|
|
||||||
<motion.div
|
|
||||||
initial={{ opacity: 0 }}
|
|
||||||
animate={{ opacity: 1 }}
|
|
||||||
transition={{ duration: 0.2, delay: 0.1 }}
|
|
||||||
className="flex items-center justify-between px-3"
|
|
||||||
>
|
|
||||||
<Text variant="h3" size="body-medium">
|
|
||||||
Your chats
|
|
||||||
</Text>
|
|
||||||
<div className="relative left-6">
|
|
||||||
<SidebarTrigger />
|
|
||||||
</div>
|
|
||||||
</motion.div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{!isCollapsed && (
|
|
||||||
<motion.div
|
|
||||||
initial={{ opacity: 0 }}
|
|
||||||
animate={{ opacity: 1 }}
|
|
||||||
transition={{ duration: 0.2, delay: 0.15 }}
|
|
||||||
className="mt-4 flex flex-col gap-1"
|
|
||||||
>
|
|
||||||
{isLoadingSessions ? (
|
|
||||||
<div className="flex min-h-[30rem] items-center justify-center py-4">
|
|
||||||
<LoadingSpinner size="small" className="text-neutral-600" />
|
|
||||||
</div>
|
|
||||||
) : sessions.length === 0 ? (
|
|
||||||
<p className="py-4 text-center text-sm text-neutral-500">
|
|
||||||
No conversations yet
|
|
||||||
</p>
|
|
||||||
) : (
|
|
||||||
sessions.map((session) => (
|
|
||||||
<button
|
|
||||||
key={session.id}
|
|
||||||
onClick={() => handleSelectSession(session.id)}
|
|
||||||
className={cn(
|
|
||||||
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
|
|
||||||
session.id === sessionId
|
|
||||||
? "bg-zinc-100"
|
|
||||||
: "hover:bg-zinc-50",
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
|
||||||
<div className="min-w-0 max-w-full">
|
|
||||||
<Text
|
|
||||||
variant="body"
|
|
||||||
className={cn(
|
|
||||||
"truncate font-normal",
|
|
||||||
session.id === sessionId
|
|
||||||
? "text-zinc-600"
|
|
||||||
: "text-zinc-800",
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
{session.title || `Untitled chat`}
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
<Text variant="small" className="text-neutral-400">
|
|
||||||
{formatDate(session.updated_at)}
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
</button>
|
|
||||||
))
|
|
||||||
)}
|
|
||||||
</motion.div>
|
|
||||||
)}
|
|
||||||
</SidebarContent>
|
|
||||||
{!isCollapsed && sessionId && (
|
|
||||||
<SidebarFooter className="shrink-0 bg-zinc-50 p-3 pb-1 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
|
||||||
<motion.div
|
|
||||||
initial={{ opacity: 0 }}
|
|
||||||
animate={{ opacity: 1 }}
|
|
||||||
transition={{ duration: 0.2, delay: 0.2 }}
|
|
||||||
>
|
|
||||||
<Button
|
|
||||||
variant="primary"
|
|
||||||
size="small"
|
|
||||||
onClick={handleNewChat}
|
|
||||||
className="w-full"
|
|
||||||
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
|
|
||||||
>
|
|
||||||
New Chat
|
|
||||||
</Button>
|
|
||||||
</motion.div>
|
|
||||||
</SidebarFooter>
|
|
||||||
)}
|
|
||||||
</Sidebar>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { CopilotChatActionsContext } from "./useCopilotChatActions";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
onSend: (message: string) => void | Promise<void>;
|
|
||||||
children: React.ReactNode;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function CopilotChatActionsProvider({ onSend, children }: Props) {
|
|
||||||
return (
|
|
||||||
<CopilotChatActionsContext.Provider value={{ onSend }}>
|
|
||||||
{children}
|
|
||||||
</CopilotChatActionsContext.Provider>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { createContext, useContext } from "react";
|
|
||||||
|
|
||||||
interface CopilotChatActions {
|
|
||||||
onSend: (message: string) => void | Promise<void>;
|
|
||||||
}
|
|
||||||
|
|
||||||
const CopilotChatActionsContext = createContext<CopilotChatActions | null>(
|
|
||||||
null,
|
|
||||||
);
|
|
||||||
|
|
||||||
export function useCopilotChatActions(): CopilotChatActions {
|
|
||||||
const ctx = useContext(CopilotChatActionsContext);
|
|
||||||
if (!ctx) {
|
|
||||||
throw new Error(
|
|
||||||
"useCopilotChatActions must be used within CopilotChatActionsProvider",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return ctx;
|
|
||||||
}
|
|
||||||
|
|
||||||
export { CopilotChatActionsContext };
|
|
||||||
@@ -0,0 +1,99 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
|
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
||||||
|
import type { ReactNode } from "react";
|
||||||
|
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
|
||||||
|
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
||||||
|
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
||||||
|
import { useCopilotShell } from "./useCopilotShell";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
children: ReactNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function CopilotShell({ children }: Props) {
|
||||||
|
const {
|
||||||
|
isMobile,
|
||||||
|
isDrawerOpen,
|
||||||
|
isLoading,
|
||||||
|
isCreatingSession,
|
||||||
|
isLoggedIn,
|
||||||
|
hasActiveSession,
|
||||||
|
sessions,
|
||||||
|
currentSessionId,
|
||||||
|
handleOpenDrawer,
|
||||||
|
handleCloseDrawer,
|
||||||
|
handleDrawerOpenChange,
|
||||||
|
handleNewChatClick,
|
||||||
|
handleSessionClick,
|
||||||
|
hasNextPage,
|
||||||
|
isFetchingNextPage,
|
||||||
|
fetchNextPage,
|
||||||
|
} = useCopilotShell();
|
||||||
|
|
||||||
|
if (!isLoggedIn) {
|
||||||
|
return (
|
||||||
|
<div className="flex h-full items-center justify-center">
|
||||||
|
<ChatLoader />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className="flex overflow-hidden bg-[#EFEFF0]"
|
||||||
|
style={{ height: `calc(100vh - ${NAVBAR_HEIGHT_PX}px)` }}
|
||||||
|
>
|
||||||
|
{!isMobile && (
|
||||||
|
<DesktopSidebar
|
||||||
|
sessions={sessions}
|
||||||
|
currentSessionId={currentSessionId}
|
||||||
|
isLoading={isLoading}
|
||||||
|
hasNextPage={hasNextPage}
|
||||||
|
isFetchingNextPage={isFetchingNextPage}
|
||||||
|
onSelectSession={handleSessionClick}
|
||||||
|
onFetchNextPage={fetchNextPage}
|
||||||
|
onNewChat={handleNewChatClick}
|
||||||
|
hasActiveSession={Boolean(hasActiveSession)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="relative flex min-h-0 flex-1 flex-col">
|
||||||
|
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
||||||
|
<div className="flex min-h-0 flex-1 flex-col">
|
||||||
|
{isCreatingSession ? (
|
||||||
|
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
|
||||||
|
<div className="flex flex-col items-center gap-4">
|
||||||
|
<ChatLoader />
|
||||||
|
<Text variant="body" className="text-zinc-500">
|
||||||
|
Creating your chat...
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
children
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{isMobile && (
|
||||||
|
<MobileDrawer
|
||||||
|
isOpen={isDrawerOpen}
|
||||||
|
sessions={sessions}
|
||||||
|
currentSessionId={currentSessionId}
|
||||||
|
isLoading={isLoading}
|
||||||
|
hasNextPage={hasNextPage}
|
||||||
|
isFetchingNextPage={isFetchingNextPage}
|
||||||
|
onSelectSession={handleSessionClick}
|
||||||
|
onFetchNextPage={fetchNextPage}
|
||||||
|
onNewChat={handleNewChatClick}
|
||||||
|
onClose={handleCloseDrawer}
|
||||||
|
onOpenChange={handleDrawerOpenChange}
|
||||||
|
hasActiveSession={Boolean(hasActiveSession)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||||
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
|
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { Plus } from "@phosphor-icons/react";
|
||||||
|
import { SessionsList } from "../SessionsList/SessionsList";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
sessions: SessionSummaryResponse[];
|
||||||
|
currentSessionId: string | null;
|
||||||
|
isLoading: boolean;
|
||||||
|
hasNextPage: boolean;
|
||||||
|
isFetchingNextPage: boolean;
|
||||||
|
onSelectSession: (sessionId: string) => void;
|
||||||
|
onFetchNextPage: () => void;
|
||||||
|
onNewChat: () => void;
|
||||||
|
hasActiveSession: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function DesktopSidebar({
|
||||||
|
sessions,
|
||||||
|
currentSessionId,
|
||||||
|
isLoading,
|
||||||
|
hasNextPage,
|
||||||
|
isFetchingNextPage,
|
||||||
|
onSelectSession,
|
||||||
|
onFetchNextPage,
|
||||||
|
onNewChat,
|
||||||
|
hasActiveSession,
|
||||||
|
}: Props) {
|
||||||
|
return (
|
||||||
|
<aside className="flex h-full w-80 flex-col border-r border-zinc-100 bg-zinc-50">
|
||||||
|
<div className="shrink-0 px-6 py-4">
|
||||||
|
<Text variant="h3" size="body-medium">
|
||||||
|
Your chats
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"flex min-h-0 flex-1 flex-col overflow-y-auto px-3 py-3",
|
||||||
|
scrollbarStyles,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<SessionsList
|
||||||
|
sessions={sessions}
|
||||||
|
currentSessionId={currentSessionId}
|
||||||
|
isLoading={isLoading}
|
||||||
|
hasNextPage={hasNextPage}
|
||||||
|
isFetchingNextPage={isFetchingNextPage}
|
||||||
|
onSelectSession={onSelectSession}
|
||||||
|
onFetchNextPage={onFetchNextPage}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{hasActiveSession && (
|
||||||
|
<div className="shrink-0 bg-zinc-50 p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||||
|
<Button
|
||||||
|
variant="primary"
|
||||||
|
size="small"
|
||||||
|
onClick={onNewChat}
|
||||||
|
className="w-full"
|
||||||
|
leftIcon={<Plus width="1rem" height="1rem" />}
|
||||||
|
>
|
||||||
|
New Chat
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</aside>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||||
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
|
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { PlusIcon, X } from "@phosphor-icons/react";
|
||||||
|
import { Drawer } from "vaul";
|
||||||
|
import { SessionsList } from "../SessionsList/SessionsList";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
isOpen: boolean;
|
||||||
|
sessions: SessionSummaryResponse[];
|
||||||
|
currentSessionId: string | null;
|
||||||
|
isLoading: boolean;
|
||||||
|
hasNextPage: boolean;
|
||||||
|
isFetchingNextPage: boolean;
|
||||||
|
onSelectSession: (sessionId: string) => void;
|
||||||
|
onFetchNextPage: () => void;
|
||||||
|
onNewChat: () => void;
|
||||||
|
onClose: () => void;
|
||||||
|
onOpenChange: (open: boolean) => void;
|
||||||
|
hasActiveSession: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function MobileDrawer({
|
||||||
|
isOpen,
|
||||||
|
sessions,
|
||||||
|
currentSessionId,
|
||||||
|
isLoading,
|
||||||
|
hasNextPage,
|
||||||
|
isFetchingNextPage,
|
||||||
|
onSelectSession,
|
||||||
|
onFetchNextPage,
|
||||||
|
onNewChat,
|
||||||
|
onClose,
|
||||||
|
onOpenChange,
|
||||||
|
hasActiveSession,
|
||||||
|
}: Props) {
|
||||||
|
return (
|
||||||
|
<Drawer.Root open={isOpen} onOpenChange={onOpenChange} direction="left">
|
||||||
|
<Drawer.Portal>
|
||||||
|
<Drawer.Overlay className="fixed inset-0 z-[60] bg-black/10 backdrop-blur-sm" />
|
||||||
|
<Drawer.Content className="fixed left-0 top-0 z-[70] flex h-full w-80 flex-col border-r border-zinc-200 bg-zinc-50">
|
||||||
|
<div className="shrink-0 border-b border-zinc-200 p-4">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<Drawer.Title className="text-lg font-semibold text-zinc-800">
|
||||||
|
Your chats
|
||||||
|
</Drawer.Title>
|
||||||
|
<Button
|
||||||
|
variant="icon"
|
||||||
|
size="icon"
|
||||||
|
aria-label="Close sessions"
|
||||||
|
onClick={onClose}
|
||||||
|
>
|
||||||
|
<X width="1.25rem" height="1.25rem" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"flex min-h-0 flex-1 flex-col overflow-y-auto px-3 py-3",
|
||||||
|
scrollbarStyles,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<SessionsList
|
||||||
|
sessions={sessions}
|
||||||
|
currentSessionId={currentSessionId}
|
||||||
|
isLoading={isLoading}
|
||||||
|
hasNextPage={hasNextPage}
|
||||||
|
isFetchingNextPage={isFetchingNextPage}
|
||||||
|
onSelectSession={onSelectSession}
|
||||||
|
onFetchNextPage={onFetchNextPage}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{hasActiveSession && (
|
||||||
|
<div className="shrink-0 bg-white p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||||
|
<Button
|
||||||
|
variant="primary"
|
||||||
|
size="small"
|
||||||
|
onClick={onNewChat}
|
||||||
|
className="w-full"
|
||||||
|
leftIcon={<PlusIcon width="1rem" height="1rem" />}
|
||||||
|
>
|
||||||
|
New Chat
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</Drawer.Content>
|
||||||
|
</Drawer.Portal>
|
||||||
|
</Drawer.Root>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
import { useState } from "react";
|
||||||
|
|
||||||
|
export function useMobileDrawer() {
|
||||||
|
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
||||||
|
|
||||||
|
const handleOpenDrawer = () => {
|
||||||
|
setIsDrawerOpen(true);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleCloseDrawer = () => {
|
||||||
|
setIsDrawerOpen(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleDrawerOpenChange = (open: boolean) => {
|
||||||
|
setIsDrawerOpen(open);
|
||||||
|
};
|
||||||
|
|
||||||
|
return {
|
||||||
|
isDrawerOpen,
|
||||||
|
handleOpenDrawer,
|
||||||
|
handleCloseDrawer,
|
||||||
|
handleDrawerOpenChange,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||||
|
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
|
import { InfiniteList } from "@/components/molecules/InfiniteList/InfiniteList";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { getSessionTitle } from "../../helpers";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
sessions: SessionSummaryResponse[];
|
||||||
|
currentSessionId: string | null;
|
||||||
|
isLoading: boolean;
|
||||||
|
hasNextPage: boolean;
|
||||||
|
isFetchingNextPage: boolean;
|
||||||
|
onSelectSession: (sessionId: string) => void;
|
||||||
|
onFetchNextPage: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function SessionsList({
|
||||||
|
sessions,
|
||||||
|
currentSessionId,
|
||||||
|
isLoading,
|
||||||
|
hasNextPage,
|
||||||
|
isFetchingNextPage,
|
||||||
|
onSelectSession,
|
||||||
|
onFetchNextPage,
|
||||||
|
}: Props) {
|
||||||
|
if (isLoading) {
|
||||||
|
return (
|
||||||
|
<div className="space-y-1">
|
||||||
|
{Array.from({ length: 5 }).map((_, i) => (
|
||||||
|
<div key={i} className="rounded-lg px-3 py-2.5">
|
||||||
|
<Skeleton className="h-5 w-full" />
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sessions.length === 0) {
|
||||||
|
return (
|
||||||
|
<div className="flex h-full items-center justify-center">
|
||||||
|
<Text variant="body" className="text-zinc-500">
|
||||||
|
You don't have previous chats
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<InfiniteList
|
||||||
|
items={sessions}
|
||||||
|
hasMore={hasNextPage}
|
||||||
|
isFetchingMore={isFetchingNextPage}
|
||||||
|
onEndReached={onFetchNextPage}
|
||||||
|
className="space-y-1"
|
||||||
|
renderItem={(session) => {
|
||||||
|
const isActive = session.id === currentSessionId;
|
||||||
|
return (
|
||||||
|
<button
|
||||||
|
onClick={() => onSelectSession(session.id)}
|
||||||
|
className={cn(
|
||||||
|
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
|
||||||
|
isActive ? "bg-zinc-100" : "hover:bg-zinc-50",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<Text
|
||||||
|
variant="body"
|
||||||
|
className={cn(
|
||||||
|
"font-normal",
|
||||||
|
isActive ? "text-zinc-600" : "text-zinc-800",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{getSessionTitle(session)}
|
||||||
|
</Text>
|
||||||
|
</button>
|
||||||
|
);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
|
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||||
|
import { okData } from "@/app/api/helpers";
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
|
||||||
|
const PAGE_SIZE = 50;
|
||||||
|
|
||||||
|
export interface UseSessionsPaginationArgs {
|
||||||
|
enabled: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
||||||
|
const [offset, setOffset] = useState(0);
|
||||||
|
|
||||||
|
const [accumulatedSessions, setAccumulatedSessions] = useState<
|
||||||
|
SessionSummaryResponse[]
|
||||||
|
>([]);
|
||||||
|
|
||||||
|
const [totalCount, setTotalCount] = useState<number | null>(null);
|
||||||
|
|
||||||
|
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
|
||||||
|
{ limit: PAGE_SIZE, offset },
|
||||||
|
{
|
||||||
|
query: {
|
||||||
|
enabled: enabled && offset >= 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const responseData = okData(data);
|
||||||
|
if (responseData) {
|
||||||
|
const newSessions = responseData.sessions;
|
||||||
|
const total = responseData.total;
|
||||||
|
setTotalCount(total);
|
||||||
|
|
||||||
|
if (offset === 0) {
|
||||||
|
setAccumulatedSessions(newSessions);
|
||||||
|
} else {
|
||||||
|
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
|
||||||
|
}
|
||||||
|
} else if (!enabled) {
|
||||||
|
setAccumulatedSessions([]);
|
||||||
|
setTotalCount(null);
|
||||||
|
}
|
||||||
|
}, [data, offset, enabled]);
|
||||||
|
|
||||||
|
const hasNextPage =
|
||||||
|
totalCount !== null && accumulatedSessions.length < totalCount;
|
||||||
|
|
||||||
|
const areAllSessionsLoaded =
|
||||||
|
totalCount !== null &&
|
||||||
|
accumulatedSessions.length >= totalCount &&
|
||||||
|
!isFetching &&
|
||||||
|
!isLoading;
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (
|
||||||
|
hasNextPage &&
|
||||||
|
!isFetching &&
|
||||||
|
!isLoading &&
|
||||||
|
!isError &&
|
||||||
|
totalCount !== null
|
||||||
|
) {
|
||||||
|
setOffset((prev) => prev + PAGE_SIZE);
|
||||||
|
}
|
||||||
|
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
|
||||||
|
|
||||||
|
const fetchNextPage = () => {
|
||||||
|
if (hasNextPage && !isFetching) {
|
||||||
|
setOffset((prev) => prev + PAGE_SIZE);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const reset = () => {
|
||||||
|
// Only reset the offset - keep existing sessions visible during refetch
|
||||||
|
// The effect will replace sessions when new data arrives at offset 0
|
||||||
|
setOffset(0);
|
||||||
|
};
|
||||||
|
|
||||||
|
return {
|
||||||
|
sessions: accumulatedSessions,
|
||||||
|
isLoading,
|
||||||
|
isFetching,
|
||||||
|
hasNextPage,
|
||||||
|
areAllSessionsLoaded,
|
||||||
|
totalCount,
|
||||||
|
fetchNextPage,
|
||||||
|
reset,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -0,0 +1,106 @@
|
|||||||
|
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||||
|
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||||
|
import { format, formatDistanceToNow, isToday } from "date-fns";
|
||||||
|
|
||||||
|
export function convertSessionDetailToSummary(session: SessionDetailResponse) {
|
||||||
|
return {
|
||||||
|
id: session.id,
|
||||||
|
created_at: session.created_at,
|
||||||
|
updated_at: session.updated_at,
|
||||||
|
title: undefined,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function filterVisibleSessions(sessions: SessionSummaryResponse[]) {
|
||||||
|
const fiveMinutesAgo = Date.now() - 5 * 60 * 1000;
|
||||||
|
return sessions.filter((session) => {
|
||||||
|
const hasBeenUpdated = session.updated_at !== session.created_at;
|
||||||
|
|
||||||
|
if (hasBeenUpdated) return true;
|
||||||
|
|
||||||
|
const isRecentlyCreated =
|
||||||
|
new Date(session.created_at).getTime() > fiveMinutesAgo;
|
||||||
|
|
||||||
|
return isRecentlyCreated;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getSessionTitle(session: SessionSummaryResponse) {
|
||||||
|
if (session.title) return session.title;
|
||||||
|
|
||||||
|
const isNewSession = session.updated_at === session.created_at;
|
||||||
|
|
||||||
|
if (isNewSession) {
|
||||||
|
const createdDate = new Date(session.created_at);
|
||||||
|
if (isToday(createdDate)) {
|
||||||
|
return "Today";
|
||||||
|
}
|
||||||
|
return format(createdDate, "MMM d, yyyy");
|
||||||
|
}
|
||||||
|
|
||||||
|
return "Untitled Chat";
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getSessionUpdatedLabel(session: SessionSummaryResponse) {
|
||||||
|
if (!session.updated_at) return "";
|
||||||
|
return formatDistanceToNow(new Date(session.updated_at), { addSuffix: true });
|
||||||
|
}
|
||||||
|
|
||||||
|
export function mergeCurrentSessionIntoList(
|
||||||
|
accumulatedSessions: SessionSummaryResponse[],
|
||||||
|
currentSessionId: string | null,
|
||||||
|
currentSessionData: SessionDetailResponse | null | undefined,
|
||||||
|
recentlyCreatedSessions?: Map<string, SessionSummaryResponse>,
|
||||||
|
) {
|
||||||
|
const filteredSessions: SessionSummaryResponse[] = [];
|
||||||
|
const addedIds = new Set<string>();
|
||||||
|
|
||||||
|
if (accumulatedSessions.length > 0) {
|
||||||
|
const visibleSessions = filterVisibleSessions(accumulatedSessions);
|
||||||
|
|
||||||
|
if (currentSessionId) {
|
||||||
|
const currentInAll = accumulatedSessions.find(
|
||||||
|
(s) => s.id === currentSessionId,
|
||||||
|
);
|
||||||
|
if (currentInAll) {
|
||||||
|
const isInVisible = visibleSessions.some(
|
||||||
|
(s) => s.id === currentSessionId,
|
||||||
|
);
|
||||||
|
if (!isInVisible) {
|
||||||
|
filteredSessions.push(currentInAll);
|
||||||
|
addedIds.add(currentInAll.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const session of visibleSessions) {
|
||||||
|
if (!addedIds.has(session.id)) {
|
||||||
|
filteredSessions.push(session);
|
||||||
|
addedIds.add(session.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (currentSessionId && currentSessionData) {
|
||||||
|
if (!addedIds.has(currentSessionId)) {
|
||||||
|
const summarySession = convertSessionDetailToSummary(currentSessionData);
|
||||||
|
filteredSessions.unshift(summarySession);
|
||||||
|
addedIds.add(currentSessionId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (recentlyCreatedSessions) {
|
||||||
|
for (const [sessionId, sessionData] of recentlyCreatedSessions) {
|
||||||
|
if (!addedIds.has(sessionId)) {
|
||||||
|
filteredSessions.unshift(sessionData);
|
||||||
|
addedIds.add(sessionId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredSessions;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getCurrentSessionId(searchParams: URLSearchParams) {
|
||||||
|
return searchParams.get("sessionId");
|
||||||
|
}
|
||||||
@@ -0,0 +1,124 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import {
|
||||||
|
getGetV2GetSessionQueryKey,
|
||||||
|
getGetV2ListSessionsQueryKey,
|
||||||
|
useGetV2GetSession,
|
||||||
|
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
|
import { okData } from "@/app/api/helpers";
|
||||||
|
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
||||||
|
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||||
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
|
import { usePathname, useSearchParams } from "next/navigation";
|
||||||
|
import { useCopilotStore } from "../../copilot-page-store";
|
||||||
|
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
||||||
|
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
||||||
|
import { getCurrentSessionId } from "./helpers";
|
||||||
|
import { useShellSessionList } from "./useShellSessionList";
|
||||||
|
|
||||||
|
export function useCopilotShell() {
|
||||||
|
const pathname = usePathname();
|
||||||
|
const searchParams = useSearchParams();
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
const breakpoint = useBreakpoint();
|
||||||
|
const { isLoggedIn } = useSupabase();
|
||||||
|
const isMobile =
|
||||||
|
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
||||||
|
|
||||||
|
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
|
||||||
|
|
||||||
|
const isOnHomepage = pathname === "/copilot";
|
||||||
|
const paramSessionId = searchParams.get("sessionId");
|
||||||
|
|
||||||
|
const {
|
||||||
|
isDrawerOpen,
|
||||||
|
handleOpenDrawer,
|
||||||
|
handleCloseDrawer,
|
||||||
|
handleDrawerOpenChange,
|
||||||
|
} = useMobileDrawer();
|
||||||
|
|
||||||
|
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
|
||||||
|
|
||||||
|
const currentSessionId = getCurrentSessionId(searchParams);
|
||||||
|
|
||||||
|
const { data: currentSessionData } = useGetV2GetSession(
|
||||||
|
currentSessionId || "",
|
||||||
|
{
|
||||||
|
query: {
|
||||||
|
enabled: !!currentSessionId,
|
||||||
|
select: okData,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
const {
|
||||||
|
sessions,
|
||||||
|
isLoading,
|
||||||
|
isSessionsFetching,
|
||||||
|
hasNextPage,
|
||||||
|
fetchNextPage,
|
||||||
|
resetPagination,
|
||||||
|
recentlyCreatedSessionsRef,
|
||||||
|
} = useShellSessionList({
|
||||||
|
paginationEnabled,
|
||||||
|
currentSessionId,
|
||||||
|
currentSessionData,
|
||||||
|
isOnHomepage,
|
||||||
|
paramSessionId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const stopStream = useChatStore((s) => s.stopStream);
|
||||||
|
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
||||||
|
|
||||||
|
function handleSessionClick(sessionId: string) {
|
||||||
|
if (sessionId === currentSessionId) return;
|
||||||
|
|
||||||
|
// Stop current stream - SSE reconnection allows resuming later
|
||||||
|
if (currentSessionId) {
|
||||||
|
stopStream(currentSessionId);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
setUrlSessionId(sessionId, { shallow: false });
|
||||||
|
if (isMobile) handleCloseDrawer();
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleNewChatClick() {
|
||||||
|
// Stop current stream - SSE reconnection allows resuming later
|
||||||
|
if (currentSessionId) {
|
||||||
|
stopStream(currentSessionId);
|
||||||
|
}
|
||||||
|
|
||||||
|
resetPagination();
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
|
});
|
||||||
|
setUrlSessionId(null, { shallow: false });
|
||||||
|
if (isMobile) handleCloseDrawer();
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
isMobile,
|
||||||
|
isDrawerOpen,
|
||||||
|
isLoggedIn,
|
||||||
|
hasActiveSession:
|
||||||
|
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
|
||||||
|
isLoading: isLoading || isCreatingSession,
|
||||||
|
isCreatingSession,
|
||||||
|
sessions,
|
||||||
|
currentSessionId: urlSessionId,
|
||||||
|
handleOpenDrawer,
|
||||||
|
handleCloseDrawer,
|
||||||
|
handleDrawerOpenChange,
|
||||||
|
handleNewChatClick,
|
||||||
|
handleSessionClick,
|
||||||
|
hasNextPage,
|
||||||
|
isFetchingNextPage: isSessionsFetching,
|
||||||
|
fetchNextPage,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -0,0 +1,113 @@
|
|||||||
|
import { getGetV2ListSessionsQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
|
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||||
|
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||||
|
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
||||||
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
|
import { useEffect, useMemo, useRef } from "react";
|
||||||
|
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
|
||||||
|
import {
|
||||||
|
convertSessionDetailToSummary,
|
||||||
|
filterVisibleSessions,
|
||||||
|
mergeCurrentSessionIntoList,
|
||||||
|
} from "./helpers";
|
||||||
|
|
||||||
|
interface UseShellSessionListArgs {
|
||||||
|
paginationEnabled: boolean;
|
||||||
|
currentSessionId: string | null;
|
||||||
|
currentSessionData: SessionDetailResponse | null | undefined;
|
||||||
|
isOnHomepage: boolean;
|
||||||
|
paramSessionId: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useShellSessionList({
|
||||||
|
paginationEnabled,
|
||||||
|
currentSessionId,
|
||||||
|
currentSessionData,
|
||||||
|
isOnHomepage,
|
||||||
|
paramSessionId,
|
||||||
|
}: UseShellSessionListArgs) {
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||||
|
|
||||||
|
const {
|
||||||
|
sessions: accumulatedSessions,
|
||||||
|
isLoading: isSessionsLoading,
|
||||||
|
isFetching: isSessionsFetching,
|
||||||
|
hasNextPage,
|
||||||
|
fetchNextPage,
|
||||||
|
reset: resetPagination,
|
||||||
|
} = useSessionsPagination({
|
||||||
|
enabled: paginationEnabled,
|
||||||
|
});
|
||||||
|
|
||||||
|
const recentlyCreatedSessionsRef = useRef<
|
||||||
|
Map<string, SessionSummaryResponse>
|
||||||
|
>(new Map());
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isOnHomepage && !paramSessionId) {
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [isOnHomepage, paramSessionId, queryClient]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (currentSessionId && currentSessionData) {
|
||||||
|
const isNewSession =
|
||||||
|
currentSessionData.updated_at === currentSessionData.created_at;
|
||||||
|
const isNotInAccumulated = !accumulatedSessions.some(
|
||||||
|
(s) => s.id === currentSessionId,
|
||||||
|
);
|
||||||
|
if (isNewSession || isNotInAccumulated) {
|
||||||
|
const summary = convertSessionDetailToSummary(currentSessionData);
|
||||||
|
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [currentSessionId, currentSessionData, accumulatedSessions]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
|
||||||
|
if (accumulatedSessions.some((s) => s.id === sessionId)) {
|
||||||
|
recentlyCreatedSessionsRef.current.delete(sessionId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [accumulatedSessions]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const unsubscribe = onStreamComplete(() => {
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
|
});
|
||||||
|
});
|
||||||
|
return unsubscribe;
|
||||||
|
}, [onStreamComplete, queryClient]);
|
||||||
|
|
||||||
|
const sessions = useMemo(
|
||||||
|
() =>
|
||||||
|
mergeCurrentSessionIntoList(
|
||||||
|
accumulatedSessions,
|
||||||
|
currentSessionId,
|
||||||
|
currentSessionData,
|
||||||
|
recentlyCreatedSessionsRef.current,
|
||||||
|
),
|
||||||
|
[accumulatedSessions, currentSessionId, currentSessionData],
|
||||||
|
);
|
||||||
|
|
||||||
|
const visibleSessions = useMemo(
|
||||||
|
() => filterVisibleSessions(sessions),
|
||||||
|
[sessions],
|
||||||
|
);
|
||||||
|
|
||||||
|
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
|
||||||
|
|
||||||
|
return {
|
||||||
|
sessions: visibleSessions,
|
||||||
|
isLoading,
|
||||||
|
isSessionsFetching,
|
||||||
|
hasNextPage,
|
||||||
|
fetchNextPage,
|
||||||
|
resetPagination,
|
||||||
|
recentlyCreatedSessionsRef,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -1,111 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
|
||||||
import { SpinnerGapIcon } from "@phosphor-icons/react";
|
|
||||||
import { motion } from "framer-motion";
|
|
||||||
import { useEffect, useState } from "react";
|
|
||||||
import {
|
|
||||||
getGreetingName,
|
|
||||||
getInputPlaceholder,
|
|
||||||
getQuickActions,
|
|
||||||
} from "./helpers";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
inputLayoutId: string;
|
|
||||||
isCreatingSession: boolean;
|
|
||||||
onCreateSession: () => void | Promise<string>;
|
|
||||||
onSend: (message: string) => void | Promise<void>;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function EmptySession({
|
|
||||||
inputLayoutId,
|
|
||||||
isCreatingSession,
|
|
||||||
onSend,
|
|
||||||
}: Props) {
|
|
||||||
const { user } = useSupabase();
|
|
||||||
const greetingName = getGreetingName(user);
|
|
||||||
const quickActions = getQuickActions();
|
|
||||||
const [loadingAction, setLoadingAction] = useState<string | null>(null);
|
|
||||||
const [inputPlaceholder, setInputPlaceholder] = useState(
|
|
||||||
getInputPlaceholder(),
|
|
||||||
);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
|
|
||||||
}, [window.innerWidth]);
|
|
||||||
|
|
||||||
async function handleQuickActionClick(action: string) {
|
|
||||||
if (isCreatingSession || loadingAction) return;
|
|
||||||
|
|
||||||
setLoadingAction(action);
|
|
||||||
try {
|
|
||||||
await onSend(action);
|
|
||||||
} finally {
|
|
||||||
setLoadingAction(null);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-0 py-5 md:px-6 md:py-10">
|
|
||||||
<motion.div
|
|
||||||
className="w-full max-w-3xl text-center"
|
|
||||||
initial={{ opacity: 0 }}
|
|
||||||
animate={{ opacity: 1 }}
|
|
||||||
transition={{ duration: 0.3 }}
|
|
||||||
>
|
|
||||||
<div className="mx-auto max-w-3xl">
|
|
||||||
<Text variant="h3" className="mb-1 !text-[1.375rem] text-zinc-700">
|
|
||||||
Hey, <span className="text-violet-600">{greetingName}</span>
|
|
||||||
</Text>
|
|
||||||
<Text variant="h3" className="mb-8 !font-normal">
|
|
||||||
Tell me about your work — I'll find what to automate.
|
|
||||||
</Text>
|
|
||||||
|
|
||||||
<div className="mb-6">
|
|
||||||
<motion.div
|
|
||||||
layoutId={inputLayoutId}
|
|
||||||
transition={{ type: "spring", bounce: 0.2, duration: 0.65 }}
|
|
||||||
className="w-full px-2"
|
|
||||||
>
|
|
||||||
<ChatInput
|
|
||||||
inputId="chat-input-empty"
|
|
||||||
onSend={onSend}
|
|
||||||
disabled={isCreatingSession}
|
|
||||||
placeholder={inputPlaceholder}
|
|
||||||
className="w-full"
|
|
||||||
/>
|
|
||||||
</motion.div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
|
||||||
{quickActions.map((action) => (
|
|
||||||
<Button
|
|
||||||
key={action}
|
|
||||||
type="button"
|
|
||||||
variant="outline"
|
|
||||||
size="small"
|
|
||||||
onClick={() => void handleQuickActionClick(action)}
|
|
||||||
disabled={isCreatingSession || loadingAction !== null}
|
|
||||||
aria-busy={loadingAction === action}
|
|
||||||
leftIcon={
|
|
||||||
loadingAction === action ? (
|
|
||||||
<SpinnerGapIcon
|
|
||||||
className="h-4 w-4 animate-spin"
|
|
||||||
weight="bold"
|
|
||||||
/>
|
|
||||||
) : null
|
|
||||||
}
|
|
||||||
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
|
|
||||||
>
|
|
||||||
{action}
|
|
||||||
</Button>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
</motion.div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,140 +0,0 @@
|
|||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
import { PlusIcon, SpinnerGapIcon, X } from "@phosphor-icons/react";
|
|
||||||
import { Drawer } from "vaul";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
isOpen: boolean;
|
|
||||||
sessions: SessionSummaryResponse[];
|
|
||||||
currentSessionId: string | null;
|
|
||||||
isLoading: boolean;
|
|
||||||
onSelectSession: (sessionId: string) => void;
|
|
||||||
onNewChat: () => void;
|
|
||||||
onClose: () => void;
|
|
||||||
onOpenChange: (open: boolean) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
function formatDate(dateString: string) {
|
|
||||||
const date = new Date(dateString);
|
|
||||||
const now = new Date();
|
|
||||||
const diffMs = now.getTime() - date.getTime();
|
|
||||||
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
|
|
||||||
|
|
||||||
if (diffDays === 0) return "Today";
|
|
||||||
if (diffDays === 1) return "Yesterday";
|
|
||||||
if (diffDays < 7) return `${diffDays} days ago`;
|
|
||||||
|
|
||||||
const day = date.getDate();
|
|
||||||
const ordinal =
|
|
||||||
day % 10 === 1 && day !== 11
|
|
||||||
? "st"
|
|
||||||
: day % 10 === 2 && day !== 12
|
|
||||||
? "nd"
|
|
||||||
: day % 10 === 3 && day !== 13
|
|
||||||
? "rd"
|
|
||||||
: "th";
|
|
||||||
const month = date.toLocaleDateString("en-US", { month: "short" });
|
|
||||||
const year = date.getFullYear();
|
|
||||||
|
|
||||||
return `${day}${ordinal} ${month} ${year}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function MobileDrawer({
|
|
||||||
isOpen,
|
|
||||||
sessions,
|
|
||||||
currentSessionId,
|
|
||||||
isLoading,
|
|
||||||
onSelectSession,
|
|
||||||
onNewChat,
|
|
||||||
onClose,
|
|
||||||
onOpenChange,
|
|
||||||
}: Props) {
|
|
||||||
return (
|
|
||||||
<Drawer.Root open={isOpen} onOpenChange={onOpenChange} direction="left">
|
|
||||||
<Drawer.Portal>
|
|
||||||
<Drawer.Overlay className="fixed inset-0 z-[60] bg-black/10 backdrop-blur-sm" />
|
|
||||||
<Drawer.Content className="fixed left-0 top-0 z-[70] flex h-full w-80 flex-col border-r border-zinc-200 bg-zinc-50">
|
|
||||||
<div className="shrink-0 border-b border-zinc-200 px-4 py-2">
|
|
||||||
<div className="flex items-center justify-between">
|
|
||||||
<Drawer.Title className="text-lg font-semibold text-zinc-800">
|
|
||||||
Your chats
|
|
||||||
</Drawer.Title>
|
|
||||||
<Button
|
|
||||||
variant="icon"
|
|
||||||
size="icon"
|
|
||||||
aria-label="Close sessions"
|
|
||||||
onClick={onClose}
|
|
||||||
>
|
|
||||||
<X width="1rem" height="1rem" />
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div
|
|
||||||
className={cn(
|
|
||||||
"flex min-h-0 flex-1 flex-col gap-1 overflow-y-auto px-3 py-3",
|
|
||||||
scrollbarStyles,
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
{isLoading ? (
|
|
||||||
<div className="flex items-center justify-center py-4">
|
|
||||||
<SpinnerGapIcon className="h-5 w-5 animate-spin text-neutral-400" />
|
|
||||||
</div>
|
|
||||||
) : sessions.length === 0 ? (
|
|
||||||
<p className="py-4 text-center text-sm text-neutral-500">
|
|
||||||
No conversations yet
|
|
||||||
</p>
|
|
||||||
) : (
|
|
||||||
sessions.map((session) => (
|
|
||||||
<button
|
|
||||||
key={session.id}
|
|
||||||
onClick={() => onSelectSession(session.id)}
|
|
||||||
className={cn(
|
|
||||||
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
|
|
||||||
session.id === currentSessionId
|
|
||||||
? "bg-zinc-100"
|
|
||||||
: "hover:bg-zinc-50",
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
|
||||||
<div className="min-w-0 max-w-full">
|
|
||||||
<Text
|
|
||||||
variant="body"
|
|
||||||
className={cn(
|
|
||||||
"truncate font-normal",
|
|
||||||
session.id === currentSessionId
|
|
||||||
? "text-zinc-600"
|
|
||||||
: "text-zinc-800",
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
{session.title || "Untitled chat"}
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
<Text variant="small" className="text-neutral-400">
|
|
||||||
{formatDate(session.updated_at)}
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
</button>
|
|
||||||
))
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
{currentSessionId && (
|
|
||||||
<div className="shrink-0 bg-white p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
|
||||||
<Button
|
|
||||||
variant="primary"
|
|
||||||
size="small"
|
|
||||||
onClick={onNewChat}
|
|
||||||
className="w-full"
|
|
||||||
leftIcon={<PlusIcon width="1rem" height="1rem" />}
|
|
||||||
>
|
|
||||||
New Chat
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</Drawer.Content>
|
|
||||||
</Drawer.Portal>
|
|
||||||
</Drawer.Root>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
import { cn } from "@/lib/utils";
|
|
||||||
import { AnimatePresence, motion } from "framer-motion";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
text: string;
|
|
||||||
className?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function MorphingTextAnimation({ text, className }: Props) {
|
|
||||||
const letters = text.split("");
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className={cn(className)}>
|
|
||||||
<AnimatePresence mode="popLayout" initial={false}>
|
|
||||||
<motion.div key={text} className="whitespace-nowrap">
|
|
||||||
<motion.span className="inline-flex overflow-hidden">
|
|
||||||
{letters.map((char, index) => (
|
|
||||||
<motion.span
|
|
||||||
key={`${text}-${index}`}
|
|
||||||
initial={{
|
|
||||||
opacity: 0,
|
|
||||||
y: 8,
|
|
||||||
rotateX: "80deg",
|
|
||||||
filter: "blur(6px)",
|
|
||||||
}}
|
|
||||||
animate={{
|
|
||||||
opacity: 1,
|
|
||||||
y: 0,
|
|
||||||
rotateX: "0deg",
|
|
||||||
filter: "blur(0px)",
|
|
||||||
}}
|
|
||||||
exit={{
|
|
||||||
opacity: 0,
|
|
||||||
y: -8,
|
|
||||||
rotateX: "-80deg",
|
|
||||||
filter: "blur(6px)",
|
|
||||||
}}
|
|
||||||
style={{ willChange: "transform" }}
|
|
||||||
transition={{
|
|
||||||
delay: 0.015 * index,
|
|
||||||
type: "spring",
|
|
||||||
bounce: 0.5,
|
|
||||||
}}
|
|
||||||
className="inline-block"
|
|
||||||
>
|
|
||||||
{char === " " ? "\u00A0" : char}
|
|
||||||
</motion.span>
|
|
||||||
))}
|
|
||||||
</motion.span>
|
|
||||||
</motion.div>
|
|
||||||
</AnimatePresence>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
.loader {
|
|
||||||
position: relative;
|
|
||||||
animation: rotate 1s infinite;
|
|
||||||
}
|
|
||||||
|
|
||||||
.loader::before,
|
|
||||||
.loader::after {
|
|
||||||
border-radius: 50%;
|
|
||||||
content: "";
|
|
||||||
display: block;
|
|
||||||
/* 40% of container size */
|
|
||||||
height: 40%;
|
|
||||||
width: 40%;
|
|
||||||
}
|
|
||||||
|
|
||||||
.loader::before {
|
|
||||||
animation: ball1 1s infinite;
|
|
||||||
background-color: #a1a1aa; /* zinc-400 */
|
|
||||||
box-shadow: calc(var(--spacing)) 0 0 #18181b; /* zinc-900 */
|
|
||||||
margin-bottom: calc(var(--gap));
|
|
||||||
}
|
|
||||||
|
|
||||||
.loader::after {
|
|
||||||
animation: ball2 1s infinite;
|
|
||||||
background-color: #18181b; /* zinc-900 */
|
|
||||||
box-shadow: calc(var(--spacing)) 0 0 #a1a1aa; /* zinc-400 */
|
|
||||||
}
|
|
||||||
|
|
||||||
@keyframes rotate {
|
|
||||||
0% {
|
|
||||||
transform: rotate(0deg) scale(0.8);
|
|
||||||
}
|
|
||||||
50% {
|
|
||||||
transform: rotate(360deg) scale(1.2);
|
|
||||||
}
|
|
||||||
100% {
|
|
||||||
transform: rotate(720deg) scale(0.8);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@keyframes ball1 {
|
|
||||||
0% {
|
|
||||||
box-shadow: calc(var(--spacing)) 0 0 #18181b;
|
|
||||||
}
|
|
||||||
50% {
|
|
||||||
box-shadow: 0 0 0 #18181b;
|
|
||||||
margin-bottom: 0;
|
|
||||||
transform: translate(calc(var(--spacing) / 2), calc(var(--spacing) / 2));
|
|
||||||
}
|
|
||||||
100% {
|
|
||||||
box-shadow: calc(var(--spacing)) 0 0 #18181b;
|
|
||||||
margin-bottom: calc(var(--gap));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@keyframes ball2 {
|
|
||||||
0% {
|
|
||||||
box-shadow: calc(var(--spacing)) 0 0 #a1a1aa;
|
|
||||||
}
|
|
||||||
50% {
|
|
||||||
box-shadow: 0 0 0 #a1a1aa;
|
|
||||||
margin-top: calc(var(--ball-size) * -1);
|
|
||||||
transform: translate(calc(var(--spacing) / 2), calc(var(--spacing) / 2));
|
|
||||||
}
|
|
||||||
100% {
|
|
||||||
box-shadow: calc(var(--spacing)) 0 0 #a1a1aa;
|
|
||||||
margin-top: 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
import { cn } from "@/lib/utils";
|
|
||||||
import styles from "./OrbitLoader.module.css";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
size?: number;
|
|
||||||
className?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function OrbitLoader({ size = 24, className }: Props) {
|
|
||||||
const ballSize = Math.round(size * 0.4);
|
|
||||||
const spacing = Math.round(size * 0.6);
|
|
||||||
const gap = Math.round(size * 0.2);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
className={cn(styles.loader, className)}
|
|
||||||
style={
|
|
||||||
{
|
|
||||||
width: size,
|
|
||||||
height: size,
|
|
||||||
"--ball-size": `${ballSize}px`,
|
|
||||||
"--spacing": `${spacing}px`,
|
|
||||||
"--gap": `${gap}px`,
|
|
||||||
} as React.CSSProperties
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
import { cn } from "@/lib/utils";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
value: number;
|
|
||||||
label?: string;
|
|
||||||
className?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function ProgressBar({ value, label, className }: Props) {
|
|
||||||
const clamped = Math.min(100, Math.max(0, value));
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className={cn("flex flex-col gap-1.5", className)}>
|
|
||||||
<div className="flex items-center justify-between text-xs text-neutral-500">
|
|
||||||
<span>{label ?? "Working on it..."}</span>
|
|
||||||
<span>{Math.round(clamped)}%</span>
|
|
||||||
</div>
|
|
||||||
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
|
|
||||||
<div
|
|
||||||
className="h-full rounded-full bg-neutral-900 transition-[width] duration-300 ease-out"
|
|
||||||
style={{ width: `${clamped}%` }}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
.loader {
|
|
||||||
position: relative;
|
|
||||||
display: inline-block;
|
|
||||||
flex-shrink: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.loader::before,
|
|
||||||
.loader::after {
|
|
||||||
content: "";
|
|
||||||
box-sizing: border-box;
|
|
||||||
width: 100%;
|
|
||||||
height: 100%;
|
|
||||||
border-radius: 50%;
|
|
||||||
background: currentColor;
|
|
||||||
position: absolute;
|
|
||||||
left: 0;
|
|
||||||
top: 0;
|
|
||||||
animation: ripple 2s linear infinite;
|
|
||||||
}
|
|
||||||
|
|
||||||
.loader::after {
|
|
||||||
animation-delay: 1s;
|
|
||||||
}
|
|
||||||
|
|
||||||
@keyframes ripple {
|
|
||||||
0% {
|
|
||||||
transform: scale(0);
|
|
||||||
opacity: 1;
|
|
||||||
}
|
|
||||||
100% {
|
|
||||||
transform: scale(1);
|
|
||||||
opacity: 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
import { cn } from "@/lib/utils";
|
|
||||||
import styles from "./PulseLoader.module.css";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
size?: number;
|
|
||||||
className?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function PulseLoader({ size = 24, className }: Props) {
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
className={cn(styles.loader, className)}
|
|
||||||
style={{ width: size, height: size }}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
.loader {
|
|
||||||
width: 48px;
|
|
||||||
height: 48px;
|
|
||||||
display: inline-block;
|
|
||||||
position: relative;
|
|
||||||
}
|
|
||||||
|
|
||||||
.loader::after,
|
|
||||||
.loader::before {
|
|
||||||
content: "";
|
|
||||||
box-sizing: border-box;
|
|
||||||
width: 100%;
|
|
||||||
height: 100%;
|
|
||||||
border-radius: 50%;
|
|
||||||
background: currentColor;
|
|
||||||
position: absolute;
|
|
||||||
left: 0;
|
|
||||||
top: 0;
|
|
||||||
animation: animloader 2s linear infinite;
|
|
||||||
}
|
|
||||||
|
|
||||||
.loader::after {
|
|
||||||
animation-delay: 1s;
|
|
||||||
}
|
|
||||||
|
|
||||||
@keyframes animloader {
|
|
||||||
0% {
|
|
||||||
transform: scale(0);
|
|
||||||
opacity: 1;
|
|
||||||
}
|
|
||||||
100% {
|
|
||||||
transform: scale(1);
|
|
||||||
opacity: 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
import { cn } from "@/lib/utils";
|
|
||||||
import styles from "./ScaleLoader.module.css";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
size?: number;
|
|
||||||
className?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function ScaleLoader({ size = 48, className }: Props) {
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
className={cn(styles.loader, className)}
|
|
||||||
style={{ width: size, height: size }}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
.loader {
|
|
||||||
position: relative;
|
|
||||||
display: inline-block;
|
|
||||||
flex-shrink: 0;
|
|
||||||
transform: rotateZ(45deg);
|
|
||||||
perspective: 1000px;
|
|
||||||
border-radius: 50%;
|
|
||||||
color: currentColor;
|
|
||||||
}
|
|
||||||
|
|
||||||
.loader::before,
|
|
||||||
.loader::after {
|
|
||||||
content: "";
|
|
||||||
display: block;
|
|
||||||
position: absolute;
|
|
||||||
top: 0;
|
|
||||||
left: 0;
|
|
||||||
width: inherit;
|
|
||||||
height: inherit;
|
|
||||||
border-radius: 50%;
|
|
||||||
transform: rotateX(70deg);
|
|
||||||
animation: spin 1s linear infinite;
|
|
||||||
}
|
|
||||||
|
|
||||||
.loader::after {
|
|
||||||
color: var(--spinner-accent, #a855f7);
|
|
||||||
transform: rotateY(70deg);
|
|
||||||
animation-delay: 0.4s;
|
|
||||||
}
|
|
||||||
|
|
||||||
@keyframes spin {
|
|
||||||
0%,
|
|
||||||
100% {
|
|
||||||
box-shadow: 0.2em 0 0 0 currentColor;
|
|
||||||
}
|
|
||||||
12% {
|
|
||||||
box-shadow: 0.2em 0.2em 0 0 currentColor;
|
|
||||||
}
|
|
||||||
25% {
|
|
||||||
box-shadow: 0 0.2em 0 0 currentColor;
|
|
||||||
}
|
|
||||||
37% {
|
|
||||||
box-shadow: -0.2em 0.2em 0 0 currentColor;
|
|
||||||
}
|
|
||||||
50% {
|
|
||||||
box-shadow: -0.2em 0 0 0 currentColor;
|
|
||||||
}
|
|
||||||
62% {
|
|
||||||
box-shadow: -0.2em -0.2em 0 0 currentColor;
|
|
||||||
}
|
|
||||||
75% {
|
|
||||||
box-shadow: 0 -0.2em 0 0 currentColor;
|
|
||||||
}
|
|
||||||
87% {
|
|
||||||
box-shadow: 0.2em -0.2em 0 0 currentColor;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
import { cn } from "@/lib/utils";
|
|
||||||
import styles from "./SpinnerLoader.module.css";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
size?: number;
|
|
||||||
className?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function SpinnerLoader({ size = 24, className }: Props) {
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
className={cn(styles.loader, className)}
|
|
||||||
style={{ width: size, height: size }}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,235 +0,0 @@
|
|||||||
import { Link } from "@/components/atoms/Link/Link";
|
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
|
|
||||||
/* ------------------------------------------------------------------ */
|
|
||||||
/* Layout */
|
|
||||||
/* ------------------------------------------------------------------ */
|
|
||||||
|
|
||||||
export function ContentGrid({
|
|
||||||
children,
|
|
||||||
className,
|
|
||||||
}: {
|
|
||||||
children: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
}) {
|
|
||||||
return <div className={cn("grid gap-2", className)}>{children}</div>;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ------------------------------------------------------------------ */
|
|
||||||
/* Card */
|
|
||||||
/* ------------------------------------------------------------------ */
|
|
||||||
|
|
||||||
export function ContentCard({
|
|
||||||
children,
|
|
||||||
className,
|
|
||||||
}: {
|
|
||||||
children: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
className={cn(
|
|
||||||
"rounded-lg bg-gradient-to-r from-purple-500/30 to-blue-500/30 p-[1px]",
|
|
||||||
className,
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<div className="rounded-lg bg-neutral-100 p-3">{children}</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Flex row with a left content area (`children`) and an optional right‑side `action`. */
|
|
||||||
export function ContentCardHeader({
|
|
||||||
children,
|
|
||||||
action,
|
|
||||||
className,
|
|
||||||
}: {
|
|
||||||
children: React.ReactNode;
|
|
||||||
action?: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<div className={cn("flex items-start justify-between gap-2", className)}>
|
|
||||||
<div className="min-w-0">{children}</div>
|
|
||||||
{action}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function ContentCardTitle({
|
|
||||||
children,
|
|
||||||
className,
|
|
||||||
}: {
|
|
||||||
children: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<Text
|
|
||||||
variant="body-medium"
|
|
||||||
className={cn("truncate text-zinc-800", className)}
|
|
||||||
>
|
|
||||||
{children}
|
|
||||||
</Text>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function ContentCardSubtitle({
|
|
||||||
children,
|
|
||||||
className,
|
|
||||||
}: {
|
|
||||||
children: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<Text
|
|
||||||
variant="small"
|
|
||||||
className={cn("mt-0.5 truncate font-mono text-zinc-800", className)}
|
|
||||||
>
|
|
||||||
{children}
|
|
||||||
</Text>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function ContentCardDescription({
|
|
||||||
children,
|
|
||||||
className,
|
|
||||||
}: {
|
|
||||||
children: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<Text variant="body" className={cn("mt-2 text-zinc-800", className)}>
|
|
||||||
{children}
|
|
||||||
</Text>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ------------------------------------------------------------------ */
|
|
||||||
/* Text */
|
|
||||||
/* ------------------------------------------------------------------ */
|
|
||||||
|
|
||||||
export function ContentMessage({
|
|
||||||
children,
|
|
||||||
className,
|
|
||||||
}: {
|
|
||||||
children: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<Text variant="body" className={cn("text-zinc-800", className)}>
|
|
||||||
{children}
|
|
||||||
</Text>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function ContentHint({
|
|
||||||
children,
|
|
||||||
className,
|
|
||||||
}: {
|
|
||||||
children: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<Text variant="small" className={cn("text-neutral-500", className)}>
|
|
||||||
{children}
|
|
||||||
</Text>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ------------------------------------------------------------------ */
|
|
||||||
/* Code / data */
|
|
||||||
/* ------------------------------------------------------------------ */
|
|
||||||
|
|
||||||
export function ContentCodeBlock({
|
|
||||||
children,
|
|
||||||
className,
|
|
||||||
}: {
|
|
||||||
children: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<pre
|
|
||||||
className={cn(
|
|
||||||
"whitespace-pre-wrap rounded-lg border bg-black p-3 text-xs text-neutral-200",
|
|
||||||
className,
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
{children}
|
|
||||||
</pre>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ------------------------------------------------------------------ */
|
|
||||||
/* Inline elements */
|
|
||||||
/* ------------------------------------------------------------------ */
|
|
||||||
|
|
||||||
export function ContentBadge({
|
|
||||||
children,
|
|
||||||
className,
|
|
||||||
}: {
|
|
||||||
children: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<Text
|
|
||||||
variant="small"
|
|
||||||
as="span"
|
|
||||||
className={cn(
|
|
||||||
"shrink-0 rounded-full border bg-muted px-2 py-0.5 text-[11px] text-zinc-800",
|
|
||||||
className,
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
{children}
|
|
||||||
</Text>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function ContentLink({
|
|
||||||
href,
|
|
||||||
children,
|
|
||||||
className,
|
|
||||||
...rest
|
|
||||||
}: Omit<React.ComponentProps<typeof Link>, "className"> & {
|
|
||||||
className?: string;
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<Link
|
|
||||||
variant="primary"
|
|
||||||
isExternal
|
|
||||||
href={href}
|
|
||||||
className={cn("shrink-0 text-xs text-purple-500", className)}
|
|
||||||
{...rest}
|
|
||||||
>
|
|
||||||
{children}
|
|
||||||
</Link>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ------------------------------------------------------------------ */
|
|
||||||
/* Lists */
|
|
||||||
/* ------------------------------------------------------------------ */
|
|
||||||
|
|
||||||
export function ContentSuggestionsList({
|
|
||||||
items,
|
|
||||||
max = 5,
|
|
||||||
className,
|
|
||||||
}: {
|
|
||||||
items: string[];
|
|
||||||
max?: number;
|
|
||||||
className?: string;
|
|
||||||
}) {
|
|
||||||
if (items.length === 0) return null;
|
|
||||||
return (
|
|
||||||
<ul
|
|
||||||
className={cn(
|
|
||||||
"mt-2 list-disc space-y-1 pl-5 font-sans text-[0.75rem] leading-[1.125rem] text-zinc-800",
|
|
||||||
className,
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
{items.slice(0, max).map((s) => (
|
|
||||||
<li key={s}>{s}</li>
|
|
||||||
))}
|
|
||||||
</ul>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,102 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
import { CaretDownIcon } from "@phosphor-icons/react";
|
|
||||||
import { AnimatePresence, motion, useReducedMotion } from "framer-motion";
|
|
||||||
import { useId } from "react";
|
|
||||||
import { useToolAccordion } from "./useToolAccordion";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
icon: React.ReactNode;
|
|
||||||
title: React.ReactNode;
|
|
||||||
titleClassName?: string;
|
|
||||||
description?: React.ReactNode;
|
|
||||||
children: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
defaultExpanded?: boolean;
|
|
||||||
expanded?: boolean;
|
|
||||||
onExpandedChange?: (expanded: boolean) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function ToolAccordion({
|
|
||||||
icon,
|
|
||||||
title,
|
|
||||||
titleClassName,
|
|
||||||
description,
|
|
||||||
children,
|
|
||||||
className,
|
|
||||||
defaultExpanded,
|
|
||||||
expanded,
|
|
||||||
onExpandedChange,
|
|
||||||
}: Props) {
|
|
||||||
const shouldReduceMotion = useReducedMotion();
|
|
||||||
const contentId = useId();
|
|
||||||
const { isExpanded, toggle } = useToolAccordion({
|
|
||||||
expanded,
|
|
||||||
defaultExpanded,
|
|
||||||
onExpandedChange,
|
|
||||||
});
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
className={cn(
|
|
||||||
"mt-2 w-full rounded-lg border border-slate-200 bg-slate-100 px-3 py-2",
|
|
||||||
className,
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<button
|
|
||||||
type="button"
|
|
||||||
aria-expanded={isExpanded}
|
|
||||||
aria-controls={contentId}
|
|
||||||
onClick={toggle}
|
|
||||||
className="flex w-full items-center justify-between gap-3 py-1 text-left"
|
|
||||||
>
|
|
||||||
<div className="flex min-w-0 items-center gap-3">
|
|
||||||
<span className="flex shrink-0 items-center text-gray-800">
|
|
||||||
{icon}
|
|
||||||
</span>
|
|
||||||
<div className="min-w-0">
|
|
||||||
<p
|
|
||||||
className={cn(
|
|
||||||
"truncate text-sm font-medium text-gray-800",
|
|
||||||
titleClassName,
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
{title}
|
|
||||||
</p>
|
|
||||||
{description && (
|
|
||||||
<p className="truncate text-xs text-slate-800">{description}</p>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<CaretDownIcon
|
|
||||||
className={cn(
|
|
||||||
"h-4 w-4 shrink-0 text-slate-500 transition-transform",
|
|
||||||
isExpanded && "rotate-180",
|
|
||||||
)}
|
|
||||||
weight="bold"
|
|
||||||
/>
|
|
||||||
</button>
|
|
||||||
|
|
||||||
<AnimatePresence initial={false}>
|
|
||||||
{isExpanded && (
|
|
||||||
<motion.div
|
|
||||||
id={contentId}
|
|
||||||
initial={{ height: 0, opacity: 0, filter: "blur(10px)" }}
|
|
||||||
animate={{ height: "auto", opacity: 1, filter: "blur(0px)" }}
|
|
||||||
exit={{ height: 0, opacity: 0, filter: "blur(10px)" }}
|
|
||||||
transition={
|
|
||||||
shouldReduceMotion
|
|
||||||
? { duration: 0 }
|
|
||||||
: { type: "spring", bounce: 0.35, duration: 0.55 }
|
|
||||||
}
|
|
||||||
className="overflow-hidden"
|
|
||||||
style={{ willChange: "height, opacity, filter" }}
|
|
||||||
>
|
|
||||||
<div className="pb-2 pt-3">{children}</div>
|
|
||||||
</motion.div>
|
|
||||||
)}
|
|
||||||
</AnimatePresence>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
import { useState } from "react";
|
|
||||||
|
|
||||||
interface UseToolAccordionOptions {
|
|
||||||
expanded?: boolean;
|
|
||||||
defaultExpanded?: boolean;
|
|
||||||
onExpandedChange?: (expanded: boolean) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface UseToolAccordionResult {
|
|
||||||
isExpanded: boolean;
|
|
||||||
toggle: () => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function useToolAccordion({
|
|
||||||
expanded,
|
|
||||||
defaultExpanded = false,
|
|
||||||
onExpandedChange,
|
|
||||||
}: UseToolAccordionOptions): UseToolAccordionResult {
|
|
||||||
const [uncontrolledExpanded, setUncontrolledExpanded] =
|
|
||||||
useState(defaultExpanded);
|
|
||||||
|
|
||||||
const isControlled = typeof expanded === "boolean";
|
|
||||||
const isExpanded = isControlled ? expanded : uncontrolledExpanded;
|
|
||||||
|
|
||||||
function toggle() {
|
|
||||||
const next = !isExpanded;
|
|
||||||
if (!isControlled) setUncontrolledExpanded(next);
|
|
||||||
onExpandedChange?.(next);
|
|
||||||
}
|
|
||||||
|
|
||||||
return { isExpanded, toggle };
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { create } from "zustand";
|
||||||
|
|
||||||
|
interface CopilotStoreState {
|
||||||
|
isStreaming: boolean;
|
||||||
|
isSwitchingSession: boolean;
|
||||||
|
isCreatingSession: boolean;
|
||||||
|
isInterruptModalOpen: boolean;
|
||||||
|
pendingAction: (() => void) | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface CopilotStoreActions {
|
||||||
|
setIsStreaming: (isStreaming: boolean) => void;
|
||||||
|
setIsSwitchingSession: (isSwitchingSession: boolean) => void;
|
||||||
|
setIsCreatingSession: (isCreating: boolean) => void;
|
||||||
|
openInterruptModal: (onConfirm: () => void) => void;
|
||||||
|
confirmInterrupt: () => void;
|
||||||
|
cancelInterrupt: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
type CopilotStore = CopilotStoreState & CopilotStoreActions;
|
||||||
|
|
||||||
|
export const useCopilotStore = create<CopilotStore>((set, get) => ({
|
||||||
|
isStreaming: false,
|
||||||
|
isSwitchingSession: false,
|
||||||
|
isCreatingSession: false,
|
||||||
|
isInterruptModalOpen: false,
|
||||||
|
pendingAction: null,
|
||||||
|
|
||||||
|
setIsStreaming(isStreaming) {
|
||||||
|
set({ isStreaming });
|
||||||
|
},
|
||||||
|
|
||||||
|
setIsSwitchingSession(isSwitchingSession) {
|
||||||
|
set({ isSwitchingSession });
|
||||||
|
},
|
||||||
|
|
||||||
|
setIsCreatingSession(isCreatingSession) {
|
||||||
|
set({ isCreatingSession });
|
||||||
|
},
|
||||||
|
|
||||||
|
openInterruptModal(onConfirm) {
|
||||||
|
set({ isInterruptModalOpen: true, pendingAction: onConfirm });
|
||||||
|
},
|
||||||
|
|
||||||
|
confirmInterrupt() {
|
||||||
|
const { pendingAction } = get();
|
||||||
|
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||||
|
if (pendingAction) pendingAction();
|
||||||
|
},
|
||||||
|
|
||||||
|
cancelInterrupt() {
|
||||||
|
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||||
|
},
|
||||||
|
}));
|
||||||
@@ -1,26 +1,6 @@
|
|||||||
import { User } from "@supabase/supabase-js";
|
import type { User } from "@supabase/supabase-js";
|
||||||
|
|
||||||
export function getInputPlaceholder(width?: number) {
|
export function getGreetingName(user?: User | null): string {
|
||||||
if (!width) return "What's your role and what eats up most of your day?";
|
|
||||||
|
|
||||||
if (width < 500) {
|
|
||||||
return "I'm a chef and I hate...";
|
|
||||||
}
|
|
||||||
if (width <= 1080) {
|
|
||||||
return "What's your role and what eats up most of your day?";
|
|
||||||
}
|
|
||||||
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getQuickActions() {
|
|
||||||
return [
|
|
||||||
"I don't know where to start, just ask me stuff",
|
|
||||||
"I do the same thing every week and it's killing me",
|
|
||||||
"Help me find where I'm wasting my time",
|
|
||||||
];
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getGreetingName(user?: User | null) {
|
|
||||||
if (!user) return "there";
|
if (!user) return "there";
|
||||||
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
||||||
const fullName = metadata?.full_name;
|
const fullName = metadata?.full_name;
|
||||||
@@ -36,3 +16,30 @@ export function getGreetingName(user?: User | null) {
|
|||||||
}
|
}
|
||||||
return "there";
|
return "there";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function buildCopilotChatUrl(prompt: string): string {
|
||||||
|
const trimmed = prompt.trim();
|
||||||
|
if (!trimmed) return "/copilot/chat";
|
||||||
|
const encoded = encodeURIComponent(trimmed);
|
||||||
|
return `/copilot/chat?prompt=${encoded}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getQuickActions(): string[] {
|
||||||
|
return [
|
||||||
|
"I don't know where to start, just ask me stuff",
|
||||||
|
"I do the same thing every week and it's killing me",
|
||||||
|
"Help me find where I'm wasting my time",
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getInputPlaceholder(width?: number) {
|
||||||
|
if (!width) return "What's your role and what eats up most of your day?";
|
||||||
|
|
||||||
|
if (width < 500) {
|
||||||
|
return "I'm a chef and I hate...";
|
||||||
|
}
|
||||||
|
if (width <= 1080) {
|
||||||
|
return "What's your role and what eats up most of your day?";
|
||||||
|
}
|
||||||
|
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
|
||||||
|
}
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
import type { UIMessage, UIDataTypes, UITools } from "ai";
|
|
||||||
|
|
||||||
interface SessionChatMessage {
|
|
||||||
role: string;
|
|
||||||
content: string | null;
|
|
||||||
tool_call_id: string | null;
|
|
||||||
tool_calls: unknown[] | null;
|
|
||||||
}
|
|
||||||
|
|
||||||
function coerceSessionChatMessages(
|
|
||||||
rawMessages: unknown[],
|
|
||||||
): SessionChatMessage[] {
|
|
||||||
return rawMessages
|
|
||||||
.map((m) => {
|
|
||||||
if (!m || typeof m !== "object") return null;
|
|
||||||
const msg = m as Record<string, unknown>;
|
|
||||||
|
|
||||||
const role = typeof msg.role === "string" ? msg.role : null;
|
|
||||||
if (!role) return null;
|
|
||||||
|
|
||||||
return {
|
|
||||||
role,
|
|
||||||
content:
|
|
||||||
typeof msg.content === "string"
|
|
||||||
? msg.content
|
|
||||||
: msg.content == null
|
|
||||||
? null
|
|
||||||
: String(msg.content),
|
|
||||||
tool_call_id:
|
|
||||||
typeof msg.tool_call_id === "string"
|
|
||||||
? msg.tool_call_id
|
|
||||||
: msg.tool_call_id == null
|
|
||||||
? null
|
|
||||||
: String(msg.tool_call_id),
|
|
||||||
tool_calls: Array.isArray(msg.tool_calls) ? msg.tool_calls : null,
|
|
||||||
};
|
|
||||||
})
|
|
||||||
.filter((m): m is SessionChatMessage => m !== null);
|
|
||||||
}
|
|
||||||
|
|
||||||
function safeJsonParse(value: string): unknown {
|
|
||||||
try {
|
|
||||||
return JSON.parse(value) as unknown;
|
|
||||||
} catch {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function toToolInput(rawArguments: unknown): unknown {
|
|
||||||
if (typeof rawArguments === "string") {
|
|
||||||
const trimmed = rawArguments.trim();
|
|
||||||
return trimmed ? safeJsonParse(trimmed) : {};
|
|
||||||
}
|
|
||||||
if (rawArguments && typeof rawArguments === "object") return rawArguments;
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
export function convertChatSessionMessagesToUiMessages(
|
|
||||||
sessionId: string,
|
|
||||||
rawMessages: unknown[],
|
|
||||||
): UIMessage<unknown, UIDataTypes, UITools>[] {
|
|
||||||
const messages = coerceSessionChatMessages(rawMessages);
|
|
||||||
const toolOutputsByCallId = new Map<string, unknown>();
|
|
||||||
|
|
||||||
for (const msg of messages) {
|
|
||||||
if (msg.role !== "tool") continue;
|
|
||||||
if (!msg.tool_call_id) continue;
|
|
||||||
if (msg.content == null) continue;
|
|
||||||
toolOutputsByCallId.set(msg.tool_call_id, msg.content);
|
|
||||||
}
|
|
||||||
|
|
||||||
const uiMessages: UIMessage<unknown, UIDataTypes, UITools>[] = [];
|
|
||||||
|
|
||||||
messages.forEach((msg, index) => {
|
|
||||||
if (msg.role === "tool") return;
|
|
||||||
if (msg.role !== "user" && msg.role !== "assistant") return;
|
|
||||||
|
|
||||||
const parts: UIMessage<unknown, UIDataTypes, UITools>["parts"] = [];
|
|
||||||
|
|
||||||
if (typeof msg.content === "string" && msg.content.trim()) {
|
|
||||||
parts.push({ type: "text", text: msg.content, state: "done" });
|
|
||||||
}
|
|
||||||
|
|
||||||
if (msg.role === "assistant" && Array.isArray(msg.tool_calls)) {
|
|
||||||
for (const rawToolCall of msg.tool_calls) {
|
|
||||||
if (!rawToolCall || typeof rawToolCall !== "object") continue;
|
|
||||||
const toolCall = rawToolCall as {
|
|
||||||
id?: unknown;
|
|
||||||
function?: { name?: unknown; arguments?: unknown };
|
|
||||||
};
|
|
||||||
|
|
||||||
const toolCallId = String(toolCall.id ?? "").trim();
|
|
||||||
const toolName = String(toolCall.function?.name ?? "").trim();
|
|
||||||
if (!toolCallId || !toolName) continue;
|
|
||||||
|
|
||||||
const input = toToolInput(toolCall.function?.arguments);
|
|
||||||
const output = toolOutputsByCallId.get(toolCallId);
|
|
||||||
|
|
||||||
if (output !== undefined) {
|
|
||||||
parts.push({
|
|
||||||
type: `tool-${toolName}`,
|
|
||||||
toolCallId,
|
|
||||||
state: "output-available",
|
|
||||||
input,
|
|
||||||
output: typeof output === "string" ? safeJsonParse(output) : output,
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
parts.push({
|
|
||||||
type: `tool-${toolName}`,
|
|
||||||
toolCallId,
|
|
||||||
state: "input-available",
|
|
||||||
input,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (parts.length === 0) return;
|
|
||||||
|
|
||||||
uiMessages.push({
|
|
||||||
id: `${sessionId}-${index}`,
|
|
||||||
role: msg.role,
|
|
||||||
parts,
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
return uiMessages;
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
"use client";
|
||||||
|
import { FeatureFlagPage } from "@/services/feature-flags/FeatureFlagPage";
|
||||||
|
import { Flag } from "@/services/feature-flags/use-get-flag";
|
||||||
|
import { type ReactNode } from "react";
|
||||||
|
import { CopilotShell } from "./components/CopilotShell/CopilotShell";
|
||||||
|
|
||||||
|
export default function CopilotLayout({ children }: { children: ReactNode }) {
|
||||||
|
return (
|
||||||
|
<FeatureFlagPage flag={Flag.CHAT} whenDisabled="/library">
|
||||||
|
<CopilotShell>{children}</CopilotShell>
|
||||||
|
</FeatureFlagPage>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,13 +1,149 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { FeatureFlagPage } from "@/services/feature-flags/FeatureFlagPage";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
import { Flag } from "@/services/feature-flags/use-get-flag";
|
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||||
import { CopilotPage } from "./CopilotPage";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
|
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||||
|
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||||
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import { useCopilotStore } from "./copilot-page-store";
|
||||||
|
import { getInputPlaceholder } from "./helpers";
|
||||||
|
import { useCopilotPage } from "./useCopilotPage";
|
||||||
|
|
||||||
|
export default function CopilotPage() {
|
||||||
|
const { state, handlers } = useCopilotPage();
|
||||||
|
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||||
|
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||||
|
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||||
|
|
||||||
|
const [inputPlaceholder, setInputPlaceholder] = useState(
|
||||||
|
getInputPlaceholder(),
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const handleResize = () => {
|
||||||
|
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
|
||||||
|
};
|
||||||
|
|
||||||
|
handleResize();
|
||||||
|
|
||||||
|
window.addEventListener("resize", handleResize);
|
||||||
|
return () => window.removeEventListener("resize", handleResize);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const { greetingName, quickActions, isLoading, hasSession, initialPrompt } =
|
||||||
|
state;
|
||||||
|
|
||||||
|
const {
|
||||||
|
handleQuickAction,
|
||||||
|
startChatWithPrompt,
|
||||||
|
handleSessionNotFound,
|
||||||
|
handleStreamingChange,
|
||||||
|
} = handlers;
|
||||||
|
|
||||||
|
if (hasSession) {
|
||||||
|
return (
|
||||||
|
<div className="flex h-full flex-col">
|
||||||
|
<Chat
|
||||||
|
className="flex-1"
|
||||||
|
initialPrompt={initialPrompt}
|
||||||
|
onSessionNotFound={handleSessionNotFound}
|
||||||
|
onStreamingChange={handleStreamingChange}
|
||||||
|
/>
|
||||||
|
<Dialog
|
||||||
|
title="Interrupt current chat?"
|
||||||
|
styling={{ maxWidth: 300, width: "100%" }}
|
||||||
|
controlled={{
|
||||||
|
isOpen: isInterruptModalOpen,
|
||||||
|
set: (open) => {
|
||||||
|
if (!open) cancelInterrupt();
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
onClose={cancelInterrupt}
|
||||||
|
>
|
||||||
|
<Dialog.Content>
|
||||||
|
<div className="flex flex-col gap-4">
|
||||||
|
<Text variant="body">
|
||||||
|
The current chat response will be interrupted. Are you sure you
|
||||||
|
want to continue?
|
||||||
|
</Text>
|
||||||
|
<Dialog.Footer>
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="outline"
|
||||||
|
onClick={cancelInterrupt}
|
||||||
|
>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="primary"
|
||||||
|
onClick={confirmInterrupt}
|
||||||
|
>
|
||||||
|
Continue
|
||||||
|
</Button>
|
||||||
|
</Dialog.Footer>
|
||||||
|
</div>
|
||||||
|
</Dialog.Content>
|
||||||
|
</Dialog>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export default function Page() {
|
|
||||||
return (
|
return (
|
||||||
<FeatureFlagPage flag={Flag.CHAT} whenDisabled="/library">
|
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-3 py-5 md:px-6 md:py-10">
|
||||||
<CopilotPage />
|
<div className="w-full text-center">
|
||||||
</FeatureFlagPage>
|
{isLoading ? (
|
||||||
|
<div className="mx-auto max-w-2xl">
|
||||||
|
<Skeleton className="mx-auto mb-3 h-8 w-64" />
|
||||||
|
<Skeleton className="mx-auto mb-8 h-6 w-80" />
|
||||||
|
<div className="mb-8">
|
||||||
|
<Skeleton className="mx-auto h-14 w-full rounded-lg" />
|
||||||
|
</div>
|
||||||
|
<div className="flex flex-wrap items-center justify-center gap-3">
|
||||||
|
{Array.from({ length: 4 }).map((_, i) => (
|
||||||
|
<Skeleton key={i} className="h-9 w-48 rounded-md" />
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<div className="mx-auto max-w-3xl">
|
||||||
|
<Text
|
||||||
|
variant="h3"
|
||||||
|
className="mb-1 !text-[1.375rem] text-zinc-700"
|
||||||
|
>
|
||||||
|
Hey, <span className="text-violet-600">{greetingName}</span>
|
||||||
|
</Text>
|
||||||
|
<Text variant="h3" className="mb-8 !font-normal">
|
||||||
|
Tell me about your work — I'll find what to automate.
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<div className="mb-6">
|
||||||
|
<ChatInput
|
||||||
|
onSend={startChatWithPrompt}
|
||||||
|
placeholder={inputPlaceholder}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||||
|
{quickActions.map((action) => (
|
||||||
|
<Button
|
||||||
|
key={action}
|
||||||
|
type="button"
|
||||||
|
variant="outline"
|
||||||
|
size="small"
|
||||||
|
onClick={() => handleQuickAction(action)}
|
||||||
|
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
|
||||||
|
>
|
||||||
|
{action}
|
||||||
|
</Button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,235 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { WarningDiamondIcon } from "@phosphor-icons/react";
|
|
||||||
import type { ToolUIPart } from "ai";
|
|
||||||
import { useCopilotChatActions } from "../../components/CopilotChatActionsProvider/useCopilotChatActions";
|
|
||||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
|
||||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
|
||||||
import { ProgressBar } from "../../components/ProgressBar/ProgressBar";
|
|
||||||
import {
|
|
||||||
ContentCardDescription,
|
|
||||||
ContentCodeBlock,
|
|
||||||
ContentGrid,
|
|
||||||
ContentHint,
|
|
||||||
ContentLink,
|
|
||||||
ContentMessage,
|
|
||||||
} from "../../components/ToolAccordion/AccordionContent";
|
|
||||||
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
|
|
||||||
import { useAsymptoticProgress } from "../../hooks/useAsymptoticProgress";
|
|
||||||
import {
|
|
||||||
ClarificationQuestionsCard,
|
|
||||||
ClarifyingQuestion,
|
|
||||||
} from "./components/ClarificationQuestionsCard";
|
|
||||||
import {
|
|
||||||
AccordionIcon,
|
|
||||||
formatMaybeJson,
|
|
||||||
getAnimationText,
|
|
||||||
getCreateAgentToolOutput,
|
|
||||||
isAgentPreviewOutput,
|
|
||||||
isAgentSavedOutput,
|
|
||||||
isClarificationNeededOutput,
|
|
||||||
isErrorOutput,
|
|
||||||
isOperationInProgressOutput,
|
|
||||||
isOperationPendingOutput,
|
|
||||||
isOperationStartedOutput,
|
|
||||||
ToolIcon,
|
|
||||||
truncateText,
|
|
||||||
type CreateAgentToolOutput,
|
|
||||||
} from "./helpers";
|
|
||||||
|
|
||||||
export interface CreateAgentToolPart {
|
|
||||||
type: string;
|
|
||||||
toolCallId: string;
|
|
||||||
state: ToolUIPart["state"];
|
|
||||||
input?: unknown;
|
|
||||||
output?: unknown;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
part: CreateAgentToolPart;
|
|
||||||
}
|
|
||||||
|
|
||||||
function getAccordionMeta(output: CreateAgentToolOutput) {
|
|
||||||
const icon = <AccordionIcon />;
|
|
||||||
|
|
||||||
if (isAgentSavedOutput(output)) {
|
|
||||||
return { icon, title: output.agent_name };
|
|
||||||
}
|
|
||||||
if (isAgentPreviewOutput(output)) {
|
|
||||||
return {
|
|
||||||
icon,
|
|
||||||
title: output.agent_name,
|
|
||||||
description: `${output.node_count} block${output.node_count === 1 ? "" : "s"}`,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
if (isClarificationNeededOutput(output)) {
|
|
||||||
const questions = output.questions ?? [];
|
|
||||||
return {
|
|
||||||
icon,
|
|
||||||
title: "Needs clarification",
|
|
||||||
description: `${questions.length} question${questions.length === 1 ? "" : "s"}`,
|
|
||||||
expanded: true,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
if (
|
|
||||||
isOperationStartedOutput(output) ||
|
|
||||||
isOperationPendingOutput(output) ||
|
|
||||||
isOperationInProgressOutput(output)
|
|
||||||
) {
|
|
||||||
return {
|
|
||||||
icon: <OrbitLoader size={32} />,
|
|
||||||
title: "Creating agent, this may take a few minutes. Sit back and relax.",
|
|
||||||
};
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
icon: (
|
|
||||||
<WarningDiamondIcon size={32} weight="light" className="text-red-500" />
|
|
||||||
),
|
|
||||||
title: "Error",
|
|
||||||
titleClassName: "text-red-500",
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
export function CreateAgentTool({ part }: Props) {
|
|
||||||
const text = getAnimationText(part);
|
|
||||||
const { onSend } = useCopilotChatActions();
|
|
||||||
|
|
||||||
const isStreaming =
|
|
||||||
part.state === "input-streaming" || part.state === "input-available";
|
|
||||||
|
|
||||||
const output = getCreateAgentToolOutput(part);
|
|
||||||
|
|
||||||
const isError =
|
|
||||||
part.state === "output-error" || (!!output && isErrorOutput(output));
|
|
||||||
|
|
||||||
const isOperating =
|
|
||||||
!!output &&
|
|
||||||
(isOperationStartedOutput(output) ||
|
|
||||||
isOperationPendingOutput(output) ||
|
|
||||||
isOperationInProgressOutput(output));
|
|
||||||
|
|
||||||
const progress = useAsymptoticProgress(isOperating);
|
|
||||||
|
|
||||||
const hasExpandableContent =
|
|
||||||
part.state === "output-available" &&
|
|
||||||
!!output &&
|
|
||||||
(isOperationStartedOutput(output) ||
|
|
||||||
isOperationPendingOutput(output) ||
|
|
||||||
isOperationInProgressOutput(output) ||
|
|
||||||
isAgentPreviewOutput(output) ||
|
|
||||||
isAgentSavedOutput(output) ||
|
|
||||||
isClarificationNeededOutput(output) ||
|
|
||||||
isErrorOutput(output));
|
|
||||||
|
|
||||||
function handleClarificationAnswers(answers: Record<string, string>) {
|
|
||||||
const questions =
|
|
||||||
output && isClarificationNeededOutput(output)
|
|
||||||
? (output.questions ?? [])
|
|
||||||
: [];
|
|
||||||
|
|
||||||
const contextMessage = questions
|
|
||||||
.map((q) => {
|
|
||||||
const answer = answers[q.keyword] || "";
|
|
||||||
return `> ${q.question}\n\n${answer}`;
|
|
||||||
})
|
|
||||||
.join("\n\n");
|
|
||||||
|
|
||||||
onSend(
|
|
||||||
`**Here are my answers:**\n\n${contextMessage}\n\nPlease proceed with creating the agent.`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="py-2">
|
|
||||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
|
||||||
<ToolIcon isStreaming={isStreaming} isError={isError} />
|
|
||||||
<MorphingTextAnimation
|
|
||||||
text={text}
|
|
||||||
className={isError ? "text-red-500" : undefined}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{hasExpandableContent && output && (
|
|
||||||
<ToolAccordion {...getAccordionMeta(output)}>
|
|
||||||
{isOperating && (
|
|
||||||
<ContentGrid>
|
|
||||||
<ProgressBar value={progress} className="max-w-[280px]" />
|
|
||||||
<ContentHint>
|
|
||||||
This could take a few minutes, grab a coffee ☕
|
|
||||||
</ContentHint>
|
|
||||||
</ContentGrid>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{isAgentSavedOutput(output) && (
|
|
||||||
<ContentGrid>
|
|
||||||
<ContentMessage>{output.message}</ContentMessage>
|
|
||||||
<div className="flex flex-wrap gap-2">
|
|
||||||
<ContentLink href={output.library_agent_link}>
|
|
||||||
Open in library
|
|
||||||
</ContentLink>
|
|
||||||
<ContentLink href={output.agent_page_link}>
|
|
||||||
Open in builder
|
|
||||||
</ContentLink>
|
|
||||||
</div>
|
|
||||||
<ContentCodeBlock>
|
|
||||||
{truncateText(
|
|
||||||
formatMaybeJson({ agent_id: output.agent_id }),
|
|
||||||
800,
|
|
||||||
)}
|
|
||||||
</ContentCodeBlock>
|
|
||||||
</ContentGrid>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{isAgentPreviewOutput(output) && (
|
|
||||||
<ContentGrid>
|
|
||||||
<ContentMessage>{output.message}</ContentMessage>
|
|
||||||
{output.description?.trim() && (
|
|
||||||
<ContentCardDescription>
|
|
||||||
{output.description}
|
|
||||||
</ContentCardDescription>
|
|
||||||
)}
|
|
||||||
<ContentCodeBlock>
|
|
||||||
{truncateText(formatMaybeJson(output.agent_json), 1600)}
|
|
||||||
</ContentCodeBlock>
|
|
||||||
</ContentGrid>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{isClarificationNeededOutput(output) && (
|
|
||||||
<ClarificationQuestionsCard
|
|
||||||
questions={(output.questions ?? []).map((q) => {
|
|
||||||
const item: ClarifyingQuestion = {
|
|
||||||
question: q.question,
|
|
||||||
keyword: q.keyword,
|
|
||||||
};
|
|
||||||
const example =
|
|
||||||
typeof q.example === "string" && q.example.trim()
|
|
||||||
? q.example.trim()
|
|
||||||
: null;
|
|
||||||
if (example) item.example = example;
|
|
||||||
return item;
|
|
||||||
})}
|
|
||||||
message={output.message}
|
|
||||||
onSubmitAnswers={handleClarificationAnswers}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{isErrorOutput(output) && (
|
|
||||||
<ContentGrid>
|
|
||||||
<ContentMessage>{output.message}</ContentMessage>
|
|
||||||
{output.error && (
|
|
||||||
<ContentCodeBlock>
|
|
||||||
{formatMaybeJson(output.error)}
|
|
||||||
</ContentCodeBlock>
|
|
||||||
)}
|
|
||||||
{output.details && (
|
|
||||||
<ContentCodeBlock>
|
|
||||||
{formatMaybeJson(output.details)}
|
|
||||||
</ContentCodeBlock>
|
|
||||||
)}
|
|
||||||
</ContentGrid>
|
|
||||||
)}
|
|
||||||
</ToolAccordion>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,186 +0,0 @@
|
|||||||
import type { AgentPreviewResponse } from "@/app/api/__generated__/models/agentPreviewResponse";
|
|
||||||
import type { AgentSavedResponse } from "@/app/api/__generated__/models/agentSavedResponse";
|
|
||||||
import type { ClarificationNeededResponse } from "@/app/api/__generated__/models/clarificationNeededResponse";
|
|
||||||
import type { ErrorResponse } from "@/app/api/__generated__/models/errorResponse";
|
|
||||||
import type { OperationInProgressResponse } from "@/app/api/__generated__/models/operationInProgressResponse";
|
|
||||||
import type { OperationPendingResponse } from "@/app/api/__generated__/models/operationPendingResponse";
|
|
||||||
import type { OperationStartedResponse } from "@/app/api/__generated__/models/operationStartedResponse";
|
|
||||||
import { ResponseType } from "@/app/api/__generated__/models/responseType";
|
|
||||||
import {
|
|
||||||
PlusCircleIcon,
|
|
||||||
PlusIcon,
|
|
||||||
WarningDiamondIcon,
|
|
||||||
} from "@phosphor-icons/react";
|
|
||||||
import type { ToolUIPart } from "ai";
|
|
||||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
|
||||||
|
|
||||||
export type CreateAgentToolOutput =
|
|
||||||
| OperationStartedResponse
|
|
||||||
| OperationPendingResponse
|
|
||||||
| OperationInProgressResponse
|
|
||||||
| AgentPreviewResponse
|
|
||||||
| AgentSavedResponse
|
|
||||||
| ClarificationNeededResponse
|
|
||||||
| ErrorResponse;
|
|
||||||
|
|
||||||
function parseOutput(output: unknown): CreateAgentToolOutput | null {
|
|
||||||
if (!output) return null;
|
|
||||||
if (typeof output === "string") {
|
|
||||||
const trimmed = output.trim();
|
|
||||||
if (!trimmed) return null;
|
|
||||||
try {
|
|
||||||
return parseOutput(JSON.parse(trimmed) as unknown);
|
|
||||||
} catch {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (typeof output === "object") {
|
|
||||||
const type = (output as { type?: unknown }).type;
|
|
||||||
if (
|
|
||||||
type === ResponseType.operation_started ||
|
|
||||||
type === ResponseType.operation_pending ||
|
|
||||||
type === ResponseType.operation_in_progress ||
|
|
||||||
type === ResponseType.agent_preview ||
|
|
||||||
type === ResponseType.agent_saved ||
|
|
||||||
type === ResponseType.clarification_needed ||
|
|
||||||
type === ResponseType.error
|
|
||||||
) {
|
|
||||||
return output as CreateAgentToolOutput;
|
|
||||||
}
|
|
||||||
if ("operation_id" in output && "tool_name" in output)
|
|
||||||
return output as OperationStartedResponse | OperationPendingResponse;
|
|
||||||
if ("tool_call_id" in output) return output as OperationInProgressResponse;
|
|
||||||
if ("agent_json" in output && "agent_name" in output)
|
|
||||||
return output as AgentPreviewResponse;
|
|
||||||
if ("agent_id" in output && "library_agent_id" in output)
|
|
||||||
return output as AgentSavedResponse;
|
|
||||||
if ("questions" in output) return output as ClarificationNeededResponse;
|
|
||||||
if ("error" in output || "details" in output)
|
|
||||||
return output as ErrorResponse;
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getCreateAgentToolOutput(
|
|
||||||
part: unknown,
|
|
||||||
): CreateAgentToolOutput | null {
|
|
||||||
if (!part || typeof part !== "object") return null;
|
|
||||||
return parseOutput((part as { output?: unknown }).output);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function isOperationStartedOutput(
|
|
||||||
output: CreateAgentToolOutput,
|
|
||||||
): output is OperationStartedResponse {
|
|
||||||
return (
|
|
||||||
output.type === ResponseType.operation_started ||
|
|
||||||
("operation_id" in output && "tool_name" in output)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function isOperationPendingOutput(
|
|
||||||
output: CreateAgentToolOutput,
|
|
||||||
): output is OperationPendingResponse {
|
|
||||||
return output.type === ResponseType.operation_pending;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function isOperationInProgressOutput(
|
|
||||||
output: CreateAgentToolOutput,
|
|
||||||
): output is OperationInProgressResponse {
|
|
||||||
return (
|
|
||||||
output.type === ResponseType.operation_in_progress ||
|
|
||||||
"tool_call_id" in output
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function isAgentPreviewOutput(
|
|
||||||
output: CreateAgentToolOutput,
|
|
||||||
): output is AgentPreviewResponse {
|
|
||||||
return output.type === ResponseType.agent_preview || "agent_json" in output;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function isAgentSavedOutput(
|
|
||||||
output: CreateAgentToolOutput,
|
|
||||||
): output is AgentSavedResponse {
|
|
||||||
return (
|
|
||||||
output.type === ResponseType.agent_saved || "agent_page_link" in output
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function isClarificationNeededOutput(
|
|
||||||
output: CreateAgentToolOutput,
|
|
||||||
): output is ClarificationNeededResponse {
|
|
||||||
return (
|
|
||||||
output.type === ResponseType.clarification_needed || "questions" in output
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function isErrorOutput(
|
|
||||||
output: CreateAgentToolOutput,
|
|
||||||
): output is ErrorResponse {
|
|
||||||
return output.type === ResponseType.error || "error" in output;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getAnimationText(part: {
|
|
||||||
state: ToolUIPart["state"];
|
|
||||||
input?: unknown;
|
|
||||||
output?: unknown;
|
|
||||||
}): string {
|
|
||||||
switch (part.state) {
|
|
||||||
case "input-streaming":
|
|
||||||
case "input-available":
|
|
||||||
return "Creating a new agent";
|
|
||||||
case "output-available": {
|
|
||||||
const output = parseOutput(part.output);
|
|
||||||
if (!output) return "Creating a new agent";
|
|
||||||
if (isOperationStartedOutput(output)) return "Agent creation started";
|
|
||||||
if (isOperationPendingOutput(output)) return "Agent creation in progress";
|
|
||||||
if (isOperationInProgressOutput(output))
|
|
||||||
return "Agent creation already in progress";
|
|
||||||
if (isAgentSavedOutput(output)) return `Saved "${output.agent_name}"`;
|
|
||||||
if (isAgentPreviewOutput(output)) return `Preview "${output.agent_name}"`;
|
|
||||||
if (isClarificationNeededOutput(output)) return "Needs clarification";
|
|
||||||
return "Error creating agent";
|
|
||||||
}
|
|
||||||
case "output-error":
|
|
||||||
return "Error creating agent";
|
|
||||||
default:
|
|
||||||
return "Creating a new agent";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export function ToolIcon({
|
|
||||||
isStreaming,
|
|
||||||
isError,
|
|
||||||
}: {
|
|
||||||
isStreaming?: boolean;
|
|
||||||
isError?: boolean;
|
|
||||||
}) {
|
|
||||||
if (isError) {
|
|
||||||
return (
|
|
||||||
<WarningDiamondIcon size={14} weight="regular" className="text-red-500" />
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if (isStreaming) {
|
|
||||||
return <OrbitLoader size={24} />;
|
|
||||||
}
|
|
||||||
return <PlusIcon size={14} weight="regular" className="text-neutral-400" />;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function AccordionIcon() {
|
|
||||||
return <PlusCircleIcon size={32} weight="light" />;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function formatMaybeJson(value: unknown): string {
|
|
||||||
if (typeof value === "string") return value;
|
|
||||||
try {
|
|
||||||
return JSON.stringify(value, null, 2);
|
|
||||||
} catch {
|
|
||||||
return String(value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export function truncateText(text: string, maxChars: number): string {
|
|
||||||
const trimmed = text.trim();
|
|
||||||
if (trimmed.length <= maxChars) return trimmed;
|
|
||||||
return `${trimmed.slice(0, maxChars).trimEnd()}…`;
|
|
||||||
}
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user