mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
51 Commits
feat/build
...
feature/so
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64689f7319 | ||
|
|
095821b252 | ||
|
|
5cfdce995d | ||
|
|
c24ce05970 | ||
|
|
3f2cfa93ef | ||
|
|
fb035ffd35 | ||
|
|
68eb8f35d7 | ||
|
|
23b90151fe | ||
|
|
1ace125395 | ||
|
|
59b5e64c29 | ||
|
|
22666670cc | ||
|
|
9adaeda70c | ||
|
|
9a5a852be2 | ||
|
|
cce1c60ab7 | ||
|
|
34690c463d | ||
|
|
5feab3dcfa | ||
|
|
21631a565b | ||
|
|
2891e5c48e | ||
|
|
05d1269758 | ||
|
|
943a1df815 | ||
|
|
593001e0c8 | ||
|
|
e1db8234a3 | ||
|
|
282173be9d | ||
|
|
5d9a169e04 | ||
|
|
6fd1050457 | ||
|
|
02708bcd00 | ||
|
|
156d61fe5c | ||
|
|
5a29de0e0e | ||
|
|
e657472162 | ||
|
|
4d00e0f179 | ||
|
|
1d7282b5f3 | ||
|
|
e3591fcaa3 | ||
|
|
876dc32e17 | ||
|
|
616e29f5e4 | ||
|
|
280a98ad38 | ||
|
|
c7f2a7dd03 | ||
|
|
6d0e2063ec | ||
|
|
8b577ae194 | ||
|
|
d8f5f783ae | ||
|
|
82d22f3680 | ||
|
|
50622333d1 | ||
|
|
27af5782a9 | ||
|
|
522f932e67 | ||
|
|
a6124b06d5 | ||
|
|
ae660ea04f | ||
|
|
2479f3a1c4 | ||
|
|
8153306384 | ||
|
|
9c3d100a22 | ||
|
|
fc3bf6c154 | ||
|
|
e32d258a7e | ||
|
|
3e86544bfe |
@@ -2,7 +2,7 @@
|
||||
name: pr-address
|
||||
description: Address PR review comments and loop until CI green and all comments resolved. TRIGGER when user asks to address comments, fix PR feedback, respond to reviewers, or babysit/monitor a PR.
|
||||
user-invocable: true
|
||||
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
argument-hint: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
@@ -40,8 +40,8 @@ Address comments **one at a time**: fix → commit → push → inline reply →
|
||||
|
||||
| Comment type | How to reply |
|
||||
|---|---|
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="Fixed in <commit-sha>: <description>"` |
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
|
||||
## Format and commit
|
||||
|
||||
@@ -69,11 +69,78 @@ For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
||||
|
||||
```text
|
||||
address comments → format → commit → push
|
||||
→ re-check comments → fix new ones → push
|
||||
→ wait for CI → re-check comments after CI settles
|
||||
→ wait for CI (while addressing new comments) → fix failures → push
|
||||
→ re-check comments after CI settles
|
||||
→ repeat until: all comments addressed AND CI green AND no new comments arriving
|
||||
```
|
||||
|
||||
While CI runs, stay productive: run local tests, address remaining comments.
|
||||
### Polling for CI + new comments
|
||||
|
||||
**The loop ends when:** CI fully green + all comments addressed + no new comments since CI settled.
|
||||
After pushing, poll for **both** CI status and new comments in a single loop. Do not use `gh pr checks --watch` — it blocks the tool and prevents reacting to new comments while CI is running.
|
||||
|
||||
> **Note:** `gh pr checks --watch --fail-fast` is tempting but it blocks the entire Bash tool call, meaning the agent cannot check for or address new comments until CI fully completes. Always poll manually instead.
|
||||
|
||||
**Polling loop — repeat every 30 seconds:**
|
||||
|
||||
1. Check CI status:
|
||||
```bash
|
||||
gh pr checks {N} --repo Significant-Gravitas/AutoGPT --json bucket,name,link
|
||||
```
|
||||
Parse the results: if every check has `bucket` of `"pass"` or `"skipping"`, CI is green. If any has `"fail"`, CI has failed. Otherwise CI is still pending.
|
||||
|
||||
2. Check for merge conflicts:
|
||||
```bash
|
||||
gh pr view {N} --repo Significant-Gravitas/AutoGPT --json mergeable --jq '.mergeable'
|
||||
```
|
||||
If the result is `"CONFLICTING"`, the PR has a merge conflict — see "Resolving merge conflicts" below. If `"UNKNOWN"`, GitHub is still computing mergeability — wait and re-check next poll.
|
||||
|
||||
3. Check for new comments (all three sources):
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments # inline review comments
|
||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments # PR conversation comments
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews # top-level reviews
|
||||
```
|
||||
Compare against previously seen comments to detect new ones.
|
||||
|
||||
4. **React in this precedence order (first match wins):**
|
||||
|
||||
| What happened | Action |
|
||||
|---|---|
|
||||
| Merge conflict detected | See "Resolving merge conflicts" below. |
|
||||
| Mergeability is `UNKNOWN` | GitHub is still computing mergeability. Sleep 30 seconds, then restart polling from the top. |
|
||||
| New comments detected | Address them (fix → commit → push → reply). After pushing, re-fetch all comments to update your baseline, then restart this polling loop from the top (new commits invalidate CI status). |
|
||||
| CI failed (bucket == "fail") | Get failed check links: `gh pr checks {N} --repo Significant-Gravitas/AutoGPT --json bucket,link --jq '.[] \| select(.bucket == "fail") \| .link'`. Extract run ID from link (format: `.../actions/runs/<run-id>/job/...`), read logs with `gh run view <run-id> --repo Significant-Gravitas/AutoGPT --log-failed`. Fix → commit → push → restart polling. |
|
||||
| CI green + no new comments | **Do not exit immediately.** Bots (coderabbitai, sentry) often post reviews shortly after CI settles. Continue polling for **2 more cycles (60s)** after CI goes green. Only exit after 2 consecutive green+quiet polls. |
|
||||
| CI pending + no new comments | Sleep 30 seconds, then poll again. |
|
||||
|
||||
**The loop ends when:** CI fully green + all comments addressed + **2 consecutive polls with no new comments after CI settled.**
|
||||
|
||||
### Resolving merge conflicts
|
||||
|
||||
1. Identify the PR's target branch and remote:
|
||||
```bash
|
||||
gh pr view {N} --repo Significant-Gravitas/AutoGPT --json baseRefName --jq '.baseRefName'
|
||||
git remote -v # find the remote pointing to Significant-Gravitas/AutoGPT (typically 'upstream' in forks, 'origin' for direct contributors)
|
||||
```
|
||||
|
||||
2. Pull the latest base branch with a 3-way merge:
|
||||
```bash
|
||||
git pull {base-remote} {base-branch} --no-rebase
|
||||
```
|
||||
|
||||
3. Resolve conflicting files, then verify no conflict markers remain:
|
||||
```bash
|
||||
if grep -R -n -E '^(<<<<<<<|=======|>>>>>>>)' <conflicted-files>; then
|
||||
echo "Unresolved conflict markers found — resolve before proceeding."
|
||||
exit 1
|
||||
fi
|
||||
```
|
||||
|
||||
4. Stage and push:
|
||||
```bash
|
||||
git add <conflicted-files>
|
||||
git commit -m "Resolve merge conflicts with {base-branch}"
|
||||
git push
|
||||
```
|
||||
|
||||
5. Restart the polling loop from the top — new commits reset CI status.
|
||||
|
||||
114
.github/workflows/platform-backend-ci.yml
vendored
114
.github/workflows/platform-backend-ci.yml
vendored
@@ -27,10 +27,91 @@ defaults:
|
||||
working-directory: autogpt_platform/backend
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 10
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-py3.12-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Using Poetry version ${HEAD_POETRY_VERSION}"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run Linters
|
||||
run: poetry run lint --skip-pyright
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
type-check:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11", "3.12", "3.13"]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Using Poetry version ${HEAD_POETRY_VERSION}"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
- name: Run Pyright
|
||||
run: poetry run pyright --pythonversion ${{ matrix.python-version }}
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
timeout-minutes: 15
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -98,9 +179,9 @@ jobs:
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
key: poetry-${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock
|
||||
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
@@ -158,22 +239,22 @@ jobs:
|
||||
echo "Waiting for ClamAV daemon to start..."
|
||||
max_attempts=60
|
||||
attempt=0
|
||||
|
||||
|
||||
until nc -z localhost 3310 || [ $attempt -eq $max_attempts ]; do
|
||||
echo "ClamAV is unavailable - sleeping (attempt $((attempt+1))/$max_attempts)"
|
||||
sleep 5
|
||||
attempt=$((attempt+1))
|
||||
done
|
||||
|
||||
|
||||
if [ $attempt -eq $max_attempts ]; then
|
||||
echo "ClamAV failed to start after $((max_attempts*5)) seconds"
|
||||
echo "Checking ClamAV service logs..."
|
||||
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "ClamAV is ready!"
|
||||
|
||||
|
||||
# Verify ClamAV is responsive
|
||||
echo "Testing ClamAV connection..."
|
||||
timeout 10 bash -c 'echo "PING" | nc localhost 3310' || {
|
||||
@@ -188,18 +269,13 @@ jobs:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
run: poetry run lint
|
||||
|
||||
- name: Run pytest with coverage
|
||||
- name: Run pytest
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
else
|
||||
poetry run pytest -s -vv
|
||||
fi
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
@@ -211,6 +287,12 @@ jobs:
|
||||
REDIS_PORT: "6379"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
@@ -224,9 +306,3 @@ jobs:
|
||||
# the backend service, docker composes, and examples
|
||||
RABBITMQ_DEFAULT_USER: "rabbitmq_user_default"
|
||||
RABBITMQ_DEFAULT_PASS: "k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7"
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
|
||||
4
.github/workflows/platform-fullstack-ci.yml
vendored
4
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -294,7 +294,7 @@ jobs:
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
path: autogpt_platform/frontend/playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
@@ -303,7 +303,7 @@ jobs:
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: test-results
|
||||
path: autogpt_platform/frontend/test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **Lazy `%s` logging** — `logger.info("Processing %s items", count)` not `logger.info(f"Processing {count} items")`
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
@@ -75,6 +75,7 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ RUN poetry install --no-ansi --no-root
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
COPY autogpt_platform/backend/scripts/gen_prisma_types_stub.py ./scripts/
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# =============================== DB MIGRATOR =============================== #
|
||||
@@ -82,7 +82,7 @@ RUN pip3 install prisma>=0.15.0 --break-system-packages
|
||||
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
COPY autogpt_platform/backend/scripts/gen_prisma_types_stub.py ./scripts/
|
||||
COPY autogpt_platform/backend/migrations ./migrations
|
||||
|
||||
# ============================== BACKEND SERVER ============================== #
|
||||
@@ -121,19 +121,37 @@ RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
|
||||
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||
fonts-liberation libfontconfig1 \
|
||||
# Install agent-browser (Copilot browser tool) + Chromium.
|
||||
# On amd64: install runtime libs + run `agent-browser install` to download
|
||||
# Chrome for Testing (pinned version, tested with Playwright).
|
||||
# On arm64: install system chromium package — Chrome for Testing has no ARM64
|
||||
# binary. AGENT_BROWSER_EXECUTABLE_PATH is set at runtime by the entrypoint
|
||||
# script (below) to redirect agent-browser to the system binary.
|
||||
ARG TARGETARCH
|
||||
RUN apt-get update \
|
||||
&& if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
apt-get install -y --no-install-recommends chromium fonts-liberation; \
|
||||
else \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||
fonts-liberation libfontconfig1; \
|
||||
fi \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& npm install -g agent-browser \
|
||||
&& agent-browser install \
|
||||
&& ([ "$TARGETARCH" = "arm64" ] || agent-browser install) \
|
||||
&& rm -rf /tmp/* /root/.npm
|
||||
|
||||
# On arm64 the system chromium is at /usr/bin/chromium; set
|
||||
# AGENT_BROWSER_EXECUTABLE_PATH so agent-browser's daemon uses it instead of
|
||||
# Chrome for Testing (which has no ARM64 binary). On amd64 the variable is left
|
||||
# unset so agent-browser uses the Chrome for Testing binary it downloaded above.
|
||||
RUN printf '#!/bin/sh\n[ -x /usr/bin/chromium ] && export AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium\nexec "$@"\n' \
|
||||
> /usr/local/bin/entrypoint.sh \
|
||||
&& chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
# Copy only the .venv from builder (not the entire /app directory)
|
||||
@@ -155,4 +173,5 @@ RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
|
||||
ENV PORT=8000
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["rest"]
|
||||
|
||||
@@ -4,14 +4,12 @@ from difflib import SequenceMatcher
|
||||
from typing import Any, Sequence, get_args, get_origin
|
||||
|
||||
import prisma
|
||||
from prisma.enums import ContentType
|
||||
from prisma.models import mv_suggested_blocks
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks._base import (
|
||||
AnyBlockSchema,
|
||||
@@ -24,6 +22,7 @@ from backend.blocks.llm import LlmModel
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.text import split_camelcase
|
||||
|
||||
from .model import (
|
||||
BlockCategoryResponse,
|
||||
@@ -271,7 +270,7 @@ async def _build_cached_search_results(
|
||||
|
||||
# Use hybrid search when query is present, otherwise list all blocks
|
||||
if (include_blocks or include_integrations) and normalized_query:
|
||||
block_results, block_total, integration_total = await _hybrid_search_blocks(
|
||||
block_results, block_total, integration_total = await _text_search_blocks(
|
||||
query=search_query,
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
@@ -383,117 +382,75 @@ def _collect_block_results(
|
||||
return results, block_count, integration_count
|
||||
|
||||
|
||||
async def _hybrid_search_blocks(
|
||||
async def _text_search_blocks(
|
||||
*,
|
||||
query: str,
|
||||
include_blocks: bool,
|
||||
include_integrations: bool,
|
||||
) -> tuple[list[_ScoredItem], int, int]:
|
||||
"""
|
||||
Search blocks using hybrid search with builder-specific filtering.
|
||||
Search blocks using in-memory text matching over the block registry.
|
||||
|
||||
Uses unified_hybrid_search for semantic + lexical search, then applies
|
||||
post-filtering for block/integration types and scoring adjustments.
|
||||
All blocks are already loaded in memory, so this is fast and reliable
|
||||
regardless of whether OpenAI embeddings are available.
|
||||
|
||||
Scoring:
|
||||
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
|
||||
- Base: text relevance via _score_primary_fields, plus BLOCK_SCORE_BOOST
|
||||
to prioritize blocks over marketplace agents in combined results
|
||||
- +30 for exact name match, +15 for prefix name match
|
||||
- +20 if the block has an LlmModel field and the query matches an LLM model name
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
include_blocks: Whether to include regular blocks
|
||||
include_integrations: Whether to include integration blocks
|
||||
|
||||
Returns:
|
||||
Tuple of (scored_items, block_count, integration_count)
|
||||
"""
|
||||
results: list[_ScoredItem] = []
|
||||
block_count = 0
|
||||
integration_count = 0
|
||||
|
||||
if not include_blocks and not include_integrations:
|
||||
return results, block_count, integration_count
|
||||
return results, 0, 0
|
||||
|
||||
normalized_query = query.strip().lower()
|
||||
|
||||
# Fetch more results to account for post-filtering
|
||||
search_results, _ = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=150,
|
||||
min_score=0.10,
|
||||
all_results, _, _ = _collect_block_results(
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
)
|
||||
|
||||
# Load all blocks for getting BlockInfo
|
||||
all_blocks = load_all_blocks()
|
||||
|
||||
for result in search_results:
|
||||
block_id = result["content_id"]
|
||||
for item in all_results:
|
||||
block_info = item.item
|
||||
assert isinstance(block_info, BlockInfo)
|
||||
name = split_camelcase(block_info.name).lower()
|
||||
|
||||
# Skip excluded blocks
|
||||
if block_id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
# Build rich description including input field descriptions,
|
||||
# matching the searchable text that the embedding pipeline uses
|
||||
desc_parts = [block_info.description or ""]
|
||||
block_cls = all_blocks.get(block_info.id)
|
||||
if block_cls is not None:
|
||||
block: AnyBlockSchema = block_cls()
|
||||
desc_parts += [
|
||||
f"{f}: {info.description}"
|
||||
for f, info in block.input_schema.model_fields.items()
|
||||
if info.description
|
||||
]
|
||||
description = " ".join(desc_parts).lower()
|
||||
|
||||
metadata = result.get("metadata", {})
|
||||
hybrid_score = result.get("relevance", 0.0)
|
||||
|
||||
# Get the actual block class
|
||||
if block_id not in all_blocks:
|
||||
continue
|
||||
|
||||
block_cls = all_blocks[block_id]
|
||||
block: AnyBlockSchema = block_cls()
|
||||
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
# Check block/integration filter using metadata
|
||||
is_integration = metadata.get("is_integration", False)
|
||||
|
||||
if is_integration and not include_integrations:
|
||||
continue
|
||||
if not is_integration and not include_blocks:
|
||||
continue
|
||||
|
||||
# Get block info
|
||||
block_info = block.get_info()
|
||||
|
||||
# Calculate final score: scale hybrid score and add builder-specific bonuses
|
||||
# Hybrid scores are 0-1, builder scores were 0-200+
|
||||
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
|
||||
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
|
||||
score = _score_primary_fields(name, description, normalized_query)
|
||||
|
||||
# Add LLM model match bonus
|
||||
has_llm_field = metadata.get("has_llm_model_field", False)
|
||||
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
|
||||
final_score += 20
|
||||
if block_cls is not None and _matches_llm_model(
|
||||
block_cls().input_schema, normalized_query
|
||||
):
|
||||
score += 20
|
||||
|
||||
# Add exact/prefix match bonus for deterministic tie-breaking
|
||||
name = block_info.name.lower()
|
||||
if name == normalized_query:
|
||||
final_score += 30
|
||||
elif name.startswith(normalized_query):
|
||||
final_score += 15
|
||||
|
||||
# Track counts
|
||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||
if is_integration:
|
||||
integration_count += 1
|
||||
else:
|
||||
block_count += 1
|
||||
|
||||
results.append(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=filter_type,
|
||||
score=final_score,
|
||||
sort_key=name,
|
||||
if score >= MIN_SCORE_FOR_FILTERED_RESULTS:
|
||||
results.append(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=item.filter_type,
|
||||
score=score + BLOCK_SCORE_BOOST,
|
||||
sort_key=name,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
block_count = sum(1 for r in results if r.filter_type == "blocks")
|
||||
integration_count = sum(1 for r in results if r.filter_type == "integrations")
|
||||
return results, block_count, integration_count
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ from backend.data.graph import GraphSettings
|
||||
from backend.data.includes import (
|
||||
AGENT_PRESET_INCLUDE,
|
||||
LIBRARY_FOLDER_INCLUDE,
|
||||
MAX_LIBRARY_AGENT_EXECUTIONS_FETCH,
|
||||
MAX_LIBRARY_AGENTS_LAST_EXECUTED_FETCH,
|
||||
library_agent_include,
|
||||
)
|
||||
from backend.data.model import CredentialsMetaInput, GraphInput
|
||||
@@ -59,7 +61,7 @@ async def list_library_agents(
|
||||
Args:
|
||||
user_id: The ID of the user whose LibraryAgents we want to retrieve.
|
||||
search_term: Optional string to filter agents by name/description.
|
||||
sort_by: Sorting field (createdAt, updatedAt, isFavorite, isCreatedByUser).
|
||||
sort_by: Sorting field (createdAt, updatedAt, lastExecuted).
|
||||
page: Current page (1-indexed).
|
||||
page_size: Number of items per page.
|
||||
folder_id: Filter by folder ID. If provided, only returns agents in this folder.
|
||||
@@ -124,16 +126,84 @@ async def list_library_agents(
|
||||
elif sort_by == library_model.LibraryAgentSort.UPDATED_AT:
|
||||
order_by = {"updatedAt": "desc"}
|
||||
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=include_executions
|
||||
),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
agent_count = await prisma.models.LibraryAgent.prisma().count(where=where_clause)
|
||||
# For LAST_EXECUTED sorting, we need to fetch execution data and sort in Python
|
||||
# since Prisma doesn't support sorting by nested relations
|
||||
if sort_by == library_model.LibraryAgentSort.LAST_EXECUTED:
|
||||
# TODO: This fetches up to MAX_LIBRARY_AGENTS_LAST_EXECUTED_FETCH agents
|
||||
# into memory for sorting. Prisma doesn't support sorting by nested relations,
|
||||
# so a dedicated lastExecutedAt column or raw SQL query would be needed for
|
||||
# database-level pagination. The ceiling prevents worst-case memory blowup.
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
take=MAX_LIBRARY_AGENTS_LAST_EXECUTED_FETCH,
|
||||
include=library_agent_include(
|
||||
user_id,
|
||||
include_nodes=False,
|
||||
include_executions=True,
|
||||
execution_limit=1,
|
||||
),
|
||||
)
|
||||
|
||||
def get_sort_key(
|
||||
agent: prisma.models.LibraryAgent,
|
||||
) -> tuple[int, float]:
|
||||
"""
|
||||
Returns a tuple for sorting: (has_no_executions, -timestamp).
|
||||
|
||||
Agents WITH executions come first (sorted by most recent execution),
|
||||
agents WITHOUT executions come last (sorted by creation date).
|
||||
"""
|
||||
graph = agent.AgentGraph
|
||||
if graph and graph.Executions and len(graph.Executions) > 0:
|
||||
execution = graph.Executions[0]
|
||||
timestamp = execution.updatedAt or execution.createdAt
|
||||
return (0, -timestamp.timestamp())
|
||||
return (1, -agent.createdAt.timestamp())
|
||||
|
||||
library_agents.sort(key=get_sort_key)
|
||||
|
||||
# Apply pagination after sorting
|
||||
agent_count = len(library_agents)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
page_agents = library_agents[start_idx:end_idx]
|
||||
|
||||
# Re-fetch the page agents with full execution data so that metrics
|
||||
# (execution_count, success_rate, avg_correctness_score, status) are
|
||||
# accurate. The sort-only fetch above used execution_limit=1 which
|
||||
# would make all metrics derived from a single execution.
|
||||
if include_executions and page_agents:
|
||||
page_agent_ids = [a.id for a in page_agents]
|
||||
full_exec_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where={"id": {"in": page_agent_ids}},
|
||||
include=library_agent_include(
|
||||
user_id,
|
||||
include_nodes=False,
|
||||
include_executions=True,
|
||||
execution_limit=MAX_LIBRARY_AGENT_EXECUTIONS_FETCH,
|
||||
),
|
||||
)
|
||||
# Restore sort order (find_many with `in` does not guarantee order)
|
||||
full_exec_map = {a.id: a for a in full_exec_agents}
|
||||
library_agents = [
|
||||
full_exec_map[a.id] for a in page_agents if a.id in full_exec_map
|
||||
]
|
||||
else:
|
||||
library_agents = page_agents
|
||||
else:
|
||||
# Standard sorting via database
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=include_executions
|
||||
),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
agent_count = await prisma.models.LibraryAgent.prisma().count(
|
||||
where=where_clause
|
||||
)
|
||||
|
||||
logger.debug(f"Retrieved {len(library_agents)} library agents for user #{user_id}")
|
||||
|
||||
@@ -337,6 +407,20 @@ async def get_library_agent_by_graph_id(
|
||||
graph_id: str,
|
||||
graph_version: Optional[int] = None,
|
||||
) -> library_model.LibraryAgent | None:
|
||||
"""
|
||||
Retrieves a library agent by its graph ID for a given user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user who owns the library agent.
|
||||
graph_id: The ID of the agent graph to look up.
|
||||
graph_version: Optional specific version of the graph to retrieve.
|
||||
|
||||
Returns:
|
||||
The LibraryAgent if found, otherwise None.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's an error during retrieval.
|
||||
"""
|
||||
filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"agentGraphId": graph_id,
|
||||
"userId": user_id,
|
||||
@@ -724,6 +808,17 @@ async def update_library_agent(
|
||||
async def delete_library_agent(
|
||||
library_agent_id: str, user_id: str, soft_delete: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Deletes a library agent and cleans up associated schedules and webhooks.
|
||||
|
||||
Args:
|
||||
library_agent_id: The ID of the library agent to delete.
|
||||
user_id: The ID of the user who owns the library agent.
|
||||
soft_delete: If True, marks the agent as deleted; if False, permanently removes it.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the library agent is not found or doesn't belong to the user.
|
||||
"""
|
||||
# First get the agent to find the graph_id for cleanup
|
||||
library_agent = await prisma.models.LibraryAgent.prisma().find_unique(
|
||||
where={"id": library_agent_id}, include={"AgentGraph": True}
|
||||
@@ -1827,6 +1922,20 @@ async def update_preset(
|
||||
async def set_preset_webhook(
|
||||
user_id: str, preset_id: str, webhook_id: str | None
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
"""
|
||||
Sets or removes a webhook connection for a preset.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user who owns the preset.
|
||||
preset_id: The ID of the preset to update.
|
||||
webhook_id: The ID of the webhook to connect, or None to disconnect.
|
||||
|
||||
Returns:
|
||||
The updated LibraryAgentPreset.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the preset is not found or doesn't belong to the user.
|
||||
"""
|
||||
current = await prisma.models.AgentPreset.prisma().find_unique(
|
||||
where={"id": preset_id},
|
||||
include=AGENT_PRESET_INCLUDE,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
@@ -8,6 +8,7 @@ from backend.data.db import connect
|
||||
from backend.data.includes import library_agent_include
|
||||
|
||||
from . import db
|
||||
from . import model as library_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -224,3 +225,506 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_library_agents_sort_by_last_executed(mocker):
|
||||
"""
|
||||
Test LAST_EXECUTED sorting behavior:
|
||||
- Agents WITH executions come first, sorted by most recent execution (updatedAt)
|
||||
- Agents WITHOUT executions come last, sorted by creation date
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Agent 1: Has execution that finished 1 hour ago
|
||||
agent1_execution = prisma.models.AgentGraphExecution(
|
||||
id="exec1",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
userId="test-user",
|
||||
createdAt=now - timedelta(hours=2),
|
||||
updatedAt=now - timedelta(hours=1), # Finished 1 hour ago
|
||||
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||
isDeleted=False,
|
||||
isShared=False,
|
||||
)
|
||||
agent1_graph = prisma.models.AgentGraph(
|
||||
id="agent1",
|
||||
version=1,
|
||||
name="Agent With Recent Execution",
|
||||
description="Has execution finished 1 hour ago",
|
||||
userId="test-user",
|
||||
isActive=True,
|
||||
createdAt=now - timedelta(days=5),
|
||||
Executions=[agent1_execution],
|
||||
)
|
||||
library_agent1 = prisma.models.LibraryAgent(
|
||||
id="lib1",
|
||||
userId="test-user",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
settings="{}", # type: ignore
|
||||
isCreatedByUser=True,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=now - timedelta(days=5),
|
||||
updatedAt=now - timedelta(days=5),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=agent1_graph,
|
||||
)
|
||||
|
||||
# Agent 2: Has execution that finished 3 hours ago
|
||||
agent2_execution = prisma.models.AgentGraphExecution(
|
||||
id="exec2",
|
||||
agentGraphId="agent2",
|
||||
agentGraphVersion=1,
|
||||
userId="test-user",
|
||||
createdAt=now - timedelta(hours=5),
|
||||
updatedAt=now - timedelta(hours=3), # Finished 3 hours ago
|
||||
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||
isDeleted=False,
|
||||
isShared=False,
|
||||
)
|
||||
agent2_graph = prisma.models.AgentGraph(
|
||||
id="agent2",
|
||||
version=1,
|
||||
name="Agent With Older Execution",
|
||||
description="Has execution finished 3 hours ago",
|
||||
userId="test-user",
|
||||
isActive=True,
|
||||
createdAt=now - timedelta(days=3),
|
||||
Executions=[agent2_execution],
|
||||
)
|
||||
library_agent2 = prisma.models.LibraryAgent(
|
||||
id="lib2",
|
||||
userId="test-user",
|
||||
agentGraphId="agent2",
|
||||
agentGraphVersion=1,
|
||||
settings="{}", # type: ignore
|
||||
isCreatedByUser=True,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=now - timedelta(days=3),
|
||||
updatedAt=now - timedelta(days=3),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=agent2_graph,
|
||||
)
|
||||
|
||||
# Agent 3: No executions, created 1 day ago (should come after agents with executions)
|
||||
agent3_graph = prisma.models.AgentGraph(
|
||||
id="agent3",
|
||||
version=1,
|
||||
name="Agent Without Executions (Newer)",
|
||||
description="No executions, created 1 day ago",
|
||||
userId="test-user",
|
||||
isActive=True,
|
||||
createdAt=now - timedelta(days=1),
|
||||
Executions=[],
|
||||
)
|
||||
library_agent3 = prisma.models.LibraryAgent(
|
||||
id="lib3",
|
||||
userId="test-user",
|
||||
agentGraphId="agent3",
|
||||
agentGraphVersion=1,
|
||||
settings="{}", # type: ignore
|
||||
isCreatedByUser=True,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=now - timedelta(days=1),
|
||||
updatedAt=now - timedelta(days=1),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=agent3_graph,
|
||||
)
|
||||
|
||||
# Agent 4: No executions, created 2 days ago
|
||||
agent4_graph = prisma.models.AgentGraph(
|
||||
id="agent4",
|
||||
version=1,
|
||||
name="Agent Without Executions (Older)",
|
||||
description="No executions, created 2 days ago",
|
||||
userId="test-user",
|
||||
isActive=True,
|
||||
createdAt=now - timedelta(days=2),
|
||||
Executions=[],
|
||||
)
|
||||
library_agent4 = prisma.models.LibraryAgent(
|
||||
id="lib4",
|
||||
userId="test-user",
|
||||
agentGraphId="agent4",
|
||||
agentGraphVersion=1,
|
||||
settings="{}", # type: ignore
|
||||
isCreatedByUser=True,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=now - timedelta(days=2),
|
||||
updatedAt=now - timedelta(days=2),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=agent4_graph,
|
||||
)
|
||||
|
||||
# Return agents in random order to verify sorting works
|
||||
mock_library_agents = [
|
||||
library_agent3,
|
||||
library_agent1,
|
||||
library_agent4,
|
||||
library_agent2,
|
||||
]
|
||||
|
||||
# Mock prisma calls
|
||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||
mock_agent_graph.return_value.find_many = mocker.AsyncMock(return_value=[])
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_library_agents
|
||||
)
|
||||
|
||||
# Call function with LAST_EXECUTED sort (without include_executions)
|
||||
result = await db.list_library_agents(
|
||||
"test-user",
|
||||
sort_by=library_model.LibraryAgentSort.LAST_EXECUTED,
|
||||
)
|
||||
|
||||
# Verify sorting order:
|
||||
# 1. Agent 1 (execution finished 1 hour ago) - most recent execution
|
||||
# 2. Agent 2 (execution finished 3 hours ago) - older execution
|
||||
# 3. Agent 3 (no executions, created 1 day ago) - newer creation
|
||||
# 4. Agent 4 (no executions, created 2 days ago) - older creation
|
||||
assert len(result.agents) == 4
|
||||
assert (
|
||||
result.agents[0].id == "lib1"
|
||||
), "Agent with most recent execution should be first"
|
||||
assert result.agents[1].id == "lib2", "Agent with older execution should be second"
|
||||
assert (
|
||||
result.agents[2].id == "lib3"
|
||||
), "Agent without executions (newer) should be third"
|
||||
assert (
|
||||
result.agents[3].id == "lib4"
|
||||
), "Agent without executions (older) should be last"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_library_agents_last_executed_metrics_accuracy(mocker):
|
||||
"""
|
||||
Test that when LAST_EXECUTED sort is used with include_executions=True,
|
||||
metrics (execution_count, success_rate) are computed from the full execution
|
||||
history, not from the single execution used for sort-order determination.
|
||||
|
||||
Bug: execution_limit=1 was used for both sorting AND metric calculation,
|
||||
causing execution_count to always be 0 or 1 and success_rate to be 0% or 100%.
|
||||
Fix: after sorting/pagination, re-fetch the page agents with full execution data.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Agent with 1 execution (used for sort-key fetch, execution_limit=1)
|
||||
sort_execution = prisma.models.AgentGraphExecution(
|
||||
id="exec-sort",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
userId="test-user",
|
||||
createdAt=now - timedelta(hours=2),
|
||||
updatedAt=now - timedelta(hours=1),
|
||||
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||
isDeleted=False,
|
||||
isShared=False,
|
||||
)
|
||||
sort_graph = prisma.models.AgentGraph(
|
||||
id="agent1",
|
||||
version=1,
|
||||
name="Agent With Many Executions",
|
||||
description="Should show full execution count",
|
||||
userId="test-user",
|
||||
isActive=True,
|
||||
createdAt=now - timedelta(days=5),
|
||||
Executions=[sort_execution], # Only 1 for sort
|
||||
)
|
||||
sort_library_agent = prisma.models.LibraryAgent(
|
||||
id="lib1",
|
||||
userId="test-user",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
settings="{}", # type: ignore
|
||||
isCreatedByUser=True,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=now - timedelta(days=5),
|
||||
updatedAt=now - timedelta(days=5),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=sort_graph,
|
||||
)
|
||||
|
||||
# Agent with full execution history (used for metric calculation, full execution_limit)
|
||||
full_exec1 = prisma.models.AgentGraphExecution(
|
||||
id="exec1",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
userId="test-user",
|
||||
createdAt=now - timedelta(hours=2),
|
||||
updatedAt=now - timedelta(hours=1),
|
||||
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||
isDeleted=False,
|
||||
isShared=False,
|
||||
)
|
||||
full_exec2 = prisma.models.AgentGraphExecution(
|
||||
id="exec2",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
userId="test-user",
|
||||
createdAt=now - timedelta(hours=4),
|
||||
updatedAt=now - timedelta(hours=3),
|
||||
executionStatus=prisma.enums.AgentExecutionStatus.FAILED,
|
||||
isDeleted=False,
|
||||
isShared=False,
|
||||
)
|
||||
full_exec3 = prisma.models.AgentGraphExecution(
|
||||
id="exec3",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
userId="test-user",
|
||||
createdAt=now - timedelta(hours=6),
|
||||
updatedAt=now - timedelta(hours=5),
|
||||
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||
isDeleted=False,
|
||||
isShared=False,
|
||||
)
|
||||
full_graph = prisma.models.AgentGraph(
|
||||
id="agent1",
|
||||
version=1,
|
||||
name="Agent With Many Executions",
|
||||
description="Should show full execution count",
|
||||
userId="test-user",
|
||||
isActive=True,
|
||||
createdAt=now - timedelta(days=5),
|
||||
Executions=[full_exec1, full_exec2, full_exec3], # All 3
|
||||
)
|
||||
full_library_agent = prisma.models.LibraryAgent(
|
||||
id="lib1",
|
||||
userId="test-user",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
settings="{}", # type: ignore
|
||||
isCreatedByUser=True,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=now - timedelta(days=5),
|
||||
updatedAt=now - timedelta(days=5),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=full_graph,
|
||||
)
|
||||
|
||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||
mock_agent_graph.return_value.find_many = mocker.AsyncMock(return_value=[])
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
# First call: sort-key fetch (execution_limit=1) → returns sort_library_agent
|
||||
# Second call: full metric fetch → returns full_library_agent
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||
side_effect=[
|
||||
[sort_library_agent],
|
||||
[full_library_agent],
|
||||
]
|
||||
)
|
||||
|
||||
result = await db.list_library_agents(
|
||||
"test-user",
|
||||
sort_by=library_model.LibraryAgentSort.LAST_EXECUTED,
|
||||
include_executions=True,
|
||||
)
|
||||
|
||||
assert len(result.agents) == 1
|
||||
agent = result.agents[0]
|
||||
assert agent.id == "lib1"
|
||||
# With the fix: metrics are computed from all 3 executions, not just 1
|
||||
assert agent.execution_count == 3, (
|
||||
"execution_count should reflect the full execution history, not the "
|
||||
"sort-key fetch which used execution_limit=1"
|
||||
)
|
||||
# 2 out of 3 executions are COMPLETED → 66.67%
|
||||
assert agent.success_rate is not None
|
||||
assert (
|
||||
abs(agent.success_rate - 200 / 3) < 0.01
|
||||
), "success_rate should be calculated from all executions"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_library_agents_last_executed_null_updated_at(mocker):
|
||||
"""
|
||||
Test that the LAST_EXECUTED sort gracefully handles executions where updatedAt
|
||||
is None — the sort key should fall back to createdAt instead.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
execution_no_updated = prisma.models.AgentGraphExecution(
|
||||
id="exec-no-updated",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
userId="test-user",
|
||||
createdAt=now - timedelta(hours=2),
|
||||
updatedAt=None,
|
||||
executionStatus=prisma.enums.AgentExecutionStatus.RUNNING,
|
||||
isDeleted=False,
|
||||
isShared=False,
|
||||
)
|
||||
graph1 = prisma.models.AgentGraph(
|
||||
id="agent1",
|
||||
version=1,
|
||||
name="Agent With Null UpdatedAt",
|
||||
description="",
|
||||
userId="test-user",
|
||||
isActive=True,
|
||||
createdAt=now - timedelta(days=1),
|
||||
Executions=[execution_no_updated],
|
||||
)
|
||||
library_agent1 = prisma.models.LibraryAgent(
|
||||
id="lib1",
|
||||
userId="test-user",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
settings="{}", # type: ignore
|
||||
isCreatedByUser=True,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=now - timedelta(days=1),
|
||||
updatedAt=now - timedelta(days=1),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=graph1,
|
||||
)
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=[library_agent1]
|
||||
)
|
||||
|
||||
result = await db.list_library_agents(
|
||||
"test-user",
|
||||
sort_by=library_model.LibraryAgentSort.LAST_EXECUTED,
|
||||
)
|
||||
|
||||
assert len(result.agents) == 1
|
||||
assert result.agents[0].id == "lib1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_library_agents_last_executed_none_agent_graph(mocker):
|
||||
"""
|
||||
Test that the LAST_EXECUTED sort safely handles agents where AgentGraph is None.
|
||||
Such agents should fall to the bottom (treated as no executions).
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
agent_no_graph = prisma.models.LibraryAgent(
|
||||
id="lib-no-graph",
|
||||
userId="test-user",
|
||||
agentGraphId="agent-gone",
|
||||
agentGraphVersion=1,
|
||||
settings="{}", # type: ignore
|
||||
isCreatedByUser=True,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=now - timedelta(days=1),
|
||||
updatedAt=now - timedelta(days=1),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=None,
|
||||
)
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=[agent_no_graph]
|
||||
)
|
||||
|
||||
result = await db.list_library_agents(
|
||||
"test-user",
|
||||
sort_by=library_model.LibraryAgentSort.LAST_EXECUTED,
|
||||
)
|
||||
|
||||
assert (
|
||||
len(result.agents) == 0
|
||||
), "Agent with no graph should be skipped (from_db will fail gracefully)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_library_agents_last_executed_pagination(mocker):
|
||||
"""
|
||||
Test that LAST_EXECUTED sort correctly applies in-memory pagination:
|
||||
page 1 returns first page_size agents, page 2 returns the next batch,
|
||||
and agent_count reflects the total across all pages.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
def make_agent(agent_id: str, lib_id: str, hours_ago: int):
|
||||
execution = prisma.models.AgentGraphExecution(
|
||||
id=f"exec-{agent_id}",
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=1,
|
||||
userId="test-user",
|
||||
createdAt=now - timedelta(hours=hours_ago + 1),
|
||||
updatedAt=now - timedelta(hours=hours_ago),
|
||||
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||
isDeleted=False,
|
||||
isShared=False,
|
||||
)
|
||||
graph = prisma.models.AgentGraph(
|
||||
id=agent_id,
|
||||
version=1,
|
||||
name=f"Agent {agent_id}",
|
||||
description="",
|
||||
userId="test-user",
|
||||
isActive=True,
|
||||
createdAt=now - timedelta(days=3),
|
||||
Executions=[execution],
|
||||
)
|
||||
return prisma.models.LibraryAgent(
|
||||
id=lib_id,
|
||||
userId="test-user",
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=1,
|
||||
settings="{}", # type: ignore
|
||||
isCreatedByUser=True,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=now - timedelta(days=3),
|
||||
updatedAt=now - timedelta(days=3),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=graph,
|
||||
)
|
||||
|
||||
# 3 agents, ordered newest-first by execution time: lib1, lib2, lib3
|
||||
agents = [
|
||||
make_agent("a1", "lib1", hours_ago=1),
|
||||
make_agent("a2", "lib2", hours_ago=2),
|
||||
make_agent("a3", "lib3", hours_ago=3),
|
||||
]
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(return_value=agents)
|
||||
|
||||
result_page1 = await db.list_library_agents(
|
||||
"test-user",
|
||||
sort_by=library_model.LibraryAgentSort.LAST_EXECUTED,
|
||||
page=1,
|
||||
page_size=2,
|
||||
)
|
||||
result_page2 = await db.list_library_agents(
|
||||
"test-user",
|
||||
sort_by=library_model.LibraryAgentSort.LAST_EXECUTED,
|
||||
page=2,
|
||||
page_size=2,
|
||||
)
|
||||
|
||||
assert result_page1.pagination.total_items == 3
|
||||
assert result_page1.pagination.total_pages == 2
|
||||
assert len(result_page1.agents) == 2
|
||||
assert result_page1.agents[0].id == "lib1"
|
||||
assert result_page1.agents[1].id == "lib2"
|
||||
|
||||
assert len(result_page2.agents) == 1
|
||||
assert result_page2.agents[0].id == "lib3"
|
||||
|
||||
@@ -539,6 +539,7 @@ class LibraryAgentSort(str, Enum):
|
||||
|
||||
CREATED_AT = "createdAt"
|
||||
UPDATED_AT = "updatedAt"
|
||||
LAST_EXECUTED = "lastExecuted"
|
||||
|
||||
|
||||
class LibraryAgentUpdateRequest(pydantic.BaseModel):
|
||||
|
||||
@@ -9,7 +9,7 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.data.db import query_raw_with_schema, transaction
|
||||
from backend.data.graph import (
|
||||
GraphModel,
|
||||
GraphModelWithoutNodes,
|
||||
@@ -104,7 +104,8 @@ async def get_store_agents(
|
||||
# search_used_hybrid remains False, will use fallback path below
|
||||
|
||||
# Convert hybrid search results (dict format) if hybrid succeeded
|
||||
if search_used_hybrid:
|
||||
# Fall through to direct DB search if hybrid returned nothing
|
||||
if search_used_hybrid and agents:
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
@@ -130,52 +131,20 @@ async def get_store_agents(
|
||||
)
|
||||
continue
|
||||
|
||||
if not search_used_hybrid:
|
||||
# Fallback path - use basic search or no search
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
where_clause["creator_username"] = {"in": creators}
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
# Add basic text search if search_query provided but hybrid failed
|
||||
if search_query:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"sub_heading": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
if sorted_by == StoreAgentsSortOptions.RATING:
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.RUNS:
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.NAME:
|
||||
order_by.append({"agent_name": "asc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.UPDATED_AT:
|
||||
order_by.append({"updated_at": "desc"})
|
||||
|
||||
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
if not search_used_hybrid or not agents:
|
||||
# Fallback path: direct DB query with optional tsvector search.
|
||||
# This mirrors the original pre-hybrid-search implementation.
|
||||
store_agents, total = await _fallback_store_agent_search(
|
||||
search_query=search_query,
|
||||
featured=featured,
|
||||
creators=creators,
|
||||
category=category,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in db_agents:
|
||||
try:
|
||||
store_agents.append(store_model.StoreAgent.from_db(agent))
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing StoreAgent from db: {e}")
|
||||
continue
|
||||
|
||||
logger.debug(f"Found {len(store_agents)} agents")
|
||||
return store_model.StoreAgentsResponse(
|
||||
agents=store_agents,
|
||||
@@ -195,6 +164,126 @@ async def get_store_agents(
|
||||
# await log_search_term(search_query=search_term)
|
||||
|
||||
|
||||
async def _fallback_store_agent_search(
|
||||
*,
|
||||
search_query: str | None,
|
||||
featured: bool,
|
||||
creators: list[str] | None,
|
||||
category: str | None,
|
||||
sorted_by: StoreAgentsSortOptions | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
) -> tuple[list[store_model.StoreAgent], int]:
|
||||
"""Direct DB search fallback when hybrid search is unavailable or empty.
|
||||
|
||||
Uses ad-hoc to_tsvector/plainto_tsquery with ts_rank_cd for text search,
|
||||
matching the quality of the original pre-hybrid-search implementation.
|
||||
Falls back to simple listing when no search query is provided.
|
||||
"""
|
||||
if not search_query:
|
||||
# No search query — use Prisma for simple filtered listing
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
where_clause["creator_username"] = {"in": creators}
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
order_by = []
|
||||
if sorted_by == StoreAgentsSortOptions.RATING:
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.RUNS:
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.NAME:
|
||||
order_by.append({"agent_name": "asc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.UPDATED_AT:
|
||||
order_by.append({"updated_at": "desc"})
|
||||
|
||||
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
return [store_model.StoreAgent.from_db(a) for a in db_agents], total
|
||||
|
||||
# Text search using ad-hoc tsvector on StoreAgent view fields
|
||||
params: list[Any] = [search_query]
|
||||
filters = ["sa.is_available = true"]
|
||||
param_idx = 2
|
||||
|
||||
if featured:
|
||||
filters.append("sa.featured = true")
|
||||
if creators:
|
||||
params.append(creators)
|
||||
filters.append(f"sa.creator_username = ANY(${param_idx})")
|
||||
param_idx += 1
|
||||
if category:
|
||||
params.append(category)
|
||||
filters.append(f"${param_idx} = ANY(sa.categories)")
|
||||
param_idx += 1
|
||||
|
||||
where_sql = " AND ".join(filters)
|
||||
|
||||
params.extend([page_size, (page - 1) * page_size])
|
||||
limit_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
offset_param = f"${param_idx}"
|
||||
|
||||
sql = f"""
|
||||
WITH ranked AS (
|
||||
SELECT sa.*,
|
||||
ts_rank_cd(
|
||||
to_tsvector('english',
|
||||
COALESCE(sa.agent_name, '') || ' ' ||
|
||||
COALESCE(sa.sub_heading, '') || ' ' ||
|
||||
COALESCE(sa.description, '')
|
||||
),
|
||||
plainto_tsquery('english', $1)
|
||||
) AS rank,
|
||||
COUNT(*) OVER () AS total_count
|
||||
FROM {{schema_prefix}}"StoreAgent" sa
|
||||
WHERE {where_sql}
|
||||
AND to_tsvector('english',
|
||||
COALESCE(sa.agent_name, '') || ' ' ||
|
||||
COALESCE(sa.sub_heading, '') || ' ' ||
|
||||
COALESCE(sa.description, '')
|
||||
) @@ plainto_tsquery('english', $1)
|
||||
)
|
||||
SELECT * FROM ranked
|
||||
ORDER BY rank DESC
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
results = await query_raw_with_schema(sql, *params)
|
||||
total = results[0]["total_count"] if results else 0
|
||||
|
||||
store_agents = []
|
||||
for row in results:
|
||||
try:
|
||||
store_agents.append(
|
||||
store_model.StoreAgent(
|
||||
slug=row["slug"],
|
||||
agent_name=row["agent_name"],
|
||||
agent_image=row["agent_image"][0] if row["agent_image"] else "",
|
||||
creator=row["creator_username"] or "Needs Profile",
|
||||
creator_avatar=row["creator_avatar"] or "",
|
||||
sub_heading=row["sub_heading"],
|
||||
description=row["description"],
|
||||
runs=row["runs"],
|
||||
rating=row["rating"],
|
||||
agent_graph_id=row.get("graph_id", ""),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing StoreAgent from fallback search: {e}")
|
||||
continue
|
||||
|
||||
return store_agents, total
|
||||
|
||||
|
||||
async def log_search_term(search_query: str):
|
||||
"""Log a search term to the database"""
|
||||
|
||||
@@ -1139,16 +1228,21 @@ async def review_store_submission(
|
||||
},
|
||||
)
|
||||
|
||||
# Generate embedding for approved listing (blocking - admin operation)
|
||||
# Inside transaction: if embedding fails, entire transaction rolls back
|
||||
await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=submission.name,
|
||||
description=submission.description,
|
||||
sub_heading=submission.subHeading,
|
||||
categories=submission.categories,
|
||||
tx=tx,
|
||||
)
|
||||
# Generate embedding for approved listing (best-effort)
|
||||
try:
|
||||
await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=submission.name,
|
||||
description=submission.description,
|
||||
sub_heading=submission.subHeading,
|
||||
categories=submission.categories,
|
||||
tx=tx,
|
||||
)
|
||||
except Exception as emb_err:
|
||||
logger.warning(
|
||||
f"Could not generate embedding for listing "
|
||||
f"{store_listing_version_id}: {emb_err}"
|
||||
)
|
||||
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": submission.storeListingId},
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import tempfile
|
||||
import urllib.parse
|
||||
|
||||
import autogpt_libs.auth
|
||||
@@ -259,21 +258,18 @@ async def get_graph_meta_by_store_listing_version_id(
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str,
|
||||
) -> fastapi.responses.FileResponse:
|
||||
) -> fastapi.responses.Response:
|
||||
"""Download agent graph file for a specific marketplace listing version"""
|
||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
return fastapi.responses.Response(
|
||||
content=backend.util.json.dumps(graph_data),
|
||||
media_type="application/json",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{file_name}"',
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
|
||||
@@ -24,7 +24,7 @@ from fastapi import (
|
||||
UploadFile,
|
||||
)
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
@@ -55,7 +55,11 @@ from backend.data.credit import (
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.invited_user import get_or_activate_user
|
||||
from backend.data.invited_user import (
|
||||
check_invite_eligibility,
|
||||
get_or_activate_user,
|
||||
is_internal_email,
|
||||
)
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
@@ -70,6 +74,7 @@ from backend.data.onboarding import (
|
||||
reset_user_onboarding,
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.user import (
|
||||
get_user_by_id,
|
||||
get_user_notification_preference,
|
||||
@@ -129,6 +134,69 @@ v1_router = APIRouter()
|
||||
_tally_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
class CheckInviteRequest(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class CheckInviteResponse(BaseModel):
|
||||
allowed: bool
|
||||
|
||||
|
||||
_CHECK_INVITE_RATE_LIMIT = 10 # requests
|
||||
_CHECK_INVITE_RATE_WINDOW = 60 # seconds
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/check-invite",
|
||||
summary="Check if an email is allowed to sign up",
|
||||
tags=["auth"],
|
||||
)
|
||||
async def check_invite_route(
|
||||
http_request: Request,
|
||||
request: CheckInviteRequest,
|
||||
) -> CheckInviteResponse:
|
||||
"""Check if an email is allowed to sign up (no auth required).
|
||||
|
||||
Called by the frontend before creating a Supabase auth user to prevent
|
||||
orphaned accounts when the invite gate is enabled.
|
||||
"""
|
||||
client_ip = (
|
||||
http_request.headers.get("x-forwarded-for", "").split(",")[0].strip()
|
||||
or http_request.headers.get("x-real-ip", "")
|
||||
or (http_request.client.host if http_request.client else "unknown")
|
||||
)
|
||||
rate_key = f"rate:check-invite:{client_ip}"
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
# Use a pipeline so that incr + expire are sent atomically.
|
||||
# This prevents the key from persisting indefinitely when expire fails
|
||||
# after a successful incr (which would permanently block the IP once
|
||||
# the count exceeds the limit).
|
||||
# NOTE: pipeline command methods (incr, expire) are NOT awaitable —
|
||||
# they queue the command and return the pipeline. Only execute() is
|
||||
# awaited, which flushes all queued commands in a single round-trip.
|
||||
pipe = redis.pipeline()
|
||||
pipe.incr(rate_key)
|
||||
pipe.expire(rate_key, _CHECK_INVITE_RATE_WINDOW)
|
||||
results = await pipe.execute()
|
||||
count = results[0]
|
||||
if count > _CHECK_INVITE_RATE_LIMIT:
|
||||
raise HTTPException(status_code=429, detail="Too many requests")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.debug("Rate limit check failed for check-invite, failing open")
|
||||
|
||||
if not settings.config.enable_invite_gate:
|
||||
return CheckInviteResponse(allowed=True)
|
||||
|
||||
if is_internal_email(request.email):
|
||||
return CheckInviteResponse(allowed=True)
|
||||
|
||||
allowed = await check_invite_eligibility(request.email)
|
||||
return CheckInviteResponse(allowed=allowed)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/user",
|
||||
summary="Get or create user",
|
||||
|
||||
@@ -35,6 +35,102 @@ def setup_app_auth(mock_jwt_user, setup_test_user):
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# check_invite_route tests
|
||||
|
||||
_RATE_LIMIT_PATCH = "backend.api.features.v1.get_redis_async"
|
||||
|
||||
|
||||
def _make_redis_mock(count: int = 1) -> AsyncMock:
|
||||
"""Return a mock Redis client that reports `count` for the rate-limit key.
|
||||
|
||||
The route uses a pipeline where incr/expire are synchronous (they queue
|
||||
commands and return the pipeline) and only execute() is awaited.
|
||||
"""
|
||||
mock_pipe = Mock()
|
||||
mock_pipe.incr = Mock(return_value=mock_pipe)
|
||||
mock_pipe.expire = Mock(return_value=mock_pipe)
|
||||
mock_pipe.execute = AsyncMock(return_value=[count, True])
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = Mock(return_value=mock_pipe)
|
||||
return mock_redis
|
||||
|
||||
|
||||
def test_check_invite_gate_disabled(mocker: pytest_mock.MockFixture) -> None:
|
||||
"""When enable_invite_gate is False every email is allowed."""
|
||||
mocker.patch(_RATE_LIMIT_PATCH, return_value=_make_redis_mock())
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.settings",
|
||||
Mock(config=Mock(enable_invite_gate=False)),
|
||||
)
|
||||
|
||||
response = client.post("/auth/check-invite", json={"email": "anyone@example.com"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"allowed": True}
|
||||
|
||||
|
||||
def test_check_invite_internal_email_bypasses_gate(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""@agpt.co addresses bypass the gate even when it is enabled."""
|
||||
mocker.patch(_RATE_LIMIT_PATCH, return_value=_make_redis_mock())
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.settings",
|
||||
Mock(config=Mock(enable_invite_gate=True)),
|
||||
)
|
||||
|
||||
response = client.post("/auth/check-invite", json={"email": "employee@agpt.co"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"allowed": True}
|
||||
|
||||
|
||||
def test_check_invite_eligible_email(mocker: pytest_mock.MockFixture) -> None:
|
||||
"""An email with INVITED status is allowed when the gate is enabled."""
|
||||
mocker.patch(_RATE_LIMIT_PATCH, return_value=_make_redis_mock())
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.settings",
|
||||
Mock(config=Mock(enable_invite_gate=True)),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.check_invite_eligibility",
|
||||
new=AsyncMock(return_value=True),
|
||||
)
|
||||
|
||||
response = client.post("/auth/check-invite", json={"email": "invited@example.com"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"allowed": True}
|
||||
|
||||
|
||||
def test_check_invite_ineligible_email(mocker: pytest_mock.MockFixture) -> None:
|
||||
"""An email without an active invite is denied when the gate is enabled."""
|
||||
mocker.patch(_RATE_LIMIT_PATCH, return_value=_make_redis_mock())
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.settings",
|
||||
Mock(config=Mock(enable_invite_gate=True)),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.check_invite_eligibility",
|
||||
new=AsyncMock(return_value=False),
|
||||
)
|
||||
|
||||
response = client.post("/auth/check-invite", json={"email": "stranger@example.com"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"allowed": False}
|
||||
|
||||
|
||||
def test_check_invite_rate_limit_exceeded(mocker: pytest_mock.MockFixture) -> None:
|
||||
"""Requests beyond the per-IP rate limit receive HTTP 429."""
|
||||
mocker.patch(_RATE_LIMIT_PATCH, return_value=_make_redis_mock(count=11))
|
||||
|
||||
response = client.post("/auth/check-invite", json={"email": "flood@example.com"})
|
||||
|
||||
assert response.status_code == 429
|
||||
|
||||
|
||||
# Auth endpoints tests
|
||||
def test_get_or_create_user_route(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
Shared configuration for all AgentMail blocks.
|
||||
"""
|
||||
|
||||
from agentmail import AsyncAgentMail
|
||||
|
||||
from backend.sdk import APIKeyCredentials, ProviderBuilder, SecretStr
|
||||
|
||||
agent_mail = (
|
||||
ProviderBuilder("agent_mail")
|
||||
.with_api_key("AGENTMAIL_API_KEY", "AgentMail API Key")
|
||||
.build()
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="agent_mail",
|
||||
title="Mock AgentMail API Key",
|
||||
api_key=SecretStr("mock-agentmail-api-key"),
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def _client(credentials: APIKeyCredentials) -> AsyncAgentMail:
|
||||
"""Create an AsyncAgentMail client from credentials."""
|
||||
return AsyncAgentMail(api_key=credentials.api_key.get_secret_value())
|
||||
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
AgentMail Attachment blocks — download file attachments from messages and threads.
|
||||
|
||||
Attachments are files associated with messages (PDFs, CSVs, images, etc.).
|
||||
To send attachments, include them in the attachments parameter when using
|
||||
AgentMailSendMessageBlock or AgentMailReplyToMessageBlock.
|
||||
|
||||
To download, first get the attachment_id from a message's attachments array,
|
||||
then use these blocks to retrieve the file content as base64.
|
||||
"""
|
||||
|
||||
import base64
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailGetMessageAttachmentBlock(Block):
|
||||
"""
|
||||
Download a file attachment from a specific email message.
|
||||
|
||||
Retrieves the raw file content and returns it as base64-encoded data.
|
||||
First get the attachment_id from a message object's attachments array,
|
||||
then use this block to download the file.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the message belongs to"
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID containing the attachment"
|
||||
)
|
||||
attachment_id: str = SchemaField(
|
||||
description="Attachment ID to download (from the message's attachments array)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
content_base64: str = SchemaField(
|
||||
description="File content encoded as a base64 string. Decode with base64.b64decode() to get raw bytes."
|
||||
)
|
||||
attachment_id: str = SchemaField(
|
||||
description="The attachment ID that was downloaded"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a283ffc4-8087-4c3d-9135-8f26b86742ec",
|
||||
description="Download a file attachment from an email message. Returns base64-encoded file content.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"attachment_id": "test-attach",
|
||||
},
|
||||
test_output=[
|
||||
("content_base64", "dGVzdA=="),
|
||||
("attachment_id", "test-attach"),
|
||||
],
|
||||
test_mock={
|
||||
"get_attachment": lambda *a, **kw: b"test",
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_attachment(
|
||||
credentials: APIKeyCredentials,
|
||||
inbox_id: str,
|
||||
message_id: str,
|
||||
attachment_id: str,
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.get_attachment(
|
||||
inbox_id=inbox_id,
|
||||
message_id=message_id,
|
||||
attachment_id=attachment_id,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.get_attachment(
|
||||
credentials=credentials,
|
||||
inbox_id=input_data.inbox_id,
|
||||
message_id=input_data.message_id,
|
||||
attachment_id=input_data.attachment_id,
|
||||
)
|
||||
if isinstance(data, bytes):
|
||||
encoded = base64.b64encode(data).decode()
|
||||
elif isinstance(data, str):
|
||||
encoded = base64.b64encode(data.encode("utf-8")).decode()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unexpected attachment data type: {type(data).__name__}"
|
||||
)
|
||||
|
||||
yield "content_base64", encoded
|
||||
yield "attachment_id", input_data.attachment_id
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetThreadAttachmentBlock(Block):
|
||||
"""
|
||||
Download a file attachment from a conversation thread.
|
||||
|
||||
Same as GetMessageAttachment but looks up by thread ID instead of
|
||||
message ID. Useful when you know the thread but not the specific
|
||||
message containing the attachment.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the thread belongs to"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID containing the attachment")
|
||||
attachment_id: str = SchemaField(
|
||||
description="Attachment ID to download (from a message's attachments array within the thread)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
content_base64: str = SchemaField(
|
||||
description="File content encoded as a base64 string. Decode with base64.b64decode() to get raw bytes."
|
||||
)
|
||||
attachment_id: str = SchemaField(
|
||||
description="The attachment ID that was downloaded"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="06b6a4c4-9d71-4992-9e9c-cf3b352763b5",
|
||||
description="Download a file attachment from a conversation thread. Returns base64-encoded file content.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"thread_id": "test-thread",
|
||||
"attachment_id": "test-attach",
|
||||
},
|
||||
test_output=[
|
||||
("content_base64", "dGVzdA=="),
|
||||
("attachment_id", "test-attach"),
|
||||
],
|
||||
test_mock={
|
||||
"get_attachment": lambda *a, **kw: b"test",
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_attachment(
|
||||
credentials: APIKeyCredentials,
|
||||
inbox_id: str,
|
||||
thread_id: str,
|
||||
attachment_id: str,
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.threads.get_attachment(
|
||||
inbox_id=inbox_id,
|
||||
thread_id=thread_id,
|
||||
attachment_id=attachment_id,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.get_attachment(
|
||||
credentials=credentials,
|
||||
inbox_id=input_data.inbox_id,
|
||||
thread_id=input_data.thread_id,
|
||||
attachment_id=input_data.attachment_id,
|
||||
)
|
||||
if isinstance(data, bytes):
|
||||
encoded = base64.b64encode(data).decode()
|
||||
elif isinstance(data, str):
|
||||
encoded = base64.b64encode(data.encode("utf-8")).decode()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unexpected attachment data type: {type(data).__name__}"
|
||||
)
|
||||
|
||||
yield "content_base64", encoded
|
||||
yield "attachment_id", input_data.attachment_id
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
678
autogpt_platform/backend/backend/blocks/agent_mail/drafts.py
Normal file
678
autogpt_platform/backend/backend/blocks/agent_mail/drafts.py
Normal file
@@ -0,0 +1,678 @@
|
||||
"""
|
||||
AgentMail Draft blocks — create, get, list, update, send, and delete drafts.
|
||||
|
||||
A Draft is an unsent message that can be reviewed, edited, and sent later.
|
||||
Drafts enable human-in-the-loop review, scheduled sending (via send_at),
|
||||
and complex multi-step email composition workflows.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailCreateDraftBlock(Block):
|
||||
"""
|
||||
Create a draft email in an AgentMail inbox for review or scheduled sending.
|
||||
|
||||
Drafts let agents prepare emails without sending immediately. Use send_at
|
||||
to schedule automatic sending at a future time (ISO 8601 format).
|
||||
Scheduled drafts are auto-labeled 'scheduled' and can be cancelled by
|
||||
deleting the draft.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to create the draft in"
|
||||
)
|
||||
to: list[str] = SchemaField(
|
||||
description="Recipient email addresses (e.g. ['user@example.com'])"
|
||||
)
|
||||
subject: str = SchemaField(description="Email subject line", default="")
|
||||
text: str = SchemaField(description="Plain text body of the draft", default="")
|
||||
html: str = SchemaField(
|
||||
description="Rich HTML body of the draft", default="", advanced=True
|
||||
)
|
||||
cc: list[str] = SchemaField(
|
||||
description="CC recipient email addresses",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
bcc: list[str] = SchemaField(
|
||||
description="BCC recipient email addresses",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
in_reply_to: str = SchemaField(
|
||||
description="Message ID this draft replies to, for threading follow-up drafts",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
send_at: str = SchemaField(
|
||||
description="Schedule automatic sending at this ISO 8601 datetime (e.g. '2025-01-15T09:00:00Z'). Leave empty for manual send.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
draft_id: str = SchemaField(
|
||||
description="Unique identifier of the created draft"
|
||||
)
|
||||
send_status: str = SchemaField(
|
||||
description="'scheduled' if send_at was set, empty otherwise. Values: scheduled, sending, failed.",
|
||||
default="",
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete draft object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="25ac9086-69fd-48b8-b910-9dbe04b8f3bd",
|
||||
description="Create a draft email for review or scheduled sending. Use send_at for automatic future delivery.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"to": ["user@example.com"],
|
||||
},
|
||||
test_output=[
|
||||
("draft_id", "mock-draft-id"),
|
||||
("send_status", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_draft": lambda *a, **kw: type(
|
||||
"Draft",
|
||||
(),
|
||||
{
|
||||
"draft_id": "mock-draft-id",
|
||||
"send_status": "",
|
||||
"model_dump": lambda self: {"draft_id": "mock-draft-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_draft(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.create(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"to": input_data.to}
|
||||
if input_data.subject:
|
||||
params["subject"] = input_data.subject
|
||||
if input_data.text:
|
||||
params["text"] = input_data.text
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
if input_data.cc:
|
||||
params["cc"] = input_data.cc
|
||||
if input_data.bcc:
|
||||
params["bcc"] = input_data.bcc
|
||||
if input_data.in_reply_to:
|
||||
params["in_reply_to"] = input_data.in_reply_to
|
||||
if input_data.send_at:
|
||||
params["send_at"] = input_data.send_at
|
||||
|
||||
draft = await self.create_draft(credentials, input_data.inbox_id, **params)
|
||||
result = draft.model_dump()
|
||||
|
||||
yield "draft_id", draft.draft_id
|
||||
yield "send_status", draft.send_status or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetDraftBlock(Block):
|
||||
"""
|
||||
Retrieve a specific draft from an AgentMail inbox.
|
||||
|
||||
Returns the draft contents including recipients, subject, body, and
|
||||
scheduled send status. Use this to review a draft before approving it.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(description="Draft ID to retrieve")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
draft_id: str = SchemaField(description="Unique identifier of the draft")
|
||||
subject: str = SchemaField(description="Draft subject line", default="")
|
||||
send_status: str = SchemaField(
|
||||
description="Scheduled send status: 'scheduled', 'sending', 'failed', or empty",
|
||||
default="",
|
||||
)
|
||||
send_at: str = SchemaField(
|
||||
description="Scheduled send time (ISO 8601) if set", default=""
|
||||
)
|
||||
result: dict = SchemaField(description="Complete draft object with all fields")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8e57780d-dc25-43d4-a0f4-1f02877b09fb",
|
||||
description="Retrieve a draft email to review its contents, recipients, and scheduled send status.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[
|
||||
("draft_id", "test-draft"),
|
||||
("subject", ""),
|
||||
("send_status", ""),
|
||||
("send_at", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_draft": lambda *a, **kw: type(
|
||||
"Draft",
|
||||
(),
|
||||
{
|
||||
"draft_id": "test-draft",
|
||||
"subject": "",
|
||||
"send_status": "",
|
||||
"send_at": "",
|
||||
"model_dump": lambda self: {"draft_id": "test-draft"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_draft(credentials: APIKeyCredentials, inbox_id: str, draft_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.get(inbox_id=inbox_id, draft_id=draft_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
draft = await self.get_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id
|
||||
)
|
||||
result = draft.model_dump()
|
||||
|
||||
yield "draft_id", draft.draft_id
|
||||
yield "subject", draft.subject or ""
|
||||
yield "send_status", draft.send_status or ""
|
||||
yield "send_at", draft.send_at or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListDraftsBlock(Block):
|
||||
"""
|
||||
List all drafts in an AgentMail inbox with optional label filtering.
|
||||
|
||||
Use labels=['scheduled'] to find all drafts queued for future sending.
|
||||
Useful for building approval dashboards or monitoring pending outreach.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to list drafts from"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of drafts to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Filter drafts by labels (e.g. ['scheduled'] for pending sends)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
drafts: list[dict] = SchemaField(
|
||||
description="List of draft objects with subject, recipients, send_status, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of drafts returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e84883b7-7c39-4c5c-88e8-0a72b078ea63",
|
||||
description="List drafts in an AgentMail inbox. Filter by labels=['scheduled'] to find pending sends.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("drafts", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_drafts": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"drafts": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_drafts(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.list(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_drafts(
|
||||
credentials, input_data.inbox_id, **params
|
||||
)
|
||||
drafts = [d.model_dump() for d in response.drafts]
|
||||
|
||||
yield "drafts", drafts
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailUpdateDraftBlock(Block):
|
||||
"""
|
||||
Update an existing draft's content, recipients, or scheduled send time.
|
||||
|
||||
Use this to reschedule a draft (change send_at), modify recipients,
|
||||
or edit the subject/body before sending. To cancel a scheduled send,
|
||||
delete the draft instead.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(description="Draft ID to update")
|
||||
to: Optional[list[str]] = SchemaField(
|
||||
description="Updated recipient email addresses (replaces existing list). Omit to keep current value.",
|
||||
default=None,
|
||||
)
|
||||
subject: Optional[str] = SchemaField(
|
||||
description="Updated subject line. Omit to keep current value.",
|
||||
default=None,
|
||||
)
|
||||
text: Optional[str] = SchemaField(
|
||||
description="Updated plain text body. Omit to keep current value.",
|
||||
default=None,
|
||||
)
|
||||
html: Optional[str] = SchemaField(
|
||||
description="Updated HTML body. Omit to keep current value.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
send_at: Optional[str] = SchemaField(
|
||||
description="Reschedule: new ISO 8601 send time (e.g. '2025-01-20T14:00:00Z'). Omit to keep current value.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
draft_id: str = SchemaField(description="The updated draft ID")
|
||||
send_status: str = SchemaField(description="Updated send status", default="")
|
||||
result: dict = SchemaField(description="Complete updated draft object")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="351f6e51-695a-421a-9032-46a587b10336",
|
||||
description="Update a draft's content, recipients, or scheduled send time. Use to reschedule or edit before sending.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[
|
||||
("draft_id", "test-draft"),
|
||||
("send_status", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"update_draft": lambda *a, **kw: type(
|
||||
"Draft",
|
||||
(),
|
||||
{
|
||||
"draft_id": "test-draft",
|
||||
"send_status": "",
|
||||
"model_dump": lambda self: {"draft_id": "test-draft"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_draft(
|
||||
credentials: APIKeyCredentials, inbox_id: str, draft_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.update(
|
||||
inbox_id=inbox_id, draft_id=draft_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.to is not None:
|
||||
params["to"] = input_data.to
|
||||
if input_data.subject is not None:
|
||||
params["subject"] = input_data.subject
|
||||
if input_data.text is not None:
|
||||
params["text"] = input_data.text
|
||||
if input_data.html is not None:
|
||||
params["html"] = input_data.html
|
||||
if input_data.send_at is not None:
|
||||
params["send_at"] = input_data.send_at
|
||||
|
||||
draft = await self.update_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id, **params
|
||||
)
|
||||
result = draft.model_dump()
|
||||
|
||||
yield "draft_id", draft.draft_id
|
||||
yield "send_status", draft.send_status or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailSendDraftBlock(Block):
|
||||
"""
|
||||
Send a draft immediately, converting it into a delivered message.
|
||||
|
||||
The draft is deleted after successful sending and becomes a regular
|
||||
message with a message_id. Use this for human-in-the-loop approval
|
||||
workflows: agent creates draft, human reviews, then this block sends it.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(description="Draft ID to send now")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID of the now-sent email (draft is deleted)"
|
||||
)
|
||||
thread_id: str = SchemaField(
|
||||
description="Thread ID the sent message belongs to"
|
||||
)
|
||||
result: dict = SchemaField(description="Complete sent message object")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="37c39e83-475d-4b3d-843a-d923d001b85a",
|
||||
description="Send a draft immediately, converting it into a delivered message. The draft is deleted after sending.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-msg-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"send_draft": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-msg-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {"message_id": "mock-msg-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def send_draft(credentials: APIKeyCredentials, inbox_id: str, draft_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.send(inbox_id=inbox_id, draft_id=draft_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
msg = await self.send_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id
|
||||
)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "thread_id", msg.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteDraftBlock(Block):
|
||||
"""
|
||||
Delete a draft from an AgentMail inbox. Also cancels any scheduled send.
|
||||
|
||||
If the draft was scheduled with send_at, deleting it cancels the
|
||||
scheduled delivery. This is the way to cancel a scheduled email.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(
|
||||
description="Draft ID to delete (also cancels scheduled sends)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the draft was successfully deleted/cancelled"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9023eb99-3e2f-4def-808b-d9c584b3d9e7",
|
||||
description="Delete a draft or cancel a scheduled email. Removes the draft permanently.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_draft": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_draft(
|
||||
credentials: APIKeyCredentials, inbox_id: str, draft_id: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
await client.inboxes.drafts.delete(inbox_id=inbox_id, draft_id=draft_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id
|
||||
)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListOrgDraftsBlock(Block):
|
||||
"""
|
||||
List all drafts across every inbox in your organization.
|
||||
|
||||
Returns drafts from all inboxes in one query. Perfect for building
|
||||
a central approval dashboard where a human supervisor can review
|
||||
and approve any draft created by any agent.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of drafts to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
drafts: list[dict] = SchemaField(
|
||||
description="List of draft objects from all inboxes in the organization"
|
||||
)
|
||||
count: int = SchemaField(description="Number of drafts returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ed7558ae-3a07-45f5-af55-a25fe88c9971",
|
||||
description="List all drafts across every inbox in your organization. Use for central approval dashboards.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("drafts", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_org_drafts": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"drafts": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_org_drafts(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.drafts.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_org_drafts(credentials, **params)
|
||||
drafts = [d.model_dump() for d in response.drafts]
|
||||
|
||||
yield "drafts", drafts
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
414
autogpt_platform/backend/backend/blocks/agent_mail/inbox.py
Normal file
414
autogpt_platform/backend/backend/blocks/agent_mail/inbox.py
Normal file
@@ -0,0 +1,414 @@
|
||||
"""
|
||||
AgentMail Inbox blocks — create, get, list, update, and delete inboxes.
|
||||
|
||||
An Inbox is a fully programmable email account for AI agents. Each inbox gets
|
||||
a unique email address and can send, receive, and manage emails via the
|
||||
AgentMail API. You can create thousands of inboxes on demand.
|
||||
"""
|
||||
|
||||
from agentmail.inboxes.types import CreateInboxRequest
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailCreateInboxBlock(Block):
|
||||
"""
|
||||
Create a new email inbox for an AI agent via AgentMail.
|
||||
|
||||
Each inbox gets a unique email address (e.g. username@agentmail.to).
|
||||
If username and domain are not provided, AgentMail auto-generates them.
|
||||
Use custom domains by specifying the domain field.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
username: str = SchemaField(
|
||||
description="Local part of the email address (e.g. 'support' for support@domain.com). Leave empty to auto-generate.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Email domain (e.g. 'mydomain.com'). Defaults to agentmail.to if empty.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
display_name: str = SchemaField(
|
||||
description="Friendly name shown in the 'From' field of sent emails (e.g. 'Support Agent')",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(
|
||||
description="Unique identifier for the created inbox (also the email address)"
|
||||
)
|
||||
email_address: str = SchemaField(
|
||||
description="Full email address of the inbox (e.g. support@agentmail.to)"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7a8ac219-c6ec-4eec-a828-81af283ce04c",
|
||||
description="Create a new email inbox for an AI agent via AgentMail. Each inbox gets a unique address and can send/receive emails.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("inbox_id", "mock-inbox-id"),
|
||||
("email_address", "mock-inbox-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "mock-inbox-id",
|
||||
"model_dump": lambda self: {"inbox_id": "mock-inbox-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_inbox(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.create(request=CreateInboxRequest(**params))
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.username:
|
||||
params["username"] = input_data.username
|
||||
if input_data.domain:
|
||||
params["domain"] = input_data.domain
|
||||
if input_data.display_name:
|
||||
params["display_name"] = input_data.display_name
|
||||
|
||||
inbox = await self.create_inbox(credentials, **params)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "email_address", inbox.inbox_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetInboxBlock(Block):
|
||||
"""
|
||||
Retrieve details of an existing AgentMail inbox by its ID or email address.
|
||||
|
||||
Returns the inbox metadata including email address, display name, and
|
||||
configuration. Use this to check if an inbox exists or get its properties.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to look up (e.g. 'support@agentmail.to')"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(description="Unique identifier of the inbox")
|
||||
email_address: str = SchemaField(description="Full email address of the inbox")
|
||||
display_name: str = SchemaField(
|
||||
description="Friendly name shown in the 'From' field", default=""
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b858f62b-6c12-4736-aaf2-dbc5a9281320",
|
||||
description="Retrieve details of an existing AgentMail inbox including its email address, display name, and configuration.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("inbox_id", "test-inbox"),
|
||||
("email_address", "test-inbox"),
|
||||
("display_name", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "test-inbox",
|
||||
"display_name": "",
|
||||
"model_dump": lambda self: {"inbox_id": "test-inbox"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_inbox(credentials: APIKeyCredentials, inbox_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.get(inbox_id=inbox_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
inbox = await self.get_inbox(credentials, input_data.inbox_id)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "email_address", inbox.inbox_id
|
||||
yield "display_name", inbox.display_name or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListInboxesBlock(Block):
|
||||
"""
|
||||
List all email inboxes in your AgentMail organization.
|
||||
|
||||
Returns a paginated list of all inboxes with their metadata.
|
||||
Use page_token for pagination when you have many inboxes.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of inboxes to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page of results",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inboxes: list[dict] = SchemaField(
|
||||
description="List of inbox objects, each containing inbox_id, email_address, display_name, etc."
|
||||
)
|
||||
count: int = SchemaField(
|
||||
description="Total number of inboxes in your organization"
|
||||
)
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token to pass as page_token to get the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="cfd84a06-2121-4cef-8d14-8badf52d22f0",
|
||||
description="List all email inboxes in your AgentMail organization with pagination support.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("inboxes", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_inboxes": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"inboxes": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_inboxes(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_inboxes(credentials, **params)
|
||||
inboxes = [i.model_dump() for i in response.inboxes]
|
||||
|
||||
yield "inboxes", inboxes
|
||||
yield "count", (c if (c := response.count) is not None else len(inboxes))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailUpdateInboxBlock(Block):
|
||||
"""
|
||||
Update the display name of an existing AgentMail inbox.
|
||||
|
||||
Changes the friendly name shown in the 'From' field when emails are sent
|
||||
from this inbox. The email address itself cannot be changed.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to update (e.g. 'support@agentmail.to')"
|
||||
)
|
||||
display_name: str = SchemaField(
|
||||
description="New display name for the inbox (e.g. 'Customer Support Bot')"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(description="The updated inbox ID")
|
||||
result: dict = SchemaField(
|
||||
description="Complete updated inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="59b49f59-a6d1-4203-94c0-3908adac50b6",
|
||||
description="Update the display name of an AgentMail inbox. Changes the 'From' name shown when emails are sent.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"display_name": "Updated",
|
||||
},
|
||||
test_output=[
|
||||
("inbox_id", "test-inbox"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"update_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "test-inbox",
|
||||
"model_dump": lambda self: {"inbox_id": "test-inbox"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_inbox(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.update(inbox_id=inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
inbox = await self.update_inbox(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
display_name=input_data.display_name,
|
||||
)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteInboxBlock(Block):
|
||||
"""
|
||||
Permanently delete an AgentMail inbox and all its data.
|
||||
|
||||
This removes the inbox, all its messages, threads, and drafts.
|
||||
This action cannot be undone. The email address will no longer
|
||||
receive or send emails.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to permanently delete"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the inbox was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ade970ae-8428-4a7b-9278-b52054dbf535",
|
||||
description="Permanently delete an AgentMail inbox and all its messages, threads, and drafts. This action cannot be undone.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_inbox": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_inbox(credentials: APIKeyCredentials, inbox_id: str):
|
||||
client = _client(credentials)
|
||||
await client.inboxes.delete(inbox_id=inbox_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_inbox(credentials, input_data.inbox_id)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
384
autogpt_platform/backend/backend/blocks/agent_mail/lists.py
Normal file
384
autogpt_platform/backend/backend/blocks/agent_mail/lists.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
AgentMail List blocks — manage allow/block lists for email filtering.
|
||||
|
||||
Lists let you control which email addresses and domains your agents can
|
||||
send to or receive from. There are four list types based on two dimensions:
|
||||
direction (send/receive) and type (allow/block).
|
||||
|
||||
- receive + allow: Only accept emails from these addresses/domains
|
||||
- receive + block: Reject emails from these addresses/domains
|
||||
- send + allow: Only send emails to these addresses/domains
|
||||
- send + block: Prevent sending emails to these addresses/domains
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class ListDirection(str, Enum):
|
||||
SEND = "send"
|
||||
RECEIVE = "receive"
|
||||
|
||||
|
||||
class ListType(str, Enum):
|
||||
ALLOW = "allow"
|
||||
BLOCK = "block"
|
||||
|
||||
|
||||
class AgentMailListEntriesBlock(Block):
|
||||
"""
|
||||
List all entries in an AgentMail allow/block list.
|
||||
|
||||
Retrieves email addresses and domains that are currently allowed
|
||||
or blocked for sending or receiving. Use direction and list_type
|
||||
to select which of the four lists to query.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' to filter outgoing emails, 'receive' to filter incoming emails"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' for whitelist (only permit these), 'block' for blacklist (reject these)"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of entries to return per page",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
entries: list[dict] = SchemaField(
|
||||
description="List of entries, each with an email address or domain"
|
||||
)
|
||||
count: int = SchemaField(description="Number of entries returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="01489100-35da-45aa-8a01-9540ba0e9a21",
|
||||
description="List all entries in an AgentMail allow/block list. Choose send/receive direction and allow/block type.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
},
|
||||
test_output=[
|
||||
("entries", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_entries": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"entries": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_entries(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.lists.list(direction, list_type, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_entries(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
**params,
|
||||
)
|
||||
entries = [e.model_dump() for e in response.entries]
|
||||
|
||||
yield "entries", entries
|
||||
yield "count", (c if (c := response.count) is not None else len(entries))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailCreateListEntryBlock(Block):
|
||||
"""
|
||||
Add an email address or domain to an AgentMail allow/block list.
|
||||
|
||||
Entries can be full email addresses (e.g. 'partner@example.com') or
|
||||
entire domains (e.g. 'example.com'). For block lists, you can optionally
|
||||
provide a reason (e.g. 'spam', 'competitor').
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' for outgoing email rules, 'receive' for incoming email rules"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' to whitelist, 'block' to blacklist"
|
||||
)
|
||||
entry: str = SchemaField(
|
||||
description="Email address (user@example.com) or domain (example.com) to add"
|
||||
)
|
||||
reason: str = SchemaField(
|
||||
description="Reason for blocking (only used with block lists, e.g. 'spam', 'competitor')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
entry: str = SchemaField(
|
||||
description="The email address or domain that was added"
|
||||
)
|
||||
result: dict = SchemaField(description="Complete entry object")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b6650a0a-b113-40cf-8243-ff20f684f9b8",
|
||||
description="Add an email address or domain to an allow/block list. Block spam senders or whitelist trusted domains.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
"entry": "spam@example.com",
|
||||
},
|
||||
test_output=[
|
||||
("entry", "spam@example.com"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_entry": lambda *a, **kw: type(
|
||||
"Entry",
|
||||
(),
|
||||
{
|
||||
"model_dump": lambda self: {"entry": "spam@example.com"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_entry(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.lists.create(direction, list_type, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"entry": input_data.entry}
|
||||
if input_data.reason and input_data.list_type == ListType.BLOCK:
|
||||
params["reason"] = input_data.reason
|
||||
|
||||
result = await self.create_entry(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
**params,
|
||||
)
|
||||
result_dict = result.model_dump()
|
||||
|
||||
yield "entry", input_data.entry
|
||||
yield "result", result_dict
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetListEntryBlock(Block):
|
||||
"""
|
||||
Check if an email address or domain exists in an AgentMail allow/block list.
|
||||
|
||||
Returns the entry details if found. Use this to verify whether a specific
|
||||
address or domain is currently allowed or blocked.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' for outgoing rules, 'receive' for incoming rules"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' for whitelist, 'block' for blacklist"
|
||||
)
|
||||
entry: str = SchemaField(description="Email address or domain to look up")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
entry: str = SchemaField(
|
||||
description="The email address or domain that was found"
|
||||
)
|
||||
result: dict = SchemaField(description="Complete entry object with metadata")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fb117058-ab27-40d1-9231-eb1dd526fc7a",
|
||||
description="Check if an email address or domain is in an allow/block list. Verify filtering rules.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
"entry": "spam@example.com",
|
||||
},
|
||||
test_output=[
|
||||
("entry", "spam@example.com"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_entry": lambda *a, **kw: type(
|
||||
"Entry",
|
||||
(),
|
||||
{
|
||||
"model_dump": lambda self: {"entry": "spam@example.com"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_entry(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, entry: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.lists.get(direction, list_type, entry=entry)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.get_entry(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
input_data.entry,
|
||||
)
|
||||
result_dict = result.model_dump()
|
||||
|
||||
yield "entry", input_data.entry
|
||||
yield "result", result_dict
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteListEntryBlock(Block):
|
||||
"""
|
||||
Remove an email address or domain from an AgentMail allow/block list.
|
||||
|
||||
After removal, the address/domain will no longer be filtered by this list.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' for outgoing rules, 'receive' for incoming rules"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' for whitelist, 'block' for blacklist"
|
||||
)
|
||||
entry: str = SchemaField(
|
||||
description="Email address or domain to remove from the list"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the entry was successfully removed"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2b8d57f1-1c9e-470f-a70b-5991c80fad5f",
|
||||
description="Remove an email address or domain from an allow/block list to stop filtering it.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
"entry": "spam@example.com",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_entry": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_entry(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, entry: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
await client.lists.delete(direction, list_type, entry=entry)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_entry(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
input_data.entry,
|
||||
)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
695
autogpt_platform/backend/backend/blocks/agent_mail/messages.py
Normal file
695
autogpt_platform/backend/backend/blocks/agent_mail/messages.py
Normal file
@@ -0,0 +1,695 @@
|
||||
"""
|
||||
AgentMail Message blocks — send, list, get, reply, forward, and update messages.
|
||||
|
||||
A Message is an individual email within a Thread. Agents can send new messages
|
||||
(which create threads), reply to existing messages, forward them, and manage
|
||||
labels for state tracking (e.g. read/unread, campaign tags).
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailSendMessageBlock(Block):
|
||||
"""
|
||||
Send a new email from an AgentMail inbox, automatically creating a new thread.
|
||||
|
||||
Supports plain text and HTML bodies, CC/BCC recipients, and labels for
|
||||
organizing messages (e.g. campaign tracking, state management).
|
||||
Max 50 combined recipients across to, cc, and bcc.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to send from (e.g. 'agent@agentmail.to')"
|
||||
)
|
||||
to: list[str] = SchemaField(
|
||||
description="Recipient email addresses (e.g. ['user@example.com'])"
|
||||
)
|
||||
subject: str = SchemaField(description="Email subject line")
|
||||
text: str = SchemaField(
|
||||
description="Plain text body of the email. Always provide this as a fallback for email clients that don't render HTML."
|
||||
)
|
||||
html: str = SchemaField(
|
||||
description="Rich HTML body of the email. Embed CSS in a <style> tag for best compatibility across email clients.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
cc: list[str] = SchemaField(
|
||||
description="CC recipient email addresses for human-in-the-loop oversight",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
bcc: list[str] = SchemaField(
|
||||
description="BCC recipient email addresses (hidden from other recipients)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Labels to tag the message for filtering and state management (e.g. ['outreach', 'q4-campaign'])",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Unique identifier of the sent message"
|
||||
)
|
||||
thread_id: str = SchemaField(
|
||||
description="Thread ID grouping this message and any future replies"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete sent message object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b67469b2-7748-4d81-a223-4ebd332cca89",
|
||||
description="Send a new email from an AgentMail inbox. Creates a new conversation thread. Supports HTML, CC/BCC, and labels.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"to": ["user@example.com"],
|
||||
"subject": "Test",
|
||||
"text": "Hello",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-msg-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"send_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-msg-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {
|
||||
"message_id": "mock-msg-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def send_message(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.send(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
total = len(input_data.to) + len(input_data.cc) + len(input_data.bcc)
|
||||
if total > 50:
|
||||
raise ValueError(
|
||||
f"Max 50 combined recipients across to, cc, and bcc (got {total})"
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"to": input_data.to,
|
||||
"subject": input_data.subject,
|
||||
"text": input_data.text,
|
||||
}
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
if input_data.cc:
|
||||
params["cc"] = input_data.cc
|
||||
if input_data.bcc:
|
||||
params["bcc"] = input_data.bcc
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
msg = await self.send_message(credentials, input_data.inbox_id, **params)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "thread_id", msg.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListMessagesBlock(Block):
|
||||
"""
|
||||
List all messages in an AgentMail inbox with optional label filtering.
|
||||
|
||||
Returns a paginated list of messages. Use labels to filter (e.g.
|
||||
labels=['unread'] to only get unprocessed messages). Useful for
|
||||
polling workflows or building inbox views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to list messages from"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of messages to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return messages with ALL of these labels (e.g. ['unread'] or ['q4-campaign', 'follow-up'])",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
messages: list[dict] = SchemaField(
|
||||
description="List of message objects with subject, sender, text, html, labels, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of messages returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="721234df-c7a2-4927-b205-744badbd5844",
|
||||
description="List messages in an AgentMail inbox. Filter by labels to find unread, campaign-tagged, or categorized messages.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("messages", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_messages": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"messages": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_messages(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.list(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_messages(
|
||||
credentials, input_data.inbox_id, **params
|
||||
)
|
||||
messages = [m.model_dump() for m in response.messages]
|
||||
|
||||
yield "messages", messages
|
||||
yield "count", (c if (c := response.count) is not None else len(messages))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetMessageBlock(Block):
|
||||
"""
|
||||
Retrieve a specific email message by ID from an AgentMail inbox.
|
||||
|
||||
Returns the full message including subject, body (text and HTML),
|
||||
sender, recipients, and attachments. Use extracted_text to get
|
||||
only the new reply content without quoted history.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the message belongs to"
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID to retrieve (e.g. '<abc123@agentmail.to>')"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(description="Unique identifier of the message")
|
||||
thread_id: str = SchemaField(description="Thread this message belongs to")
|
||||
subject: str = SchemaField(description="Email subject line")
|
||||
text: str = SchemaField(
|
||||
description="Full plain text body (may include quoted reply history)"
|
||||
)
|
||||
extracted_text: str = SchemaField(
|
||||
description="Just the new reply content with quoted history stripped. Best for AI processing.",
|
||||
default="",
|
||||
)
|
||||
html: str = SchemaField(description="HTML body of the email", default="")
|
||||
result: dict = SchemaField(
|
||||
description="Complete message object with all fields including sender, recipients, attachments, labels"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2788bdfa-1527-4603-a5e4-a455c05c032f",
|
||||
description="Retrieve a specific email message by ID. Includes extracted_text for clean reply content without quoted history.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "test-msg"),
|
||||
("thread_id", "t1"),
|
||||
("subject", "Hi"),
|
||||
("text", "Hello"),
|
||||
("extracted_text", "Hello"),
|
||||
("html", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "test-msg",
|
||||
"thread_id": "t1",
|
||||
"subject": "Hi",
|
||||
"text": "Hello",
|
||||
"extracted_text": "Hello",
|
||||
"html": "",
|
||||
"model_dump": lambda self: {"message_id": "test-msg"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_message(
|
||||
credentials: APIKeyCredentials,
|
||||
inbox_id: str,
|
||||
message_id: str,
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.get(
|
||||
inbox_id=inbox_id, message_id=message_id
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
msg = await self.get_message(
|
||||
credentials, input_data.inbox_id, input_data.message_id
|
||||
)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "thread_id", msg.thread_id or ""
|
||||
yield "subject", msg.subject or ""
|
||||
yield "text", msg.text or ""
|
||||
yield "extracted_text", msg.extracted_text or ""
|
||||
yield "html", msg.html or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailReplyToMessageBlock(Block):
|
||||
"""
|
||||
Reply to an existing email message, keeping the reply in the same thread.
|
||||
|
||||
The reply is automatically added to the same conversation thread as the
|
||||
original message. Use this for multi-turn agent conversations.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to send the reply from"
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID to reply to (e.g. '<abc123@agentmail.to>')"
|
||||
)
|
||||
text: str = SchemaField(description="Plain text body of the reply")
|
||||
html: str = SchemaField(
|
||||
description="Rich HTML body of the reply",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Unique identifier of the reply message"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID the reply was added to")
|
||||
result: dict = SchemaField(
|
||||
description="Complete reply message object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b9fe53fa-5026-4547-9570-b54ccb487229",
|
||||
description="Reply to an existing email in the same conversation thread. Use for multi-turn agent conversations.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"text": "Reply",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-reply-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"reply_to_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-reply-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {"message_id": "mock-reply-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def reply_to_message(
|
||||
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.reply(
|
||||
inbox_id=inbox_id, message_id=message_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"text": input_data.text}
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
|
||||
reply = await self.reply_to_message(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
input_data.message_id,
|
||||
**params,
|
||||
)
|
||||
result = reply.model_dump()
|
||||
|
||||
yield "message_id", reply.message_id
|
||||
yield "thread_id", reply.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailForwardMessageBlock(Block):
|
||||
"""
|
||||
Forward an existing email message to one or more recipients.
|
||||
|
||||
Sends the original message content to different email addresses.
|
||||
Optionally prepend additional text or override the subject line.
|
||||
Max 50 combined recipients across to, cc, and bcc.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to forward from"
|
||||
)
|
||||
message_id: str = SchemaField(description="Message ID to forward")
|
||||
to: list[str] = SchemaField(
|
||||
description="Recipient email addresses to forward the message to (e.g. ['user@example.com'])"
|
||||
)
|
||||
cc: list[str] = SchemaField(
|
||||
description="CC recipient email addresses",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
bcc: list[str] = SchemaField(
|
||||
description="BCC recipient email addresses (hidden from other recipients)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
subject: str = SchemaField(
|
||||
description="Override the subject line (defaults to 'Fwd: <original subject>')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text: str = SchemaField(
|
||||
description="Additional plain text to prepend before the forwarded content",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
html: str = SchemaField(
|
||||
description="Additional HTML to prepend before the forwarded content",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Unique identifier of the forwarded message"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID of the forward")
|
||||
result: dict = SchemaField(
|
||||
description="Complete forwarded message object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b70c7e33-5d66-4f8e-897f-ac73a7bfce82",
|
||||
description="Forward an email message to one or more recipients. Supports CC/BCC and optional extra text or subject override.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"to": ["user@example.com"],
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-fwd-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"forward_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-fwd-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {"message_id": "mock-fwd-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def forward_message(
|
||||
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.forward(
|
||||
inbox_id=inbox_id, message_id=message_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
total = len(input_data.to) + len(input_data.cc) + len(input_data.bcc)
|
||||
if total > 50:
|
||||
raise ValueError(
|
||||
f"Max 50 combined recipients across to, cc, and bcc (got {total})"
|
||||
)
|
||||
|
||||
params: dict = {"to": input_data.to}
|
||||
if input_data.cc:
|
||||
params["cc"] = input_data.cc
|
||||
if input_data.bcc:
|
||||
params["bcc"] = input_data.bcc
|
||||
if input_data.subject:
|
||||
params["subject"] = input_data.subject
|
||||
if input_data.text:
|
||||
params["text"] = input_data.text
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
|
||||
fwd = await self.forward_message(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
input_data.message_id,
|
||||
**params,
|
||||
)
|
||||
result = fwd.model_dump()
|
||||
|
||||
yield "message_id", fwd.message_id
|
||||
yield "thread_id", fwd.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailUpdateMessageBlock(Block):
|
||||
"""
|
||||
Add or remove labels on an email message for state management.
|
||||
|
||||
Labels are string tags used to track message state (read/unread),
|
||||
categorize messages (billing, support), or tag campaigns (q4-outreach).
|
||||
Common pattern: add 'read' and remove 'unread' after processing a message.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the message belongs to"
|
||||
)
|
||||
message_id: str = SchemaField(description="Message ID to update labels on")
|
||||
add_labels: list[str] = SchemaField(
|
||||
description="Labels to add (e.g. ['read', 'processed', 'high-priority'])",
|
||||
default_factory=list,
|
||||
)
|
||||
remove_labels: list[str] = SchemaField(
|
||||
description="Labels to remove (e.g. ['unread', 'pending'])",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(description="The updated message ID")
|
||||
result: dict = SchemaField(
|
||||
description="Complete updated message object with current labels"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="694ff816-4c89-4a5e-a552-8c31be187735",
|
||||
description="Add or remove labels on an email message. Use for read/unread tracking, campaign tagging, or state management.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"add_labels": ["read"],
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "test-msg"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"update_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "test-msg",
|
||||
"model_dump": lambda self: {"message_id": "test-msg"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_message(
|
||||
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.update(
|
||||
inbox_id=inbox_id, message_id=message_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
if not input_data.add_labels and not input_data.remove_labels:
|
||||
raise ValueError(
|
||||
"Must specify at least one label operation: add_labels or remove_labels"
|
||||
)
|
||||
|
||||
params: dict = {}
|
||||
if input_data.add_labels:
|
||||
params["add_labels"] = input_data.add_labels
|
||||
if input_data.remove_labels:
|
||||
params["remove_labels"] = input_data.remove_labels
|
||||
|
||||
msg = await self.update_message(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
input_data.message_id,
|
||||
**params,
|
||||
)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
651
autogpt_platform/backend/backend/blocks/agent_mail/pods.py
Normal file
651
autogpt_platform/backend/backend/blocks/agent_mail/pods.py
Normal file
@@ -0,0 +1,651 @@
|
||||
"""
|
||||
AgentMail Pod blocks — create, get, list, delete pods and list pod-scoped resources.
|
||||
|
||||
Pods provide multi-tenant isolation between your customers. Each pod acts as
|
||||
an isolated workspace containing its own inboxes, domains, threads, and drafts.
|
||||
Use pods when building SaaS platforms, agency tools, or AI agent fleets that
|
||||
serve multiple customers.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailCreatePodBlock(Block):
|
||||
"""
|
||||
Create a new pod for multi-tenant customer isolation.
|
||||
|
||||
Each pod acts as an isolated workspace for one customer or tenant.
|
||||
Use client_id to map pods to your internal tenant IDs for idempotent
|
||||
creation (safe to retry without creating duplicates).
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
client_id: str = SchemaField(
|
||||
description="Your internal tenant/customer ID for idempotent mapping. Lets you access the pod by your own ID instead of AgentMail's pod_id.",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
pod_id: str = SchemaField(description="Unique identifier of the created pod")
|
||||
result: dict = SchemaField(description="Complete pod object with all metadata")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a2db9784-2d17-4f8f-9d6b-0214e6f22101",
|
||||
description="Create a new pod for multi-tenant customer isolation. Use client_id to map to your internal tenant IDs.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("pod_id", "mock-pod-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_pod": lambda *a, **kw: type(
|
||||
"Pod",
|
||||
(),
|
||||
{
|
||||
"pod_id": "mock-pod-id",
|
||||
"model_dump": lambda self: {"pod_id": "mock-pod-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_pod(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.create(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.client_id:
|
||||
params["client_id"] = input_data.client_id
|
||||
|
||||
pod = await self.create_pod(credentials, **params)
|
||||
result = pod.model_dump()
|
||||
|
||||
yield "pod_id", pod.pod_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetPodBlock(Block):
|
||||
"""
|
||||
Retrieve details of an existing pod by its ID.
|
||||
|
||||
Returns the pod metadata including its client_id mapping and
|
||||
creation timestamp.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to retrieve")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
pod_id: str = SchemaField(description="Unique identifier of the pod")
|
||||
result: dict = SchemaField(description="Complete pod object with all metadata")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="553361bc-bb1b-4322-9ad4-0c226200217e",
|
||||
description="Retrieve details of an existing pod including its client_id mapping and metadata.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("pod_id", "test-pod"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_pod": lambda *a, **kw: type(
|
||||
"Pod",
|
||||
(),
|
||||
{
|
||||
"pod_id": "test-pod",
|
||||
"model_dump": lambda self: {"pod_id": "test-pod"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_pod(credentials: APIKeyCredentials, pod_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.pods.get(pod_id=pod_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
pod = await self.get_pod(credentials, pod_id=input_data.pod_id)
|
||||
result = pod.model_dump()
|
||||
|
||||
yield "pod_id", pod.pod_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodsBlock(Block):
|
||||
"""
|
||||
List all pods in your AgentMail organization.
|
||||
|
||||
Returns a paginated list of all tenant pods with their metadata.
|
||||
Use this to see all customer workspaces at a glance.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of pods to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
pods: list[dict] = SchemaField(
|
||||
description="List of pod objects with pod_id, client_id, creation time, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of pods returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9d3725ee-2968-431a-a816-857ab41e1420",
|
||||
description="List all tenant pods in your organization. See all customer workspaces at a glance.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("pods", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pods": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"pods": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pods(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_pods(credentials, **params)
|
||||
pods = [p.model_dump() for p in response.pods]
|
||||
|
||||
yield "pods", pods
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeletePodBlock(Block):
|
||||
"""
|
||||
Permanently delete a pod. All inboxes and domains must be removed first.
|
||||
|
||||
You cannot delete a pod that still contains inboxes or domains.
|
||||
Delete all child resources first, then delete the pod.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(
|
||||
description="Pod ID to permanently delete (must have no inboxes or domains)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the pod was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f371f8cd-682d-4f5f-905c-529c74a8fb35",
|
||||
description="Permanently delete a pod. All inboxes and domains must be removed first.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_pod": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_pod(credentials: APIKeyCredentials, pod_id: str):
|
||||
client = _client(credentials)
|
||||
await client.pods.delete(pod_id=pod_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_pod(credentials, pod_id=input_data.pod_id)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodInboxesBlock(Block):
|
||||
"""
|
||||
List all inboxes within a specific pod (customer workspace).
|
||||
|
||||
Returns only the inboxes belonging to this pod, providing
|
||||
tenant-scoped visibility.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to list inboxes from")
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of inboxes to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inboxes: list[dict] = SchemaField(
|
||||
description="List of inbox objects within this pod"
|
||||
)
|
||||
count: int = SchemaField(description="Number of inboxes returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a8c17ce0-b7c1-4bc3-ae39-680e1952e5d0",
|
||||
description="List all inboxes within a pod. View email accounts scoped to a specific customer.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("inboxes", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pod_inboxes": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"inboxes": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pod_inboxes(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.inboxes.list(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_pod_inboxes(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
inboxes = [i.model_dump() for i in response.inboxes]
|
||||
|
||||
yield "inboxes", inboxes
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodThreadsBlock(Block):
|
||||
"""
|
||||
List all conversation threads across all inboxes within a pod.
|
||||
|
||||
Returns threads from every inbox in the pod. Use for building
|
||||
per-customer dashboards showing all email activity, or for
|
||||
supervisor agents monitoring a customer's conversations.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to list threads from")
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of threads to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return threads matching ALL of these labels",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
threads: list[dict] = SchemaField(
|
||||
description="List of thread objects from all inboxes in this pod"
|
||||
)
|
||||
count: int = SchemaField(description="Number of threads returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="80214f08-8b85-4533-a6b8-f8123bfcb410",
|
||||
description="List all conversation threads across all inboxes within a pod. View all email activity for a customer.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("threads", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pod_threads": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"threads": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pod_threads(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.threads.list(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_pod_threads(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
threads = [t.model_dump() for t in response.threads]
|
||||
|
||||
yield "threads", threads
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodDraftsBlock(Block):
|
||||
"""
|
||||
List all drafts across all inboxes within a pod.
|
||||
|
||||
Returns pending drafts from every inbox in the pod. Use for
|
||||
per-customer approval dashboards or monitoring scheduled sends.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to list drafts from")
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of drafts to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
drafts: list[dict] = SchemaField(
|
||||
description="List of draft objects from all inboxes in this pod"
|
||||
)
|
||||
count: int = SchemaField(description="Number of drafts returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="12fd7a3e-51ad-4b20-97c1-0391f207f517",
|
||||
description="List all drafts across all inboxes within a pod. View pending emails for a customer.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("drafts", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pod_drafts": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"drafts": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pod_drafts(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.drafts.list(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_pod_drafts(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
drafts = [d.model_dump() for d in response.drafts]
|
||||
|
||||
yield "drafts", drafts
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailCreatePodInboxBlock(Block):
|
||||
"""
|
||||
Create a new email inbox within a specific pod (customer workspace).
|
||||
|
||||
The inbox is automatically scoped to the pod and inherits its
|
||||
isolation guarantees. If username/domain are not provided,
|
||||
AgentMail auto-generates a unique address.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to create the inbox in")
|
||||
username: str = SchemaField(
|
||||
description="Local part of the email address (e.g. 'support'). Leave empty to auto-generate.",
|
||||
default="",
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Email domain (e.g. 'mydomain.com'). Defaults to agentmail.to if empty.",
|
||||
default="",
|
||||
)
|
||||
display_name: str = SchemaField(
|
||||
description="Friendly name shown in the 'From' field (e.g. 'Customer Support')",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(
|
||||
description="Unique identifier of the created inbox"
|
||||
)
|
||||
email_address: str = SchemaField(description="Full email address of the inbox")
|
||||
result: dict = SchemaField(
|
||||
description="Complete inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c6862373-1ac6-402e-89e6-7db1fea882af",
|
||||
description="Create a new email inbox within a pod. The inbox is scoped to the customer workspace.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("inbox_id", "mock-inbox-id"),
|
||||
("email_address", "mock-inbox-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_pod_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "mock-inbox-id",
|
||||
"model_dump": lambda self: {"inbox_id": "mock-inbox-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_pod_inbox(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.inboxes.create(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.username:
|
||||
params["username"] = input_data.username
|
||||
if input_data.domain:
|
||||
params["domain"] = input_data.domain
|
||||
if input_data.display_name:
|
||||
params["display_name"] = input_data.display_name
|
||||
|
||||
inbox = await self.create_pod_inbox(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "email_address", inbox.inbox_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
438
autogpt_platform/backend/backend/blocks/agent_mail/threads.py
Normal file
438
autogpt_platform/backend/backend/blocks/agent_mail/threads.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""
|
||||
AgentMail Thread blocks — list, get, and delete conversation threads.
|
||||
|
||||
A Thread groups related messages into a single conversation. Threads are
|
||||
created automatically when a new message is sent and grow as replies are added.
|
||||
Threads can be queried per-inbox or across the entire organization.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailListInboxThreadsBlock(Block):
|
||||
"""
|
||||
List all conversation threads within a specific AgentMail inbox.
|
||||
|
||||
Returns a paginated list of threads with optional label filtering.
|
||||
Use labels to find threads by campaign, status, or custom tags.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to list threads from"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of threads to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return threads matching ALL of these labels (e.g. ['q4-campaign', 'follow-up'])",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
threads: list[dict] = SchemaField(
|
||||
description="List of thread objects with thread_id, subject, message count, labels, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of threads returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="63dd9e2d-ef81-405c-b034-c031f0437334",
|
||||
description="List all conversation threads in an AgentMail inbox. Filter by labels for campaign tracking or status management.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("threads", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_threads": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"threads": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_threads(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.threads.list(inbox_id=inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_threads(
|
||||
credentials, input_data.inbox_id, **params
|
||||
)
|
||||
threads = [t.model_dump() for t in response.threads]
|
||||
|
||||
yield "threads", threads
|
||||
yield "count", (c if (c := response.count) is not None else len(threads))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetInboxThreadBlock(Block):
|
||||
"""
|
||||
Retrieve a single conversation thread from an AgentMail inbox.
|
||||
|
||||
Returns the thread with all its messages in chronological order.
|
||||
Use this to get the full conversation history for context when
|
||||
composing replies.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the thread belongs to"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID to retrieve")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
thread_id: str = SchemaField(description="Unique identifier of the thread")
|
||||
messages: list[dict] = SchemaField(
|
||||
description="All messages in the thread, in chronological order"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete thread object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="42866290-1479-4153-83e7-550b703e9da2",
|
||||
description="Retrieve a conversation thread with all its messages. Use for getting full conversation context before replying.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"thread_id": "test-thread",
|
||||
},
|
||||
test_output=[
|
||||
("thread_id", "test-thread"),
|
||||
("messages", []),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_thread": lambda *a, **kw: type(
|
||||
"Thread",
|
||||
(),
|
||||
{
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
"model_dump": lambda self: {
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_thread(credentials: APIKeyCredentials, inbox_id: str, thread_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.threads.get(inbox_id=inbox_id, thread_id=thread_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
thread = await self.get_thread(
|
||||
credentials, input_data.inbox_id, input_data.thread_id
|
||||
)
|
||||
messages = [m.model_dump() for m in thread.messages]
|
||||
result = thread.model_dump()
|
||||
result["messages"] = messages
|
||||
|
||||
yield "thread_id", thread.thread_id
|
||||
yield "messages", messages
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteInboxThreadBlock(Block):
|
||||
"""
|
||||
Permanently delete a conversation thread and all its messages from an inbox.
|
||||
|
||||
This removes the thread and every message within it. This action
|
||||
cannot be undone.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the thread belongs to"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID to permanently delete")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the thread was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="18cd5f6f-4ff6-45da-8300-25a50ea7fb75",
|
||||
description="Permanently delete a conversation thread and all its messages. This action cannot be undone.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"thread_id": "test-thread",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_thread": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_thread(
|
||||
credentials: APIKeyCredentials, inbox_id: str, thread_id: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
await client.inboxes.threads.delete(inbox_id=inbox_id, thread_id=thread_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_thread(
|
||||
credentials, input_data.inbox_id, input_data.thread_id
|
||||
)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListOrgThreadsBlock(Block):
|
||||
"""
|
||||
List conversation threads across ALL inboxes in your organization.
|
||||
|
||||
Unlike per-inbox listing, this returns threads from every inbox.
|
||||
Ideal for building supervisor agents that monitor all conversations,
|
||||
analytics dashboards, or cross-agent routing workflows.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of threads to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return threads matching ALL of these labels",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
threads: list[dict] = SchemaField(
|
||||
description="List of thread objects from all inboxes in the organization"
|
||||
)
|
||||
count: int = SchemaField(description="Number of threads returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d7a0657b-58ab-48b2-898b-7bd94f44a708",
|
||||
description="List threads across ALL inboxes in your organization. Use for supervisor agents, dashboards, or cross-agent monitoring.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("threads", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_org_threads": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"threads": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_org_threads(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.threads.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_org_threads(credentials, **params)
|
||||
threads = [t.model_dump() for t in response.threads]
|
||||
|
||||
yield "threads", threads
|
||||
yield "count", (c if (c := response.count) is not None else len(threads))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetOrgThreadBlock(Block):
|
||||
"""
|
||||
Retrieve a single conversation thread by ID from anywhere in the organization.
|
||||
|
||||
Works without needing to know which inbox the thread belongs to.
|
||||
Returns the thread with all its messages in chronological order.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
thread_id: str = SchemaField(
|
||||
description="Thread ID to retrieve (works across all inboxes)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
thread_id: str = SchemaField(description="Unique identifier of the thread")
|
||||
messages: list[dict] = SchemaField(
|
||||
description="All messages in the thread, in chronological order"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete thread object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="39aaae31-3eb1-44c6-9e37-5a44a4529649",
|
||||
description="Retrieve a conversation thread by ID from anywhere in the organization, without needing the inbox ID.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"thread_id": "test-thread",
|
||||
},
|
||||
test_output=[
|
||||
("thread_id", "test-thread"),
|
||||
("messages", []),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_org_thread": lambda *a, **kw: type(
|
||||
"Thread",
|
||||
(),
|
||||
{
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
"model_dump": lambda self: {
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_org_thread(credentials: APIKeyCredentials, thread_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.threads.get(thread_id=thread_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
thread = await self.get_org_thread(credentials, input_data.thread_id)
|
||||
messages = [m.model_dump() for m in thread.messages]
|
||||
result = thread.model_dump()
|
||||
result["messages"] = messages
|
||||
|
||||
yield "thread_id", thread.thread_id
|
||||
yield "messages", messages
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -27,6 +27,7 @@ from backend.util.file import MediaFileType, store_media_file
|
||||
class GeminiImageModel(str, Enum):
|
||||
NANO_BANANA = "google/nano-banana"
|
||||
NANO_BANANA_PRO = "google/nano-banana-pro"
|
||||
NANO_BANANA_2 = "google/nano-banana-2"
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
@@ -77,7 +78,7 @@ class AIImageCustomizerBlock(Block):
|
||||
)
|
||||
model: GeminiImageModel = SchemaField(
|
||||
description="The AI model to use for image generation and editing",
|
||||
default=GeminiImageModel.NANO_BANANA,
|
||||
default=GeminiImageModel.NANO_BANANA_2,
|
||||
title="Model",
|
||||
)
|
||||
images: list[MediaFileType] = SchemaField(
|
||||
@@ -103,7 +104,7 @@ class AIImageCustomizerBlock(Block):
|
||||
super().__init__(
|
||||
id="d76bbe4c-930e-4894-8469-b66775511f71",
|
||||
description=(
|
||||
"Generate and edit custom images using Google's Nano-Banana model from Gemini 2.5. "
|
||||
"Generate and edit custom images using Google's Nano-Banana models from Gemini. "
|
||||
"Provide a prompt and optional reference images to create or modify images."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
@@ -111,7 +112,7 @@ class AIImageCustomizerBlock(Block):
|
||||
output_schema=AIImageCustomizerBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Make the scene more vibrant and colorful",
|
||||
"model": GeminiImageModel.NANO_BANANA,
|
||||
"model": GeminiImageModel.NANO_BANANA_2,
|
||||
"images": [],
|
||||
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
|
||||
"output_format": OutputFormat.JPG,
|
||||
|
||||
@@ -115,6 +115,7 @@ class ImageGenModel(str, Enum):
|
||||
RECRAFT = "Recraft v3"
|
||||
SD3_5 = "Stable Diffusion 3.5 Medium"
|
||||
NANO_BANANA_PRO = "Nano Banana Pro"
|
||||
NANO_BANANA_2 = "Nano Banana 2"
|
||||
|
||||
|
||||
class AIImageGeneratorBlock(Block):
|
||||
@@ -131,7 +132,7 @@ class AIImageGeneratorBlock(Block):
|
||||
)
|
||||
model: ImageGenModel = SchemaField(
|
||||
description="The AI model to use for image generation",
|
||||
default=ImageGenModel.SD3_5,
|
||||
default=ImageGenModel.NANO_BANANA_2,
|
||||
title="Model",
|
||||
)
|
||||
size: ImageSize = SchemaField(
|
||||
@@ -165,7 +166,7 @@ class AIImageGeneratorBlock(Block):
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"prompt": "An octopus using a laptop in a snowy forest with 'AutoGPT' clearly visible on the screen",
|
||||
"model": ImageGenModel.RECRAFT,
|
||||
"model": ImageGenModel.NANO_BANANA_2,
|
||||
"size": ImageSize.SQUARE,
|
||||
"style": ImageStyle.REALISTIC,
|
||||
},
|
||||
@@ -179,7 +180,9 @@ class AIImageGeneratorBlock(Block):
|
||||
],
|
||||
test_mock={
|
||||
# Return a data URI directly so store_media_file doesn't need to download
|
||||
"_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
|
||||
"_run_client": lambda *args, **kwargs: (
|
||||
"data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -280,17 +283,24 @@ class AIImageGeneratorBlock(Block):
|
||||
)
|
||||
return output
|
||||
|
||||
elif input_data.model == ImageGenModel.NANO_BANANA_PRO:
|
||||
# Use Nano Banana Pro (Google Gemini 3 Pro Image)
|
||||
elif input_data.model in (
|
||||
ImageGenModel.NANO_BANANA_PRO,
|
||||
ImageGenModel.NANO_BANANA_2,
|
||||
):
|
||||
# Use Nano Banana models (Google Gemini image variants)
|
||||
model_map = {
|
||||
ImageGenModel.NANO_BANANA_PRO: "google/nano-banana-pro",
|
||||
ImageGenModel.NANO_BANANA_2: "google/nano-banana-2",
|
||||
}
|
||||
input_params = {
|
||||
"prompt": modified_prompt,
|
||||
"aspect_ratio": SIZE_TO_NANO_BANANA_RATIO[input_data.size],
|
||||
"resolution": "2K", # Default to 2K for good quality/cost balance
|
||||
"resolution": "2K",
|
||||
"output_format": "jpg",
|
||||
"safety_filter_level": "block_only_high", # Most permissive
|
||||
"safety_filter_level": "block_only_high",
|
||||
}
|
||||
output = await self._run_client(
|
||||
credentials, "google/nano-banana-pro", input_params
|
||||
credentials, model_map[input_data.model], input_params
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
376
autogpt_platform/backend/backend/blocks/autopilot.py
Normal file
376
autogpt_platform/backend/backend/blocks/autopilot.py
Normal file
@@ -0,0 +1,376 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from typing_extensions import TypedDict # Needed for Python <3.12 compatibility
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Block ID shared between autopilot.py and copilot prompting.py.
|
||||
AUTOPILOT_BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
|
||||
|
||||
class ToolCallEntry(TypedDict):
|
||||
"""A single tool invocation record from an autopilot execution."""
|
||||
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
input: Any
|
||||
output: Any | None
|
||||
success: bool | None
|
||||
|
||||
|
||||
class TokenUsage(TypedDict):
|
||||
"""Aggregated token counts from the autopilot stream."""
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class AutoPilotBlock(Block):
|
||||
"""Execute tasks using AutoGPT AutoPilot with full access to platform tools.
|
||||
|
||||
The autopilot can manage agents, access workspace files, fetch web content,
|
||||
run blocks, and more. This block enables sub-agent patterns (autopilot calling
|
||||
autopilot) and scheduled autopilot execution via the agent executor.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
"""Input schema for the AutoPilot block."""
|
||||
|
||||
prompt: str = SchemaField(
|
||||
description=(
|
||||
"The task or instruction for the autopilot to execute. "
|
||||
"The autopilot has access to platform tools like agent management, "
|
||||
"workspace files, web fetch, block execution, and more."
|
||||
),
|
||||
placeholder="Find my agents and list them",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
system_context: str = SchemaField(
|
||||
description=(
|
||||
"Optional additional context prepended to the prompt. "
|
||||
"Use this to constrain autopilot behavior, provide domain "
|
||||
"context, or set output format requirements."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
session_id: str = SchemaField(
|
||||
description=(
|
||||
"Session ID to continue an existing autopilot conversation. "
|
||||
"Leave empty to start a new session. "
|
||||
"Use the session_id output from a previous run to continue."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
max_recursion_depth: int = SchemaField(
|
||||
description=(
|
||||
"Maximum nesting depth when the autopilot calls this block "
|
||||
"recursively (sub-agent pattern). Prevents infinite loops."
|
||||
),
|
||||
default=3,
|
||||
ge=1,
|
||||
le=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# timeout_seconds removed: the SDK manages its own heartbeat-based
|
||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||
# SDK's internal stream (see service.py CRITICAL comment).
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
"""Output schema for the AutoPilot block."""
|
||||
|
||||
response: str = SchemaField(
|
||||
description="The final text response from the autopilot."
|
||||
)
|
||||
tool_calls: list[ToolCallEntry] = SchemaField(
|
||||
description=(
|
||||
"List of tools called during execution. Each entry has "
|
||||
"tool_call_id, tool_name, input, output, and success fields."
|
||||
),
|
||||
)
|
||||
conversation_history: str = SchemaField(
|
||||
description=(
|
||||
"Current turn messages (user prompt + assistant reply) as JSON. "
|
||||
"It can be used for logging or analysis."
|
||||
),
|
||||
)
|
||||
session_id: str = SchemaField(
|
||||
description=(
|
||||
"Session ID for this conversation. "
|
||||
"Pass this back to continue the conversation in a future run."
|
||||
),
|
||||
)
|
||||
token_usage: TokenUsage = SchemaField(
|
||||
description=(
|
||||
"Token usage statistics: prompt_tokens, "
|
||||
"completion_tokens, total_tokens."
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id=AUTOPILOT_BLOCK_ID,
|
||||
description=(
|
||||
"Execute tasks using AutoGPT AutoPilot with full access to "
|
||||
"platform tools (agent management, workspace files, web fetch, "
|
||||
"block execution, and more). Enables sub-agent patterns and "
|
||||
"scheduled autopilot execution."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.AGENT},
|
||||
input_schema=AutoPilotBlock.Input,
|
||||
output_schema=AutoPilotBlock.Output,
|
||||
test_input={
|
||||
"prompt": "List my agents",
|
||||
"system_context": "",
|
||||
"session_id": "",
|
||||
"max_recursion_depth": 3,
|
||||
},
|
||||
test_output=[
|
||||
("response", "You have 2 agents: Agent A and Agent B."),
|
||||
("tool_calls", []),
|
||||
(
|
||||
"conversation_history",
|
||||
'[{"role": "user", "content": "List my agents"}]',
|
||||
),
|
||||
("session_id", "test-session-id"),
|
||||
(
|
||||
"token_usage",
|
||||
{
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"create_session": lambda *args, **kwargs: "test-session-id",
|
||||
"execute_copilot": lambda *args, **kwargs: (
|
||||
"You have 2 agents: Agent A and Agent B.",
|
||||
[],
|
||||
'[{"role": "user", "content": "List my agents"}]',
|
||||
"test-session-id",
|
||||
{
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def create_session(self, user_id: str) -> str:
|
||||
"""Create a new chat session and return its ID (mockable for tests)."""
|
||||
from backend.copilot.model import create_chat_session
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
return session.session_id
|
||||
|
||||
async def execute_copilot(
|
||||
self,
|
||||
prompt: str,
|
||||
system_context: str,
|
||||
session_id: str,
|
||||
max_recursion_depth: int,
|
||||
user_id: str,
|
||||
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
|
||||
"""Invoke the copilot and collect all stream results.
|
||||
|
||||
Delegates to :func:`collect_copilot_response` — the shared helper that
|
||||
consumes ``stream_chat_completion_sdk`` without wrapping it in an
|
||||
``asyncio.timeout`` (the SDK manages its own heartbeat-based timeouts).
|
||||
|
||||
Args:
|
||||
prompt: The user task/instruction.
|
||||
system_context: Optional context prepended to the prompt.
|
||||
session_id: Chat session to use.
|
||||
max_recursion_depth: Maximum allowed recursion nesting.
|
||||
user_id: Authenticated user ID.
|
||||
|
||||
Returns:
|
||||
A tuple of (response_text, tool_calls, history_json, session_id, usage).
|
||||
"""
|
||||
from backend.copilot.sdk.collect import collect_copilot_response
|
||||
|
||||
tokens = _check_recursion(max_recursion_depth)
|
||||
try:
|
||||
effective_prompt = prompt
|
||||
if system_context:
|
||||
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
|
||||
|
||||
result = await collect_copilot_response(
|
||||
session_id=session_id,
|
||||
message=effective_prompt,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Build a lightweight conversation summary from streamed data.
|
||||
turn_messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": effective_prompt},
|
||||
]
|
||||
if result.tool_calls:
|
||||
turn_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": result.response_text,
|
||||
"tool_calls": result.tool_calls,
|
||||
}
|
||||
)
|
||||
else:
|
||||
turn_messages.append(
|
||||
{"role": "assistant", "content": result.response_text}
|
||||
)
|
||||
history_json = json.dumps(turn_messages, default=str)
|
||||
|
||||
tool_calls: list[ToolCallEntry] = [
|
||||
{
|
||||
"tool_call_id": tc["tool_call_id"],
|
||||
"tool_name": tc["tool_name"],
|
||||
"input": tc["input"],
|
||||
"output": tc["output"],
|
||||
"success": tc["success"],
|
||||
}
|
||||
for tc in result.tool_calls
|
||||
]
|
||||
|
||||
usage: TokenUsage = {
|
||||
"prompt_tokens": result.prompt_tokens,
|
||||
"completion_tokens": result.completion_tokens,
|
||||
"total_tokens": result.total_tokens,
|
||||
}
|
||||
|
||||
return (
|
||||
result.response_text,
|
||||
tool_calls,
|
||||
history_json,
|
||||
session_id,
|
||||
usage,
|
||||
)
|
||||
finally:
|
||||
_reset_recursion(tokens)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Validate inputs, invoke the autopilot, and yield structured outputs.
|
||||
|
||||
Yields session_id even on failure so callers can inspect/resume the session.
|
||||
"""
|
||||
if not input_data.prompt.strip():
|
||||
yield "error", "Prompt cannot be empty."
|
||||
return
|
||||
|
||||
if not execution_context.user_id:
|
||||
yield "error", "Cannot run autopilot without an authenticated user."
|
||||
return
|
||||
|
||||
if input_data.max_recursion_depth < 1:
|
||||
yield "error", "max_recursion_depth must be at least 1."
|
||||
return
|
||||
|
||||
# Create session eagerly so the user always gets the session_id,
|
||||
# even if the downstream stream fails (avoids orphaned sessions).
|
||||
sid = input_data.session_id
|
||||
if not sid:
|
||||
sid = await self.create_session(execution_context.user_id)
|
||||
|
||||
# NOTE: No asyncio.timeout() here — the SDK manages its own
|
||||
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout
|
||||
# would cancel the task mid-flight, corrupting the SDK's internal
|
||||
# anyio memory stream (see service.py CRITICAL comment).
|
||||
try:
|
||||
response, tool_calls, history, _, usage = await self.execute_copilot(
|
||||
prompt=input_data.prompt,
|
||||
system_context=input_data.system_context,
|
||||
session_id=sid,
|
||||
max_recursion_depth=input_data.max_recursion_depth,
|
||||
user_id=execution_context.user_id,
|
||||
)
|
||||
|
||||
yield "response", response
|
||||
yield "tool_calls", tool_calls
|
||||
yield "conversation_history", history
|
||||
yield "session_id", sid
|
||||
yield "token_usage", usage
|
||||
except asyncio.CancelledError:
|
||||
yield "session_id", sid
|
||||
yield "error", "AutoPilot execution was cancelled."
|
||||
raise
|
||||
except Exception as exc:
|
||||
yield "session_id", sid
|
||||
yield "error", str(exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers – placed after the block class for top-down readability.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Task-scoped recursion depth counter & chain-wide limit.
|
||||
# contextvars are scoped to the current asyncio task, so concurrent
|
||||
# graph executions each get independent counters.
|
||||
_autopilot_recursion_depth: contextvars.ContextVar[int] = contextvars.ContextVar(
|
||||
"_autopilot_recursion_depth", default=0
|
||||
)
|
||||
_autopilot_recursion_limit: contextvars.ContextVar[int | None] = contextvars.ContextVar(
|
||||
"_autopilot_recursion_limit", default=None
|
||||
)
|
||||
|
||||
|
||||
def _check_recursion(
|
||||
max_depth: int,
|
||||
) -> tuple[contextvars.Token[int], contextvars.Token[int | None]]:
|
||||
"""Check and increment recursion depth.
|
||||
|
||||
Returns ContextVar tokens that must be passed to ``_reset_recursion``
|
||||
when the caller exits to restore the previous depth.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the current depth already meets or exceeds the limit.
|
||||
"""
|
||||
current = _autopilot_recursion_depth.get()
|
||||
inherited = _autopilot_recursion_limit.get()
|
||||
limit = max_depth if inherited is None else min(inherited, max_depth)
|
||||
if current >= limit:
|
||||
raise RuntimeError(
|
||||
f"AutoPilot recursion depth limit reached ({limit}). "
|
||||
"The autopilot has called itself too many times."
|
||||
)
|
||||
return (
|
||||
_autopilot_recursion_depth.set(current + 1),
|
||||
_autopilot_recursion_limit.set(limit),
|
||||
)
|
||||
|
||||
|
||||
def _reset_recursion(
|
||||
tokens: tuple[contextvars.Token[int], contextvars.Token[int | None]],
|
||||
) -> None:
|
||||
"""Restore recursion depth and limit to their previous values."""
|
||||
_autopilot_recursion_depth.reset(tokens[0])
|
||||
_autopilot_recursion_limit.reset(tokens[1])
|
||||
@@ -472,7 +472,7 @@ class AddToListBlock(Block):
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
entries_added = input_data.entries.copy()
|
||||
if input_data.entry:
|
||||
if input_data.entry is not None:
|
||||
entries_added.append(input_data.entry)
|
||||
|
||||
updated_list = input_data.list.copy()
|
||||
|
||||
@@ -21,6 +21,7 @@ from backend.data.model import (
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import resolve_and_check_blocked
|
||||
|
||||
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -99,6 +100,8 @@ class SendEmailBlock(Block):
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
ALLOWED_SMTP_PORTS = {25, 465, 587, 2525}
|
||||
|
||||
@staticmethod
|
||||
def send_email(
|
||||
config: SMTPConfig,
|
||||
@@ -129,6 +132,17 @@ class SendEmailBlock(Block):
|
||||
self, input_data: Input, *, credentials: SMTPCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
# --- SSRF Protection ---
|
||||
smtp_port = input_data.config.smtp_port
|
||||
if smtp_port not in self.ALLOWED_SMTP_PORTS:
|
||||
yield "error", (
|
||||
f"SMTP port {smtp_port} is not allowed. "
|
||||
f"Allowed ports: {sorted(self.ALLOWED_SMTP_PORTS)}"
|
||||
)
|
||||
return
|
||||
|
||||
await resolve_and_check_blocked(input_data.config.smtp_server)
|
||||
|
||||
status = self.send_email(
|
||||
config=input_data.config,
|
||||
to_email=input_data.to_email,
|
||||
@@ -180,7 +194,19 @@ class SendEmailBlock(Block):
|
||||
"was rejected by the server. "
|
||||
"Please verify your account is authorized to send emails."
|
||||
)
|
||||
except smtplib.SMTPConnectError:
|
||||
yield "error", (
|
||||
f"Cannot connect to SMTP server '{input_data.config.smtp_server}' "
|
||||
f"on port {input_data.config.smtp_port}."
|
||||
)
|
||||
except smtplib.SMTPServerDisconnected:
|
||||
yield "error", (
|
||||
f"SMTP server '{input_data.config.smtp_server}' "
|
||||
"disconnected unexpectedly."
|
||||
)
|
||||
except smtplib.SMTPDataError as e:
|
||||
yield "error", f"Email data rejected by server: {str(e)}"
|
||||
except ValueError as e:
|
||||
yield "error", str(e)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@@ -34,17 +34,29 @@ TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class FluxKontextModelName(str, Enum):
|
||||
PRO = "Flux Kontext Pro"
|
||||
MAX = "Flux Kontext Max"
|
||||
class ImageEditorModel(str, Enum):
|
||||
FLUX_KONTEXT_PRO = "Flux Kontext Pro"
|
||||
FLUX_KONTEXT_MAX = "Flux Kontext Max"
|
||||
NANO_BANANA_PRO = "Nano Banana Pro"
|
||||
NANO_BANANA_2 = "Nano Banana 2"
|
||||
|
||||
@property
|
||||
def api_name(self) -> str:
|
||||
return f"black-forest-labs/flux-kontext-{self.name.lower()}"
|
||||
_map = {
|
||||
"FLUX_KONTEXT_PRO": "black-forest-labs/flux-kontext-pro",
|
||||
"FLUX_KONTEXT_MAX": "black-forest-labs/flux-kontext-max",
|
||||
"NANO_BANANA_PRO": "google/nano-banana-pro",
|
||||
"NANO_BANANA_2": "google/nano-banana-2",
|
||||
}
|
||||
return _map[self.name]
|
||||
|
||||
|
||||
# Keep old name as alias for backwards compatibility
|
||||
FluxKontextModelName = ImageEditorModel
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
@@ -69,7 +81,7 @@ class AIImageEditorBlock(Block):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Replicate API key with permissions for Flux Kontext models",
|
||||
description="Replicate API key with permissions for Flux Kontext and Nano Banana models",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="Text instruction describing the desired edit",
|
||||
@@ -87,14 +99,14 @@ class AIImageEditorBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
seed: Optional[int] = SchemaField(
|
||||
description="Random seed. Set for reproducible generation",
|
||||
description="Random seed. Set for reproducible generation (Flux Kontext only; ignored by Nano Banana models)",
|
||||
default=None,
|
||||
title="Seed",
|
||||
advanced=True,
|
||||
)
|
||||
model: FluxKontextModelName = SchemaField(
|
||||
model: ImageEditorModel = SchemaField(
|
||||
description="Model variant to use",
|
||||
default=FluxKontextModelName.PRO,
|
||||
default=ImageEditorModel.NANO_BANANA_2,
|
||||
title="Model",
|
||||
)
|
||||
|
||||
@@ -107,7 +119,7 @@ class AIImageEditorBlock(Block):
|
||||
super().__init__(
|
||||
id="3fd9c73d-4370-4925-a1ff-1b86b99fabfa",
|
||||
description=(
|
||||
"Edit images using BlackForest Labs' Flux Kontext models. Provide a prompt "
|
||||
"Edit images using Flux Kontext or Google Nano Banana models. Provide a prompt "
|
||||
"and optional reference image to generate a modified image."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
@@ -118,7 +130,7 @@ class AIImageEditorBlock(Block):
|
||||
"input_image": "data:image/png;base64,MQ==",
|
||||
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
|
||||
"seed": None,
|
||||
"model": FluxKontextModelName.PRO,
|
||||
"model": ImageEditorModel.NANO_BANANA_2,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
@@ -127,7 +139,9 @@ class AIImageEditorBlock(Block):
|
||||
],
|
||||
test_mock={
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"run_model": lambda *args, **kwargs: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
|
||||
"run_model": lambda *args, **kwargs: (
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
),
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
@@ -142,7 +156,7 @@ class AIImageEditorBlock(Block):
|
||||
) -> BlockOutput:
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
model_name=input_data.model.api_name,
|
||||
model=input_data.model,
|
||||
prompt=input_data.prompt,
|
||||
input_image_b64=(
|
||||
await store_media_file(
|
||||
@@ -169,7 +183,7 @@ class AIImageEditorBlock(Block):
|
||||
async def run_model(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
model: ImageEditorModel,
|
||||
prompt: str,
|
||||
input_image_b64: Optional[str],
|
||||
aspect_ratio: str,
|
||||
@@ -178,12 +192,29 @@ class AIImageEditorBlock(Block):
|
||||
graph_exec_id: str,
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"input_image": input_image_b64,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
model_name = model.api_name
|
||||
|
||||
is_nano_banana = model in (
|
||||
ImageEditorModel.NANO_BANANA_PRO,
|
||||
ImageEditorModel.NANO_BANANA_2,
|
||||
)
|
||||
if is_nano_banana:
|
||||
input_params: dict = {
|
||||
"prompt": prompt,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"output_format": "jpg",
|
||||
"safety_filter_level": "block_only_high",
|
||||
}
|
||||
# NB API expects "image_input" as a list, unlike Flux's single "input_image"
|
||||
if input_image_b64:
|
||||
input_params["image_input"] = [input_image_b64]
|
||||
else:
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"input_image": input_image_b64,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
|
||||
try:
|
||||
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
|
||||
|
||||
@@ -211,7 +211,7 @@ class AgentOutputBlock(Block):
|
||||
if input_data.format:
|
||||
try:
|
||||
formatter = TextFormatter(autoescape=input_data.escape_html)
|
||||
yield "output", formatter.format_string(
|
||||
yield "output", await formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1276,8 +1276,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = fmt.format_string(input_data.sys_prompt, values)
|
||||
input_data.prompt = await fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = await fmt.format_string(
|
||||
input_data.sys_prompt, values
|
||||
)
|
||||
|
||||
if input_data.sys_prompt:
|
||||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||
|
||||
@@ -1050,8 +1050,10 @@ class SmartDecisionMakerBlock(Block):
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
|
||||
input_data.prompt = await llm.fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = await llm.fmt.format_string(
|
||||
input_data.sys_prompt, values
|
||||
)
|
||||
|
||||
if input_data.sys_prompt and not any(
|
||||
p["role"] == "system" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||
|
||||
223
autogpt_platform/backend/backend/blocks/test/test_autopilot.py
Normal file
223
autogpt_platform/backend/backend/blocks/test/test_autopilot.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Tests for AutoPilotBlock: recursion guard, streaming, validation, and error paths."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.autopilot import (
|
||||
AUTOPILOT_BLOCK_ID,
|
||||
AutoPilotBlock,
|
||||
_autopilot_recursion_depth,
|
||||
_autopilot_recursion_limit,
|
||||
_check_recursion,
|
||||
_reset_recursion,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
|
||||
def _make_context(user_id: str = "test-user-123") -> ExecutionContext:
|
||||
"""Helper to build an ExecutionContext for tests."""
|
||||
return ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id="graph-1",
|
||||
graph_exec_id="gexec-1",
|
||||
graph_version=1,
|
||||
node_id="node-1",
|
||||
node_exec_id="nexec-1",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recursion guard unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRecursion:
|
||||
"""Unit tests for _check_recursion / _reset_recursion."""
|
||||
|
||||
def test_first_call_increments_depth(self):
|
||||
tokens = _check_recursion(3)
|
||||
try:
|
||||
assert _autopilot_recursion_depth.get() == 1
|
||||
assert _autopilot_recursion_limit.get() == 3
|
||||
finally:
|
||||
_reset_recursion(tokens)
|
||||
|
||||
def test_reset_restores_previous_values(self):
|
||||
assert _autopilot_recursion_depth.get() == 0
|
||||
assert _autopilot_recursion_limit.get() is None
|
||||
tokens = _check_recursion(5)
|
||||
_reset_recursion(tokens)
|
||||
assert _autopilot_recursion_depth.get() == 0
|
||||
assert _autopilot_recursion_limit.get() is None
|
||||
|
||||
def test_exceeding_limit_raises(self):
|
||||
t1 = _check_recursion(2)
|
||||
try:
|
||||
t2 = _check_recursion(2)
|
||||
try:
|
||||
with pytest.raises(RuntimeError, match="recursion depth limit"):
|
||||
_check_recursion(2)
|
||||
finally:
|
||||
_reset_recursion(t2)
|
||||
finally:
|
||||
_reset_recursion(t1)
|
||||
|
||||
def test_nested_calls_respect_inherited_limit(self):
|
||||
"""Inner call with higher max_depth still respects outer limit."""
|
||||
t1 = _check_recursion(2) # sets limit=2
|
||||
try:
|
||||
t2 = _check_recursion(10) # inner wants 10, but inherited is 2
|
||||
try:
|
||||
# depth is now 2, limit is min(10, 2) = 2 → should raise
|
||||
with pytest.raises(RuntimeError, match="recursion depth limit"):
|
||||
_check_recursion(10)
|
||||
finally:
|
||||
_reset_recursion(t2)
|
||||
finally:
|
||||
_reset_recursion(t1)
|
||||
|
||||
def test_limit_of_one_blocks_immediately_on_second_call(self):
|
||||
t1 = _check_recursion(1)
|
||||
try:
|
||||
with pytest.raises(RuntimeError):
|
||||
_check_recursion(1)
|
||||
finally:
|
||||
_reset_recursion(t1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AutoPilotBlock.run() validation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunValidation:
|
||||
"""Tests for input validation in AutoPilotBlock.run()."""
|
||||
|
||||
@pytest.fixture
|
||||
def block(self):
|
||||
return AutoPilotBlock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_prompt_yields_error(self, block):
|
||||
block.Input # ensure schema is accessible
|
||||
input_data = block.Input(prompt=" ", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
assert outputs.get("error") == "Prompt cannot be empty."
|
||||
assert "response" not in outputs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_user_id_yields_error(self, block):
|
||||
input_data = block.Input(prompt="hello", max_recursion_depth=3)
|
||||
ctx = _make_context(user_id="")
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
assert "authenticated user" in outputs.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_run_yields_all_outputs(self, block):
|
||||
"""With execute_copilot mocked, run() should yield all 5 success outputs."""
|
||||
mock_result = (
|
||||
"Hello world",
|
||||
[],
|
||||
'[{"role":"user","content":"hi"}]',
|
||||
"sess-abc",
|
||||
{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
block.execute_copilot = AsyncMock(return_value=mock_result)
|
||||
block.create_session = AsyncMock(return_value="sess-abc")
|
||||
|
||||
input_data = block.Input(prompt="hi", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["response"] == "Hello world"
|
||||
assert outputs["tool_calls"] == []
|
||||
assert outputs["session_id"] == "sess-abc"
|
||||
assert outputs["token_usage"]["total_tokens"] == 15
|
||||
assert "error" not in outputs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_yields_error(self, block):
|
||||
"""On unexpected failure, run() should yield an error output."""
|
||||
block.execute_copilot = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
block.create_session = AsyncMock(return_value="sess-fail")
|
||||
|
||||
input_data = block.Input(prompt="do something", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["session_id"] == "sess-fail"
|
||||
assert "boom" in outputs.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancelled_error_yields_error_and_reraises(self, block):
|
||||
"""CancelledError should yield error, then re-raise."""
|
||||
block.execute_copilot = AsyncMock(side_effect=asyncio.CancelledError())
|
||||
block.create_session = AsyncMock(return_value="sess-cancel")
|
||||
|
||||
input_data = block.Input(prompt="do something", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["session_id"] == "sess-cancel"
|
||||
assert "cancelled" in outputs.get("error", "").lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_existing_session_id_skips_create(self, block):
|
||||
"""When session_id is provided, create_session should not be called."""
|
||||
mock_result = (
|
||||
"ok",
|
||||
[],
|
||||
"[]",
|
||||
"existing-sid",
|
||||
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
)
|
||||
block.execute_copilot = AsyncMock(return_value=mock_result)
|
||||
block.create_session = AsyncMock()
|
||||
|
||||
input_data = block.Input(
|
||||
prompt="test", session_id="existing-sid", max_recursion_depth=3
|
||||
)
|
||||
ctx = _make_context()
|
||||
async for _ in block.run(input_data, execution_context=ctx):
|
||||
pass
|
||||
|
||||
block.create_session.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Block registration / ID tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBlockRegistration:
|
||||
def test_block_id_matches_constant(self):
|
||||
block = AutoPilotBlock()
|
||||
assert block.id == AUTOPILOT_BLOCK_ID
|
||||
|
||||
def test_max_recursion_depth_has_upper_bound(self):
|
||||
"""Schema should enforce le=10."""
|
||||
schema = AutoPilotBlock.Input.model_json_schema()
|
||||
max_rec = schema["properties"]["max_recursion_depth"]
|
||||
assert (
|
||||
max_rec.get("maximum") == 10 or max_rec.get("exclusiveMaximum", 999) <= 11
|
||||
)
|
||||
|
||||
def test_output_schema_has_no_duplicate_error_field(self):
|
||||
"""Output should inherit error from BlockSchemaOutput, not redefine it."""
|
||||
# The field should exist (inherited) but there should be no explicit
|
||||
# redefinition. We verify by checking the class __annotations__ directly.
|
||||
assert "error" not in AutoPilotBlock.Output.__annotations__
|
||||
@@ -290,7 +290,9 @@ class FillTextTemplateBlock(Block):
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
formatter = text.TextFormatter(autoescape=input_data.escape_html)
|
||||
yield "output", formatter.format_string(input_data.format, input_data.values)
|
||||
yield "output", await formatter.format_string(
|
||||
input_data.format, input_data.values
|
||||
)
|
||||
|
||||
|
||||
class CombineTextsBlock(Block):
|
||||
|
||||
@@ -115,10 +115,22 @@ class ChatConfig(BaseSettings):
|
||||
description="Use --resume for multi-turn conversations instead of "
|
||||
"history compression. Falls back to compression when unavailable.",
|
||||
)
|
||||
use_openrouter: bool = Field(
|
||||
default=True,
|
||||
description="Enable routing API calls through the OpenRouter proxy. "
|
||||
"The actual decision also requires ``api_key`` and ``base_url`` — "
|
||||
"use the ``openrouter_active`` property for the final answer.",
|
||||
)
|
||||
use_claude_code_subscription: bool = Field(
|
||||
default=False,
|
||||
description="For personal/dev use: use Claude Code CLI subscription auth instead of API keys. Requires `claude login` on the host. Only works with SDK mode.",
|
||||
)
|
||||
test_mode: bool = Field(
|
||||
default=False,
|
||||
description="Use dummy service instead of real LLM calls. "
|
||||
"Send __test_transient_error__, __test_fatal_error__, or "
|
||||
"__test_slow_response__ to trigger specific scenarios.",
|
||||
)
|
||||
|
||||
# E2B Sandbox Configuration
|
||||
use_e2b_sandbox: bool = Field(
|
||||
@@ -136,7 +148,7 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=300, # 5 min safety net — explicit per-turn pause is the primary mechanism
|
||||
default=420, # 7 min safety net — allows headroom for compaction retries
|
||||
description="E2B sandbox running-time timeout (seconds). "
|
||||
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
|
||||
"mechanism; this is the safety net.",
|
||||
@@ -146,6 +158,21 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
|
||||
)
|
||||
|
||||
@property
|
||||
def openrouter_active(self) -> bool:
|
||||
"""True when OpenRouter is enabled AND credentials are usable.
|
||||
|
||||
Single source of truth for "will the SDK route through OpenRouter?".
|
||||
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
|
||||
present — mirrors the fallback logic in ``_build_sdk_env``.
|
||||
"""
|
||||
if not self.use_openrouter:
|
||||
return False
|
||||
base = (self.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
return bool(self.api_key and base and base.startswith("http"))
|
||||
|
||||
@property
|
||||
def e2b_active(self) -> bool:
|
||||
"""True when E2B is enabled and the API key is present.
|
||||
@@ -168,15 +195,6 @@ class ChatConfig(BaseSettings):
|
||||
"""
|
||||
return self.e2b_api_key if self.e2b_active else None
|
||||
|
||||
@field_validator("use_e2b_sandbox", mode="before")
|
||||
@classmethod
|
||||
def get_use_e2b_sandbox(cls, v):
|
||||
"""Get use_e2b_sandbox from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_E2B_SANDBOX", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("e2b_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_e2b_api_key(cls, v):
|
||||
@@ -219,26 +237,6 @@ class ChatConfig(BaseSettings):
|
||||
v = OPENROUTER_BASE_URL
|
||||
return v
|
||||
|
||||
@field_validator("use_claude_agent_sdk", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_agent_sdk(cls, v):
|
||||
"""Get use_claude_agent_sdk from environment if not provided."""
|
||||
# Check environment variable - default to True if not set
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
# Default to True (SDK enabled by default)
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("use_claude_code_subscription", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_code_subscription(cls, v):
|
||||
"""Get use_claude_code_subscription from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_CODE_SUBSCRIPTION", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return False if v is None else v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
@@ -248,6 +246,7 @@ class ChatConfig(BaseSettings):
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
env_prefix = "CHAT_"
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore" # Ignore extra environment variables
|
||||
|
||||
@@ -6,19 +6,70 @@ from .config import ChatConfig
|
||||
|
||||
# Env vars that the ChatConfig validators read — must be cleared so they don't
|
||||
# override the explicit constructor values we pass in each test.
|
||||
_E2B_ENV_VARS = (
|
||||
_ENV_VARS_TO_CLEAR = (
|
||||
"CHAT_USE_E2B_SANDBOX",
|
||||
"CHAT_E2B_API_KEY",
|
||||
"E2B_API_KEY",
|
||||
"CHAT_USE_OPENROUTER",
|
||||
"CHAT_API_KEY",
|
||||
"OPEN_ROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"CHAT_BASE_URL",
|
||||
"OPENROUTER_BASE_URL",
|
||||
"OPENAI_BASE_URL",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_e2b_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _E2B_ENV_VARS:
|
||||
def _clean_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _ENV_VARS_TO_CLEAR:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestOpenrouterActive:
|
||||
"""Tests for the openrouter_active property."""
|
||||
|
||||
def test_enabled_with_credentials_returns_true(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is True
|
||||
|
||||
def test_enabled_but_missing_api_key_returns_false(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key=None,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is False
|
||||
|
||||
def test_disabled_returns_false_despite_credentials(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is False
|
||||
|
||||
def test_strips_v1_suffix_and_still_valid(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is True
|
||||
|
||||
def test_invalid_base_url_returns_false(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="not-a-url",
|
||||
)
|
||||
assert cfg.openrouter_active is False
|
||||
|
||||
|
||||
class TestE2BActive:
|
||||
"""Tests for the e2b_active property — single source of truth for E2B usage."""
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@
|
||||
# The hex suffix makes accidental LLM generation of these strings virtually
|
||||
# impossible, avoiding false-positive marker detection in normal conversation.
|
||||
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX = (
|
||||
"[__COPILOT_RETRYABLE_ERROR_a9c2__]" # ErrorCard + retry
|
||||
)
|
||||
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
|
||||
|
||||
# Prefix for all synthetic IDs generated by CoPilot block execution.
|
||||
@@ -35,3 +38,24 @@ def parse_node_id_from_exec_id(node_exec_id: str) -> str:
|
||||
Format: "{node_id}:{random_hex}" → returns "{node_id}".
|
||||
"""
|
||||
return node_exec_id.rsplit(COPILOT_NODE_EXEC_ID_SEPARATOR, 1)[0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transient Anthropic API error detection
|
||||
# ---------------------------------------------------------------------------
|
||||
# Patterns in error text that indicate a transient Anthropic API error
|
||||
# (ECONNRESET / dropped TCP connection) which is retryable.
|
||||
_TRANSIENT_ERROR_PATTERNS = (
|
||||
"socket connection was closed unexpectedly",
|
||||
"ECONNRESET",
|
||||
"connection was forcibly closed",
|
||||
"network socket disconnected",
|
||||
)
|
||||
|
||||
FRIENDLY_TRANSIENT_MSG = "Anthropic connection interrupted — please retry"
|
||||
|
||||
|
||||
def is_transient_api_error(error_text: str) -> bool:
|
||||
"""Return True if *error_text* matches a known transient Anthropic API error."""
|
||||
lower = error_text.lower()
|
||||
return any(pat.lower() in lower for pat in _TRANSIENT_ERROR_PATTERNS)
|
||||
|
||||
@@ -17,8 +17,17 @@ from backend.util.workspace import WorkspaceManager
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
# Allowed base directory for the Read tool.
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
# Allowed base directory for the Read tool. Public so service.py can use it
|
||||
# for sweep operations without depending on a private implementation detail.
|
||||
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
|
||||
# _projects_base() function.
|
||||
_config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects"))
|
||||
|
||||
# Compiled UUID pattern for validating conversation directory names.
|
||||
# Kept as a module-level constant so the security-relevant pattern is easy
|
||||
# to audit in one place and avoids recompilation on every call.
|
||||
_UUID_RE = re.compile(r"^[0-9a-f]{8}(?:-[0-9a-f]{4}){3}-[0-9a-f]{12}$", re.IGNORECASE)
|
||||
|
||||
# Encoded project-directory name for the current session (e.g.
|
||||
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
|
||||
@@ -35,11 +44,20 @@ _current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does."""
|
||||
def encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The Claude CLI encodes the absolute cwd as a directory name by replacing
|
||||
every non-alphanumeric character with ``-``. For example
|
||||
``/tmp/copilot-abc`` becomes ``-tmp-copilot-abc``.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
# Keep the private alias for internal callers (backwards compat).
|
||||
_encode_cwd_for_cli = encode_cwd_for_cli
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
@@ -100,7 +118,9 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``.
|
||||
The SDK nests tool-results under a conversation UUID directory;
|
||||
the UUID segment is validated with ``_UUID_RE``.
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
@@ -119,10 +139,22 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
|
||||
encoded = _current_project_dir.get("")
|
||||
if encoded:
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
if resolved == tool_results_dir or resolved.startswith(
|
||||
tool_results_dir + os.sep
|
||||
):
|
||||
return True
|
||||
project_dir = os.path.realpath(os.path.join(SDK_PROJECTS_DIR, encoded))
|
||||
# Defence-in-depth: ensure project_dir didn't escape the base.
|
||||
if not project_dir.startswith(SDK_PROJECTS_DIR + os.sep):
|
||||
return False
|
||||
# Only allow: <encoded-cwd>/<uuid>/tool-results/<file>
|
||||
# The SDK always creates a conversation UUID directory between
|
||||
# the project dir and tool-results/.
|
||||
if resolved.startswith(project_dir + os.sep):
|
||||
relative = resolved[len(project_dir) + 1 :]
|
||||
parts = relative.split(os.sep)
|
||||
# Require exactly: [<uuid>, "tool-results", <file>, ...]
|
||||
if (
|
||||
len(parts) >= 3
|
||||
and _UUID_RE.match(parts[0])
|
||||
and parts[1] == "tool-results"
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -9,7 +9,7 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import (
|
||||
_SDK_PROJECTS_DIR,
|
||||
SDK_PROJECTS_DIR,
|
||||
_current_project_dir,
|
||||
get_current_sandbox,
|
||||
get_execution_context,
|
||||
@@ -104,11 +104,13 @@ def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
|
||||
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_dir():
|
||||
"""Files under the tool-results directory for the current project are allowed."""
|
||||
def test_is_allowed_local_path_tool_results_with_uuid():
|
||||
"""Files under <encoded-cwd>/<uuid>/tool-results/ are allowed."""
|
||||
encoded = "test-encoded-dir"
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
path = os.path.join(tool_results_dir, "output.txt")
|
||||
conv_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
path = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-results", "output.txt"
|
||||
)
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
@@ -117,10 +119,22 @@ def test_is_allowed_local_path_tool_results_dir():
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_without_uuid_rejected():
|
||||
"""Direct <encoded-cwd>/tool-results/ (no UUID) is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
path = os.path.join(SDK_PROJECTS_DIR, encoded, "tool-results", "output.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert not is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
"""A path adjacent to tool-results/ but not inside it is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||
sibling_path = os.path.join(SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
@@ -129,6 +143,21 @@ def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_valid_uuid_wrong_segment_name_rejected():
|
||||
"""A valid UUID dir but non-'tool-results' second segment is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
uuid_str = "12345678-1234-5678-9abc-def012345678"
|
||||
path = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, uuid_str, "not-tool-results", "output.txt"
|
||||
)
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert not is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -16,6 +16,7 @@ from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
@@ -246,17 +247,25 @@ class CoPilotProcessor:
|
||||
# Choose service based on LaunchDarkly flag.
|
||||
# Claude Code subscription forces SDK mode (CLI subprocess auth).
|
||||
config = ChatConfig()
|
||||
use_sdk = config.use_claude_code_subscription or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
|
||||
if config.test_mode:
|
||||
stream_fn = stream_chat_completion_dummy
|
||||
log.warning("Using DUMMY service (CHAT_TEST_MODE=true)")
|
||||
else:
|
||||
use_sdk = (
|
||||
config.use_claude_code_subscription
|
||||
or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
async for chunk in stream_fn(
|
||||
|
||||
173
autogpt_platform/backend/backend/copilot/integration_creds.py
Normal file
173
autogpt_platform/backend/backend/copilot/integration_creds.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Integration credential lookup with per-process TTL cache.
|
||||
|
||||
Provides token retrieval for connected integrations so that copilot tools
|
||||
(e.g. bash_exec) can inject auth tokens into the execution environment without
|
||||
hitting the database on every command.
|
||||
|
||||
Cache semantics (handled automatically by TTLCache):
|
||||
- Token found → cached for _TOKEN_CACHE_TTL (5 min). Avoids repeated DB hits
|
||||
for users who have credentials and are running many bash commands.
|
||||
- No credentials found → cached for _NULL_CACHE_TTL (60 s). Avoids a DB hit
|
||||
on every E2B command for users who haven't connected an account yet, while
|
||||
still picking up a newly-connected account within one minute.
|
||||
|
||||
Both caches are bounded to _CACHE_MAX_SIZE entries; cachetools evicts the
|
||||
least-recently-used entry when the limit is reached.
|
||||
|
||||
Multi-worker note: both caches are in-process only. Each worker/replica
|
||||
maintains its own independent cache, so a credential fetch may be duplicated
|
||||
across processes. This is acceptable for the current goal (reduce DB hits per
|
||||
session per-process), but if cache efficiency across replicas becomes important
|
||||
a shared cache (e.g. Redis) should be used instead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from backend.copilot.providers import SUPPORTED_PROVIDERS
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
register_creds_changed_hook,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Derived from the single SUPPORTED_PROVIDERS registry for backward compat.
|
||||
PROVIDER_ENV_VARS: dict[str, list[str]] = {
|
||||
slug: entry["env_vars"] for slug, entry in SUPPORTED_PROVIDERS.items()
|
||||
}
|
||||
|
||||
_TOKEN_CACHE_TTL = 300.0 # seconds — for found tokens
|
||||
_NULL_CACHE_TTL = 60.0 # seconds — for "not connected" results
|
||||
_CACHE_MAX_SIZE = 10_000
|
||||
|
||||
# (user_id, provider) → token string. TTLCache handles expiry + eviction.
|
||||
# Thread-safety note: TTLCache is NOT thread-safe, but that is acceptable here
|
||||
# because all callers (get_provider_token, invalidate_user_provider_cache) run
|
||||
# exclusively on the asyncio event loop. There are no await points between a
|
||||
# cache read and its corresponding write within any function, so no concurrent
|
||||
# coroutine can interleave. If ThreadPoolExecutor workers are ever added to
|
||||
# this path, a threading.RLock should be wrapped around these caches.
|
||||
_token_cache: TTLCache[tuple[str, str], str] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_TOKEN_CACHE_TTL
|
||||
)
|
||||
# Separate cache for "no credentials" results with a shorter TTL.
|
||||
_null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
"""Remove the cached entry for *user_id*/*provider* from both caches.
|
||||
|
||||
Call this after storing new credentials so that the next
|
||||
``get_provider_token()`` call performs a fresh DB lookup instead of
|
||||
serving a stale TTL-cached result.
|
||||
"""
|
||||
key = (user_id, provider)
|
||||
_token_cache.pop(key, None)
|
||||
_null_cache.pop(key, None)
|
||||
|
||||
|
||||
# Register this module's cache-bust function with the credentials manager so
|
||||
# that any create/update/delete operation immediately evicts stale cache
|
||||
# entries. This avoids a lazy import inside creds_manager and eliminates the
|
||||
# circular-import risk.
|
||||
try:
|
||||
register_creds_changed_hook(invalidate_user_provider_cache)
|
||||
except RuntimeError:
|
||||
# Hook already registered (e.g. module re-import in tests).
|
||||
pass
|
||||
|
||||
# Module-level singleton to avoid re-instantiating IntegrationCredentialsManager
|
||||
# on every cache-miss call to get_provider_token().
|
||||
_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
"""Return the user's access token for *provider*, or ``None`` if not connected.
|
||||
|
||||
OAuth2 tokens are preferred (refreshed if needed); API keys are the fallback.
|
||||
Found tokens are cached for _TOKEN_CACHE_TTL (5 min). "Not connected" results
|
||||
are cached for _NULL_CACHE_TTL (60 s) to avoid a DB hit on every bash_exec
|
||||
command for users who haven't connected yet, while still picking up a
|
||||
newly-connected account within one minute.
|
||||
"""
|
||||
cache_key = (user_id, provider)
|
||||
|
||||
if cache_key in _null_cache:
|
||||
return None
|
||||
if cached := _token_cache.get(cache_key):
|
||||
return cached
|
||||
|
||||
manager = _manager
|
||||
try:
|
||||
creds_list = await manager.store.get_creds_by_provider(user_id, provider)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch %s credentials for user %s",
|
||||
provider,
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
# Pass 1: prefer OAuth2 (carry scope info, refreshable via token endpoint).
|
||||
# Sort so broader-scoped tokens come first: a token with "repo" scope covers
|
||||
# full git access, while a public-data-only token lacks push/pull permission.
|
||||
# lock=False — background injection; not worth a distributed lock acquisition.
|
||||
oauth2_creds = sorted(
|
||||
[c for c in creds_list if c.type == "oauth2"],
|
||||
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
|
||||
)
|
||||
for creds in oauth2_creds:
|
||||
if creds.type == "oauth2":
|
||||
try:
|
||||
fresh = await manager.refresh_if_needed(
|
||||
user_id, cast(OAuth2Credentials, creds), lock=False
|
||||
)
|
||||
token = fresh.access_token.get_secret_value()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to refresh %s OAuth token for user %s; "
|
||||
"discarding stale token to force re-auth",
|
||||
provider,
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
# Do NOT fall back to the stale token — it is likely expired
|
||||
# or revoked. Returning None forces the caller to re-auth,
|
||||
# preventing the LLM from receiving a non-functional token.
|
||||
continue
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# Pass 2: fall back to API key (no expiry, no refresh needed).
|
||||
for creds in creds_list:
|
||||
if creds.type == "api_key":
|
||||
token = cast(APIKeyCredentials, creds).api_key.get_secret_value()
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# No credentials found — cache to avoid repeated DB hits.
|
||||
_null_cache[cache_key] = True
|
||||
return None
|
||||
|
||||
|
||||
async def get_integration_env_vars(user_id: str) -> dict[str, str]:
|
||||
"""Return env vars for all providers the user has connected.
|
||||
|
||||
Iterates :data:`PROVIDER_ENV_VARS`, fetches each token, and builds a flat
|
||||
``{env_var: token}`` dict ready to pass to a subprocess or E2B sandbox.
|
||||
Only providers with a stored credential contribute entries.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
for provider, var_names in PROVIDER_ENV_VARS.items():
|
||||
token = await get_provider_token(user_id, provider)
|
||||
if token:
|
||||
for var in var_names:
|
||||
env[var] = token
|
||||
return env
|
||||
@@ -0,0 +1,195 @@
|
||||
"""Tests for integration_creds — TTL cache and token lookup paths."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.copilot.integration_creds import (
|
||||
_NULL_CACHE_TTL,
|
||||
_TOKEN_CACHE_TTL,
|
||||
PROVIDER_ENV_VARS,
|
||||
_null_cache,
|
||||
_token_cache,
|
||||
get_integration_env_vars,
|
||||
get_provider_token,
|
||||
invalidate_user_provider_cache,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
|
||||
_USER = "user-integration-creds-test"
|
||||
_PROVIDER = "github"
|
||||
|
||||
|
||||
def _make_api_key_creds(key: str = "test-api-key") -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="creds-api-key",
|
||||
provider=_PROVIDER,
|
||||
api_key=SecretStr(key),
|
||||
title="Test API Key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_oauth2_creds(token: str = "test-oauth-token") -> OAuth2Credentials:
|
||||
return OAuth2Credentials(
|
||||
id="creds-oauth2",
|
||||
provider=_PROVIDER,
|
||||
title="Test OAuth",
|
||||
access_token=SecretStr(token),
|
||||
refresh_token=SecretStr("test-refresh"),
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_caches():
|
||||
"""Ensure clean caches before and after every test."""
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
yield
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
|
||||
|
||||
class TestInvalidateUserProviderCache:
|
||||
def test_removes_token_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_token_cache[key] = "tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _token_cache
|
||||
|
||||
def test_removes_null_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_null_cache[key] = True
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _null_cache
|
||||
|
||||
def test_noop_when_key_not_cached(self):
|
||||
# Should not raise even when there is no cache entry.
|
||||
invalidate_user_provider_cache("no-such-user", _PROVIDER)
|
||||
|
||||
def test_only_removes_targeted_key(self):
|
||||
other_key = ("other-user", _PROVIDER)
|
||||
_token_cache[other_key] = "other-tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert other_key in _token_cache
|
||||
|
||||
|
||||
class TestGetProviderToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_cached_token_without_db_hit(self):
|
||||
_token_cache[(_USER, _PROVIDER)] = "cached-tok"
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "cached-tok"
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_none_for_null_cached_provider(self):
|
||||
_null_cache[(_USER, _PROVIDER)] = True
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_api_key_creds_returned_and_cached(self):
|
||||
api_creds = _make_api_key_creds("my-api-key")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[api_creds])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "my-api-key"
|
||||
assert _token_cache.get((_USER, _PROVIDER)) == "my-api-key"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_preferred_over_api_key(self):
|
||||
oauth_creds = _make_oauth2_creds("oauth-tok")
|
||||
api_creds = _make_api_key_creds("api-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
return_value=[api_creds, oauth_creds]
|
||||
)
|
||||
mock_manager.refresh_if_needed = AsyncMock(return_value=oauth_creds)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_refresh_failure_returns_none(self):
|
||||
"""On refresh failure, return None instead of caching a stale token."""
|
||||
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
|
||||
mock_manager.refresh_if_needed = AsyncMock(side_effect=RuntimeError("network"))
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
# Stale tokens must NOT be returned — forces re-auth.
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_caches_null_entry(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
assert _null_cache.get((_USER, _PROVIDER)) is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_db_exception_returns_none_without_caching(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
side_effect=RuntimeError("db down")
|
||||
)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
# DB errors are not cached — next call will retry
|
||||
assert (_USER, _PROVIDER) not in _token_cache
|
||||
assert (_USER, _PROVIDER) not in _null_cache
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_null_cache_has_shorter_ttl_than_token_cache(self):
|
||||
"""Verify the TTL constants are set correctly for each cache."""
|
||||
assert _null_cache.ttl == _NULL_CACHE_TTL
|
||||
assert _token_cache.ttl == _TOKEN_CACHE_TTL
|
||||
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
|
||||
|
||||
|
||||
class TestGetIntegrationEnvVars:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_injects_all_env_vars_for_provider(self):
|
||||
_token_cache[(_USER, "github")] = "gh-tok"
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
for var in PROVIDER_ENV_VARS["github"]:
|
||||
assert result[var] == "gh-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_dict_when_no_credentials(self):
|
||||
_null_cache[(_USER, "github")] = True
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
assert result == {}
|
||||
@@ -6,10 +6,11 @@ handling the distinction between:
|
||||
- Local mode vs E2B mode (storage/filesystem differences)
|
||||
"""
|
||||
|
||||
from backend.blocks.autopilot import AUTOPILOT_BLOCK_ID
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
# Shared technical notes that apply to both SDK and baseline modes
|
||||
_SHARED_TOOL_NOTES = """\
|
||||
_SHARED_TOOL_NOTES = f"""\
|
||||
|
||||
### Sharing files with the user
|
||||
After saving a file to the persistent workspace with `write_workspace_file`,
|
||||
@@ -81,18 +82,53 @@ that would be corrupted by text encoding.
|
||||
|
||||
Example — committing an image file to GitHub:
|
||||
```json
|
||||
{
|
||||
"files": [{
|
||||
{{
|
||||
"files": [{{
|
||||
"path": "docs/hero.png",
|
||||
"content": "workspace://abc123#image/png",
|
||||
"operation": "upsert"
|
||||
}]
|
||||
}
|
||||
}}]
|
||||
}}
|
||||
```
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
|
||||
### Delegating to another autopilot (sub-autopilot pattern)
|
||||
Use the **AutoPilotBlock** (`run_block` with block_id
|
||||
`{AUTOPILOT_BLOCK_ID}`) to delegate a task to a fresh
|
||||
autopilot instance. The sub-autopilot has its own full tool set and can
|
||||
perform multi-step work autonomously.
|
||||
|
||||
- **Input**: `prompt` (required) — the task description.
|
||||
Optional: `system_context` to constrain behavior, `session_id` to
|
||||
continue a previous conversation, `max_recursion_depth` (default 3).
|
||||
- **Output**: `response` (text), `tool_calls` (list), `session_id`
|
||||
(for continuation), `conversation_history`, `token_usage`.
|
||||
|
||||
Use this when a task is complex enough to benefit from a separate
|
||||
autopilot context, e.g. "research X and write a report" while the
|
||||
parent autopilot handles orchestration.
|
||||
"""
|
||||
|
||||
# E2B-only notes — E2B has full internet access so gh CLI works there.
|
||||
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
|
||||
_E2B_TOOL_NOTES = """
|
||||
### GitHub CLI (`gh`) and git
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
`git` HTTPS operations (clone, push, pull) work automatically.
|
||||
- If the token changes mid-session (e.g. user reconnects with a new token),
|
||||
run `gh auth setup-git` to re-register the credential helper.
|
||||
- If `gh` or `git` fails with an authentication error (e.g. "authentication
|
||||
required", "could not read Username", or exit code 128), call
|
||||
`connect_integration(provider="github")` to surface the GitHub credentials
|
||||
setup card so the user can connect their account. Once connected, retry
|
||||
the operation.
|
||||
- For operations that need broader access (e.g. private org repos, GitHub
|
||||
Actions), pass the required scopes: e.g.
|
||||
`connect_integration(provider="github", scopes=["repo", "read:org"])`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -105,6 +141,7 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: list[str],
|
||||
file_move_name_1_to_2: str,
|
||||
file_move_name_2_to_1: str,
|
||||
extra_notes: str = "",
|
||||
) -> str:
|
||||
"""Build storage/filesystem supplement for a specific environment.
|
||||
|
||||
@@ -119,6 +156,7 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: List of persistence behavior descriptions
|
||||
file_move_name_1_to_2: Direction label for primary→persistent
|
||||
file_move_name_2_to_1: Direction label for persistent→primary
|
||||
extra_notes: Environment-specific notes appended after shared notes
|
||||
"""
|
||||
# Format lists as bullet points with proper indentation
|
||||
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
|
||||
@@ -152,12 +190,23 @@ def _build_storage_supplement(
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
|
||||
### SDK tool-result files
|
||||
When tool outputs are large, the SDK truncates them and saves the full output to
|
||||
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
|
||||
always use `read_file` or `Read` (NOT `read_workspace_file`).
|
||||
`read_workspace_file` reads from cloud workspace storage, where SDK
|
||||
tool-results are NOT stored.
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
|
||||
|
||||
# Pre-built supplements for common environments
|
||||
def _get_local_storage_supplement(cwd: str) -> str:
|
||||
"""Local ephemeral storage (files lost between turns)."""
|
||||
"""Local ephemeral storage (files lost between turns).
|
||||
|
||||
Network is isolated (bubblewrap --unshare-net), so internet-dependent CLIs
|
||||
like gh will not work — no integration env-var notes are included.
|
||||
"""
|
||||
return _build_storage_supplement(
|
||||
working_dir=cwd,
|
||||
sandbox_type="in a network-isolated sandbox",
|
||||
@@ -175,7 +224,11 @@ def _get_local_storage_supplement(cwd: str) -> str:
|
||||
|
||||
|
||||
def _get_cloud_sandbox_supplement() -> str:
|
||||
"""Cloud persistent sandbox (files survive across turns in session)."""
|
||||
"""Cloud persistent sandbox (files survive across turns in session).
|
||||
|
||||
E2B has full internet access, so integration tokens (GH_TOKEN etc.) are
|
||||
injected per command in bash_exec — include the CLI guidance notes.
|
||||
"""
|
||||
return _build_storage_supplement(
|
||||
working_dir="/home/user",
|
||||
sandbox_type="in a cloud sandbox with full internet access",
|
||||
@@ -190,6 +243,7 @@ def _get_cloud_sandbox_supplement() -> str:
|
||||
],
|
||||
file_move_name_1_to_2="Sandbox → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Sandbox",
|
||||
extra_notes=_E2B_TOOL_NOTES,
|
||||
)
|
||||
|
||||
|
||||
|
||||
63
autogpt_platform/backend/backend/copilot/providers.py
Normal file
63
autogpt_platform/backend/backend/copilot/providers.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Single source of truth for copilot-supported integration providers.
|
||||
|
||||
Both :mod:`~backend.copilot.integration_creds` (env-var injection) and
|
||||
:mod:`~backend.copilot.tools.connect_integration` (UI setup card) import from
|
||||
here, eliminating the risk of the two registries drifting out of sync.
|
||||
"""
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class ProviderEntry(TypedDict):
|
||||
"""Metadata for a supported integration provider.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable display name (e.g. "GitHub").
|
||||
env_vars: Environment variable names injected when the provider is
|
||||
connected (e.g. ``["GH_TOKEN", "GITHUB_TOKEN"]``).
|
||||
default_scopes: Default OAuth scopes requested when the agent does not
|
||||
specify any.
|
||||
"""
|
||||
|
||||
name: str
|
||||
env_vars: list[str]
|
||||
default_scopes: list[str]
|
||||
|
||||
|
||||
def _is_github_oauth_configured() -> bool:
|
||||
"""Return True if GitHub OAuth env vars are set.
|
||||
|
||||
Uses a lazy import to avoid triggering ``Secrets()`` during module import,
|
||||
which can fail in environments where secrets are not yet loaded (e.g. tests,
|
||||
CLI tooling).
|
||||
"""
|
||||
from backend.blocks.github._auth import GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
return GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
|
||||
# -- Registry ----------------------------------------------------------------
|
||||
# Add new providers here. Both env-var injection and the setup-card tool read
|
||||
# from this single registry.
|
||||
|
||||
SUPPORTED_PROVIDERS: dict[str, ProviderEntry] = {
|
||||
"github": {
|
||||
"name": "GitHub",
|
||||
"env_vars": ["GH_TOKEN", "GITHUB_TOKEN"],
|
||||
"default_scopes": ["repo"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_provider_auth_types(provider: str) -> list[str]:
|
||||
"""Return the supported credential types for *provider* at runtime.
|
||||
|
||||
OAuth types are only offered when the corresponding OAuth client env vars
|
||||
are configured.
|
||||
"""
|
||||
if provider == "github":
|
||||
if _is_github_oauth_configured():
|
||||
return ["api_key", "oauth2"]
|
||||
return ["api_key"]
|
||||
# Default for unknown/future providers — API key only.
|
||||
return ["api_key"]
|
||||
@@ -43,6 +43,7 @@ class ResponseType(str, Enum):
|
||||
ERROR = "error"
|
||||
USAGE = "usage"
|
||||
HEARTBEAT = "heartbeat"
|
||||
STATUS = "status"
|
||||
|
||||
|
||||
class StreamBaseResponse(BaseModel):
|
||||
@@ -263,3 +264,19 @@ class StreamHeartbeat(StreamBaseResponse):
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE comment format to keep connection alive."""
|
||||
return ": heartbeat\n\n"
|
||||
|
||||
|
||||
class StreamStatus(StreamBaseResponse):
|
||||
"""Transient status notification shown to the user during long operations.
|
||||
|
||||
Used to provide feedback when the backend performs behind-the-scenes work
|
||||
(e.g., compacting conversation context on a retry) that would otherwise
|
||||
leave the user staring at an unexplained pause.
|
||||
|
||||
Sent as a proper ``data:`` event so the frontend can display it to the
|
||||
user. The AI SDK stream parser gracefully skips unknown chunk types
|
||||
(logs a console warning), so this does not break the stream.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.STATUS
|
||||
message: str = Field(..., description="Human-readable status message")
|
||||
|
||||
@@ -19,9 +19,19 @@ least invasive way to break the cycle while keeping module-level constants
|
||||
intact.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
# Static imports for type checkers so they can resolve __all__ entries
|
||||
# without executing the lazy-import machinery at runtime.
|
||||
if TYPE_CHECKING:
|
||||
from .collect import CopilotResult as CopilotResult
|
||||
from .collect import collect_copilot_response as collect_copilot_response
|
||||
from .service import stream_chat_completion_sdk as stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server as create_copilot_mcp_server
|
||||
|
||||
__all__ = [
|
||||
"CopilotResult",
|
||||
"collect_copilot_response",
|
||||
"stream_chat_completion_sdk",
|
||||
"create_copilot_mcp_server",
|
||||
]
|
||||
@@ -29,6 +39,8 @@ __all__ = [
|
||||
# Dispatch table for PEP 562 lazy imports. Each entry is a (module, attr)
|
||||
# pair so new exports can be added without touching __getattr__ itself.
|
||||
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
|
||||
"CopilotResult": (".collect", "CopilotResult"),
|
||||
"collect_copilot_response": (".collect", "collect_copilot_response"),
|
||||
"stream_chat_completion_sdk": (".service", "stream_chat_completion_sdk"),
|
||||
"create_copilot_mcp_server": (".tool_adapter", "create_copilot_mcp_server"),
|
||||
}
|
||||
|
||||
@@ -143,6 +143,71 @@ To use an MCP (Model Context Protocol) tool as a node in the agent:
|
||||
tool_arguments.
|
||||
6. Output: `result` (the tool's return value) and `error` (error message)
|
||||
|
||||
### Using SmartDecisionMakerBlock (AI Orchestrator with Agent Mode)
|
||||
|
||||
To create an agent where AI autonomously decides which tools or sub-agents to
|
||||
call in a loop until the task is complete:
|
||||
1. Create a `SmartDecisionMakerBlock` node
|
||||
(ID: `3b191d9f-356f-482d-8238-ba04b6d18381`)
|
||||
2. Set `input_default`:
|
||||
- `agent_mode_max_iterations`: Choose based on task complexity:
|
||||
- `1` for single-step tool calls (AI picks one tool, calls it, done)
|
||||
- `3`–`10` for multi-step tasks (AI calls tools iteratively)
|
||||
- `-1` for open-ended orchestration (AI loops until it decides it's done).
|
||||
**Use with caution** — prefer bounded iterations (3–10) unless
|
||||
genuinely needed, as unbounded loops risk runaway cost and execution.
|
||||
Do NOT use `0` (traditional mode) — it requires complex external
|
||||
conversation-history loop wiring that the agent generator does not
|
||||
produce.
|
||||
- `conversation_compaction`: `true` (recommended to avoid context overflow)
|
||||
- `retry`: Number of retries on tool-call failure (default `3`).
|
||||
Set to `0` to disable retries.
|
||||
- `multiple_tool_calls`: Whether the AI can invoke multiple tools in a
|
||||
single turn (default `false`). Enable when tools are independent and
|
||||
can run concurrently.
|
||||
- Optional: `sys_prompt` for extra LLM context about how to orchestrate
|
||||
3. Wire the `prompt` input from an `AgentInputBlock` (the user's task)
|
||||
4. Create downstream tool blocks — regular blocks **or** `AgentExecutorBlock`
|
||||
nodes that call sub-agents
|
||||
5. Link each tool to the SmartDecisionMaker: set `source_name: "tools"` on
|
||||
the SmartDecisionMaker side and `sink_name: <input_field>` on each tool
|
||||
block's input. Create one link per input field the tool needs.
|
||||
6. Wire the `finished` output to an `AgentOutputBlock` for the final result
|
||||
7. Credentials (LLM API key) are configured by the user in the platform UI
|
||||
after saving — do NOT require them upfront
|
||||
|
||||
**Example — Orchestrator calling two sub-agents:**
|
||||
- Node 1: `AgentInputBlock` (input_default: `{"name": "task"}`)
|
||||
- Node 2: `SmartDecisionMakerBlock` (input_default:
|
||||
`{"agent_mode_max_iterations": 10, "conversation_compaction": true}`)
|
||||
- Node 3: `AgentExecutorBlock` (sub-agent A — set `graph_id`, `graph_version`,
|
||||
`input_schema`, `output_schema` from library agent)
|
||||
- Node 4: `AgentExecutorBlock` (sub-agent B — same pattern)
|
||||
- Node 5: `AgentOutputBlock` (input_default: `{"name": "result"}`)
|
||||
- Links:
|
||||
- Input→SDM: `source_name: "result"`, `sink_name: "prompt"`
|
||||
- SDM→Agent A (per input field): `source_name: "tools"`,
|
||||
`sink_name: "<agent_a_input_field>"`
|
||||
- SDM→Agent B (per input field): `source_name: "tools"`,
|
||||
`sink_name: "<agent_b_input_field>"`
|
||||
- SDM→Output: `source_name: "finished"`, `sink_name: "value"`
|
||||
|
||||
**Example — Orchestrator calling regular blocks as tools:**
|
||||
- Node 1: `AgentInputBlock` (input_default: `{"name": "task"}`)
|
||||
- Node 2: `SmartDecisionMakerBlock` (input_default:
|
||||
`{"agent_mode_max_iterations": 5, "conversation_compaction": true}`)
|
||||
- Node 3: `GetWebpageBlock` (regular block — the AI calls it as a tool)
|
||||
- Node 4: `AITextGeneratorBlock` (another regular block as a tool)
|
||||
- Node 5: `AgentOutputBlock` (input_default: `{"name": "result"}`)
|
||||
- Links:
|
||||
- Input→SDM: `source_name: "result"`, `sink_name: "prompt"`
|
||||
- SDM→GetWebpage: `source_name: "tools"`, `sink_name: "url"`
|
||||
- SDM→AITextGenerator: `source_name: "tools"`, `sink_name: "prompt"`
|
||||
- SDM→Output: `source_name: "finished"`, `sink_name: "value"`
|
||||
|
||||
Regular blocks work exactly like sub-agents as tools — wire each input
|
||||
field from `source_name: "tools"` on the SmartDecisionMaker side.
|
||||
|
||||
### Example: Simple AI Text Processor
|
||||
|
||||
A minimal agent with input, processing, and output:
|
||||
|
||||
108
autogpt_platform/backend/backend/copilot/sdk/collect.py
Normal file
108
autogpt_platform/backend/backend/copilot/sdk/collect.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Public helpers for consuming a copilot stream as a simple request-response.
|
||||
|
||||
This module exposes :class:`CopilotResult` and :func:`collect_copilot_response`
|
||||
so that callers (e.g. the AutoPilot block) can consume the copilot stream
|
||||
without implementing their own event loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class CopilotResult:
|
||||
"""Aggregated result from consuming a copilot stream.
|
||||
|
||||
Returned by :func:`collect_copilot_response` so callers don't need to
|
||||
implement their own event-loop over the raw stream events.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"response_text",
|
||||
"tool_calls",
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.response_text: str = ""
|
||||
self.tool_calls: list[dict[str, Any]] = []
|
||||
self.prompt_tokens: int = 0
|
||||
self.completion_tokens: int = 0
|
||||
self.total_tokens: int = 0
|
||||
|
||||
|
||||
async def collect_copilot_response(
|
||||
*,
|
||||
session_id: str,
|
||||
message: str,
|
||||
user_id: str,
|
||||
is_user_message: bool = True,
|
||||
) -> CopilotResult:
|
||||
"""Consume :func:`stream_chat_completion_sdk` and return aggregated results.
|
||||
|
||||
This is the recommended entry-point for callers that need a simple
|
||||
request-response interface (e.g. the AutoPilot block) rather than
|
||||
streaming individual events. It avoids duplicating the event-collection
|
||||
logic and does NOT wrap the stream in ``asyncio.timeout`` — the SDK
|
||||
manages its own heartbeat-based timeouts internally.
|
||||
|
||||
Args:
|
||||
session_id: Chat session to use.
|
||||
message: The user message / prompt.
|
||||
user_id: Authenticated user ID.
|
||||
is_user_message: Whether this is a user-initiated message.
|
||||
|
||||
Returns:
|
||||
A :class:`CopilotResult` with the aggregated response text,
|
||||
tool calls, and token usage.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the stream yields a ``StreamError`` event.
|
||||
"""
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
from .service import stream_chat_completion_sdk
|
||||
|
||||
result = CopilotResult()
|
||||
response_parts: list[str] = []
|
||||
tool_calls_by_id: dict[str, dict[str, Any]] = {}
|
||||
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
):
|
||||
if isinstance(event, StreamTextDelta):
|
||||
response_parts.append(event.delta)
|
||||
elif isinstance(event, StreamToolInputAvailable):
|
||||
entry: dict[str, Any] = {
|
||||
"tool_call_id": event.toolCallId,
|
||||
"tool_name": event.toolName,
|
||||
"input": event.input,
|
||||
"output": None,
|
||||
"success": None,
|
||||
}
|
||||
result.tool_calls.append(entry)
|
||||
tool_calls_by_id[event.toolCallId] = entry
|
||||
elif isinstance(event, StreamToolOutputAvailable):
|
||||
if tc := tool_calls_by_id.get(event.toolCallId):
|
||||
tc["output"] = event.output
|
||||
tc["success"] = event.success
|
||||
elif isinstance(event, StreamUsage):
|
||||
result.prompt_tokens += event.prompt_tokens
|
||||
result.completion_tokens += event.completion_tokens
|
||||
result.total_tokens += event.total_tokens
|
||||
elif isinstance(event, StreamError):
|
||||
raise RuntimeError(f"Copilot error: {event.errorText}")
|
||||
|
||||
result.response_text = "".join(response_parts)
|
||||
return result
|
||||
@@ -12,6 +12,7 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
||||
from ..model import ChatMessage, ChatSession
|
||||
@@ -119,14 +120,12 @@ def filter_compaction_messages(
|
||||
filtered: list[ChatMessage] = []
|
||||
for msg in messages:
|
||||
if msg.role == "assistant" and msg.tool_calls:
|
||||
real_calls: list[dict[str, Any]] = []
|
||||
for tc in msg.tool_calls:
|
||||
if tc.get("function", {}).get("name") == COMPACTION_TOOL_NAME:
|
||||
compaction_ids.add(tc.get("id", ""))
|
||||
real_calls = [
|
||||
tc
|
||||
for tc in msg.tool_calls
|
||||
if tc.get("function", {}).get("name") != COMPACTION_TOOL_NAME
|
||||
]
|
||||
else:
|
||||
real_calls.append(tc)
|
||||
if not real_calls and not msg.content:
|
||||
continue
|
||||
if msg.role == "tool" and msg.tool_call_id in compaction_ids:
|
||||
@@ -222,6 +221,7 @@ class CompactionTracker:
|
||||
|
||||
def reset_for_query(self) -> None:
|
||||
"""Reset per-query state before a new SDK query."""
|
||||
self._compact_start.clear()
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
|
||||
54
autogpt_platform/backend/backend/copilot/sdk/conftest.py
Normal file
54
autogpt_platform/backend/backend/copilot/sdk/conftest.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Shared test fixtures for copilot SDK tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_chat_config():
|
||||
"""Mock ChatConfig so compact_transcript tests skip real config lookup."""
|
||||
with patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type("Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"})(),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def build_test_transcript(pairs: list[tuple[str, str]]) -> str:
|
||||
"""Build a minimal valid JSONL transcript from (role, content) pairs.
|
||||
|
||||
Use this helper in any copilot SDK test that needs a well-formed
|
||||
transcript without hitting the real storage layer.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
last_uuid: str | None = None
|
||||
for role, content in pairs:
|
||||
uid = str(uuid4())
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
msg: dict = {"role": role, "content": content}
|
||||
if role == "assistant":
|
||||
msg.update(
|
||||
{
|
||||
"model": "",
|
||||
"id": f"msg_{uid[:8]}",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
)
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": last_uuid,
|
||||
"message": msg,
|
||||
}
|
||||
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
last_uuid = uid
|
||||
return "\n".join(lines) + "\n"
|
||||
@@ -1,9 +1,17 @@
|
||||
"""Dummy SDK service for testing copilot streaming.
|
||||
|
||||
Returns mock streaming responses without calling Claude Agent SDK.
|
||||
Enable via COPILOT_TEST_MODE=true environment variable.
|
||||
Enable via CHAT_TEST_MODE=true in .env (ChatConfig.test_mode).
|
||||
|
||||
WARNING: This is for testing only. Do not use in production.
|
||||
|
||||
Magic keywords (case-insensitive, anywhere in message):
|
||||
__test_transient_error__ — Simulate a transient Anthropic API error
|
||||
(ECONNRESET). Streams partial text, then
|
||||
yields StreamError with retryable prefix.
|
||||
__test_fatal_error__ — Simulate a non-retryable SDK error.
|
||||
__test_slow_response__ — Simulate a slow response (2s per word).
|
||||
(no keyword) — Normal dummy response.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -12,12 +20,39 @@ import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from ..model import ChatSession
|
||||
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
|
||||
from ..constants import (
|
||||
COPILOT_ERROR_PREFIX,
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX,
|
||||
FRIENDLY_TRANSIENT_MSG,
|
||||
)
|
||||
from ..model import ChatMessage, ChatSession, get_chat_session, upsert_chat_session
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _safe_upsert(session: ChatSession) -> None:
|
||||
"""Best-effort session persist — skip silently if DB is unavailable."""
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception:
|
||||
logger.debug("[TEST MODE] Could not persist session (DB unavailable)")
|
||||
|
||||
|
||||
def _has_keyword(message: str | None, keyword: str) -> bool:
|
||||
return keyword in (message or "").lower()
|
||||
|
||||
|
||||
async def stream_chat_completion_dummy(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
@@ -36,24 +71,89 @@ async def stream_chat_completion_dummy(
|
||||
- No timeout occurs
|
||||
- Text arrives in chunks
|
||||
- StreamFinish is sent by mark_session_completed
|
||||
|
||||
See module docstring for magic keywords that trigger error scenarios.
|
||||
"""
|
||||
logger.warning(
|
||||
f"[TEST MODE] Using dummy copilot streaming for session {session_id}"
|
||||
)
|
||||
|
||||
# Load session from DB (matches SDK service behaviour) so error markers
|
||||
# and the assistant reply are persisted and survive page refresh.
|
||||
# Best-effort: skip if DB is unavailable (e.g. unit tests).
|
||||
if session is None:
|
||||
try:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
except Exception:
|
||||
logger.debug("[TEST MODE] Could not load session (DB unavailable)")
|
||||
session = None
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
# Start the stream
|
||||
# Start the stream (matches baseline: StreamStart → StreamStartStep)
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
yield StreamStartStep()
|
||||
|
||||
# Simulate streaming text response with delays
|
||||
# --- Magic keyword: transient error (retryable) -------------------------
|
||||
if _has_keyword(message, "__test_transient_error__"):
|
||||
# Stream some partial text first (simulates mid-stream failure)
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
for word in ["Working", "on", "it..."]:
|
||||
yield StreamTextDelta(id=text_block_id, delta=f"{word} ")
|
||||
await asyncio.sleep(0.1)
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
yield StreamFinishStep()
|
||||
# Persist retryable marker so "Try Again" button shows after refresh
|
||||
if session:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=f"{COPILOT_RETRYABLE_ERROR_PREFIX} {FRIENDLY_TRANSIENT_MSG}",
|
||||
)
|
||||
)
|
||||
await _safe_upsert(session)
|
||||
yield StreamError(
|
||||
errorText=FRIENDLY_TRANSIENT_MSG,
|
||||
code="transient_api_error",
|
||||
)
|
||||
return
|
||||
|
||||
# --- Magic keyword: fatal error (non-retryable) -------------------------
|
||||
if _has_keyword(message, "__test_fatal_error__"):
|
||||
yield StreamFinishStep()
|
||||
error_msg = "Internal SDK error: model refused to respond"
|
||||
# Persist non-retryable error marker
|
||||
if session:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=f"{COPILOT_ERROR_PREFIX} {error_msg}",
|
||||
)
|
||||
)
|
||||
await _safe_upsert(session)
|
||||
yield StreamError(errorText=error_msg, code="sdk_error")
|
||||
return
|
||||
|
||||
# --- Magic keyword: slow response ---------------------------------------
|
||||
delay = 2.0 if _has_keyword(message, "__test_slow_response__") else 0.1
|
||||
|
||||
# --- Normal dummy response ----------------------------------------------
|
||||
dummy_response = "I counted: 1... 2... 3. All done!"
|
||||
words = dummy_response.split()
|
||||
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
for i, word in enumerate(words):
|
||||
# Add space except for last word
|
||||
text = word if i == len(words) - 1 else f"{word} "
|
||||
yield StreamTextDelta(id=text_block_id, delta=text)
|
||||
# Small delay to simulate real streaming
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(delay)
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
|
||||
# Persist the assistant reply so it survives page refresh
|
||||
if session:
|
||||
session.messages.append(ChatMessage(role="assistant", content=dummy_response))
|
||||
await _safe_upsert(session)
|
||||
|
||||
yield StreamFinishStep()
|
||||
yield StreamFinish()
|
||||
|
||||
@@ -26,6 +26,41 @@ from backend.copilot.context import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _check_sandbox_symlink_escape(
|
||||
sandbox: Any,
|
||||
parent: str,
|
||||
) -> str | None:
|
||||
"""Resolve the canonical parent path inside the sandbox to detect symlink escapes.
|
||||
|
||||
``normpath`` (used by ``resolve_sandbox_path``) only normalises the string;
|
||||
``readlink -f`` follows actual symlinks on the sandbox filesystem.
|
||||
|
||||
Returns the canonical parent path, or ``None`` if the path escapes
|
||||
``E2B_WORKDIR``.
|
||||
|
||||
Note: There is an inherent TOCTOU window between this check and the
|
||||
subsequent ``sandbox.files.write()``. A symlink could theoretically be
|
||||
replaced between the two operations. This is acceptable in the E2B
|
||||
sandbox model since the sandbox is single-user and ephemeral.
|
||||
"""
|
||||
canonical_res = await sandbox.commands.run(
|
||||
f"readlink -f {shlex.quote(parent or E2B_WORKDIR)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=5,
|
||||
)
|
||||
canonical_parent = (canonical_res.stdout or "").strip()
|
||||
if (
|
||||
canonical_res.exit_code != 0
|
||||
or not canonical_parent
|
||||
or (
|
||||
canonical_parent != E2B_WORKDIR
|
||||
and not canonical_parent.startswith(E2B_WORKDIR + "/")
|
||||
)
|
||||
):
|
||||
return None
|
||||
return canonical_parent
|
||||
|
||||
|
||||
def _get_sandbox():
|
||||
return get_current_sandbox()
|
||||
|
||||
@@ -106,6 +141,10 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
parent = os.path.dirname(remote)
|
||||
if parent and parent != E2B_WORKDIR:
|
||||
await sandbox.files.make_dir(parent)
|
||||
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
|
||||
if canonical_parent is None:
|
||||
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
|
||||
remote = os.path.join(canonical_parent, os.path.basename(remote))
|
||||
await sandbox.files.write(remote, content)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
@@ -130,6 +169,12 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
parent = os.path.dirname(remote)
|
||||
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
|
||||
if canonical_parent is None:
|
||||
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
|
||||
remote = os.path.join(canonical_parent, os.path.basename(remote))
|
||||
|
||||
try:
|
||||
raw: bytes = await sandbox.files.read(remote, format="bytes")
|
||||
content = raw.decode("utf-8", errors="replace")
|
||||
|
||||
@@ -4,15 +4,19 @@ Pure unit tests with no external dependencies (no E2B, no sandbox).
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import _current_project_dir
|
||||
|
||||
from .e2b_file_tools import _read_local, resolve_sandbox_path
|
||||
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_project_dir
|
||||
|
||||
from .e2b_file_tools import (
|
||||
_check_sandbox_symlink_escape,
|
||||
_read_local,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
|
||||
@@ -21,46 +25,48 @@ _SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
class TestResolveSandboxPath:
|
||||
def test_relative_path_resolved(self):
|
||||
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
|
||||
assert resolve_sandbox_path("src/main.py") == f"{E2B_WORKDIR}/src/main.py"
|
||||
|
||||
def test_absolute_within_sandbox(self):
|
||||
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
|
||||
assert (
|
||||
resolve_sandbox_path(f"{E2B_WORKDIR}/file.txt") == f"{E2B_WORKDIR}/file.txt"
|
||||
)
|
||||
|
||||
def test_workdir_itself(self):
|
||||
assert resolve_sandbox_path("/home/user") == "/home/user"
|
||||
assert resolve_sandbox_path(E2B_WORKDIR) == E2B_WORKDIR
|
||||
|
||||
def test_relative_dotslash(self):
|
||||
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
|
||||
assert resolve_sandbox_path("./README.md") == f"{E2B_WORKDIR}/README.md"
|
||||
|
||||
def test_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path("../../etc/passwd")
|
||||
|
||||
def test_absolute_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path(f"{E2B_WORKDIR}/../../etc/passwd")
|
||||
|
||||
def test_absolute_outside_sandbox_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
|
||||
def test_root_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path("/")
|
||||
|
||||
def test_home_other_user_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path("/home/other/file.txt")
|
||||
|
||||
def test_deep_nested_allowed(self):
|
||||
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||
assert resolve_sandbox_path("a/b/c/d/e.txt") == f"{E2B_WORKDIR}/a/b/c/d/e.txt"
|
||||
|
||||
def test_trailing_slash_normalised(self):
|
||||
assert resolve_sandbox_path("src/") == "/home/user/src"
|
||||
assert resolve_sandbox_path("src/") == f"{E2B_WORKDIR}/src"
|
||||
|
||||
def test_double_dots_within_sandbox_ok(self):
|
||||
"""Path that resolves back within /home/user is allowed."""
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||
"""Path that resolves back within E2B_WORKDIR is allowed."""
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == f"{E2B_WORKDIR}/a/c.txt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -73,9 +79,13 @@ class TestResolveSandboxPath:
|
||||
|
||||
|
||||
class TestReadLocal:
|
||||
_CONV_UUID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
|
||||
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
|
||||
"""Create a tool-results file and return its path."""
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
"""Create a tool-results file under <encoded>/<uuid>/tool-results/."""
|
||||
tool_results_dir = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
|
||||
)
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, filename)
|
||||
with open(filepath, "w") as f:
|
||||
@@ -107,7 +117,9 @@ class TestReadLocal:
|
||||
def test_read_nonexistent_tool_results(self):
|
||||
"""A tool-results path that doesn't exist returns FileNotFoundError."""
|
||||
encoded = "-tmp-copilot-e2b-test-nofile"
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
tool_results_dir = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
|
||||
)
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
|
||||
token = _current_project_dir.set(encoded)
|
||||
@@ -117,7 +129,7 @@ class TestReadLocal:
|
||||
assert "not found" in result["content"][0]["text"].lower()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
os.rmdir(tool_results_dir)
|
||||
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
|
||||
|
||||
def test_read_traversal_path_blocked(self):
|
||||
"""A traversal attempt that escapes allowed directories is blocked."""
|
||||
@@ -152,3 +164,66 @@ class TestReadLocal:
|
||||
"""Without _current_project_dir set, all paths are blocked."""
|
||||
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _check_sandbox_symlink_escape — symlink escape detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_sandbox(stdout: str, exit_code: int = 0) -> SimpleNamespace:
|
||||
"""Build a minimal sandbox mock whose commands.run returns a fixed result."""
|
||||
run_result = SimpleNamespace(stdout=stdout, exit_code=exit_code)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
return SimpleNamespace(commands=commands)
|
||||
|
||||
|
||||
class TestCheckSandboxSymlinkEscape:
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonical_path_within_workdir_returns_path(self):
|
||||
"""When readlink -f resolves to a path inside E2B_WORKDIR, returns it."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/src\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result == f"{E2B_WORKDIR}/src"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workdir_itself_returns_workdir(self):
|
||||
"""When readlink -f resolves to E2B_WORKDIR exactly, returns E2B_WORKDIR."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, E2B_WORKDIR)
|
||||
assert result == E2B_WORKDIR
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_symlink_escape_returns_none(self):
|
||||
"""When readlink -f resolves outside E2B_WORKDIR (symlink escape), returns None."""
|
||||
sandbox = _make_sandbox(stdout="/etc\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/evil")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonzero_exit_code_returns_none(self):
|
||||
"""A non-zero exit code from readlink -f returns None."""
|
||||
sandbox = _make_sandbox(stdout="", exit_code=1)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_stdout_returns_none(self):
|
||||
"""Empty stdout from readlink (e.g. path doesn't exist yet) returns None."""
|
||||
sandbox = _make_sandbox(stdout="", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prefix_collision_returns_none(self):
|
||||
"""A path prefixed with E2B_WORKDIR but not within it is rejected."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}-evil\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}-evil")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deeply_nested_path_within_workdir(self):
|
||||
"""Deep nested paths inside E2B_WORKDIR are allowed."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/a/b/c/d\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/a/b/c/d")
|
||||
assert result == f"{E2B_WORKDIR}/a/b/c/d"
|
||||
|
||||
@@ -0,0 +1,651 @@
|
||||
"""Tests for retry logic and transcript compaction helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _friendly_error_text, _is_prompt_too_long
|
||||
from .transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_run_compression,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_assistant_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenAssistantContent:
|
||||
def test_text_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": "World"},
|
||||
]
|
||||
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
|
||||
|
||||
def test_tool_use_blocks(self):
|
||||
blocks = [{"type": "tool_use", "name": "read_file", "input": {}}]
|
||||
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
|
||||
|
||||
def test_mixed_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Let me read that."},
|
||||
{"type": "tool_use", "name": "Read", "input": {"path": "/foo"}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "Let me read that." in result
|
||||
assert "[tool_use: Read]" in result
|
||||
|
||||
def test_raw_strings(self):
|
||||
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
|
||||
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "See this image:"},
|
||||
{"type": "image", "source": {"type": "base64", "data": "..."}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "See this image:" in result
|
||||
assert "[__image__]" in result
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_assistant_content([]) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_tool_result_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenToolResultContent:
|
||||
def test_tool_result_with_text(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "123",
|
||||
"content": [{"type": "text", "text": "file contents here"}],
|
||||
}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "file contents here"
|
||||
|
||||
def test_tool_result_with_string_content(self):
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "123", "content": "ok"}]
|
||||
assert _flatten_tool_result_content(blocks) == "ok"
|
||||
|
||||
def test_text_block(self):
|
||||
blocks = [{"type": "text", "text": "plain text"}]
|
||||
assert _flatten_tool_result_content(blocks) == "plain text"
|
||||
|
||||
def test_raw_string(self):
|
||||
assert _flatten_tool_result_content(["raw"]) == "raw"
|
||||
|
||||
def test_tool_result_with_none_content(self):
|
||||
"""tool_result with content=None should produce empty string."""
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": None}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_tool_result_with_empty_list_content(self):
|
||||
"""tool_result with content=[] should produce empty string."""
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": []}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_tool_result_content([]) == ""
|
||||
|
||||
def test_nested_dict_without_text(self):
|
||||
"""Dict blocks without text key use json.dumps fallback."""
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "x",
|
||||
"content": [{"type": "image", "source": "data:..."}],
|
||||
}
|
||||
]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "image" in result # json.dumps fallback
|
||||
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
blocks = [{"type": "image", "source": {"type": "base64", "data": "..."}}]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "[__image__]" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _transcript_to_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_entry(entry_type: str, role: str, content: str | list, **kwargs) -> str:
|
||||
"""Build a JSONL line for testing."""
|
||||
uid = str(uuid4())
|
||||
msg: dict = {"role": role, "content": content}
|
||||
msg.update(kwargs)
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": None,
|
||||
"message": msg,
|
||||
}
|
||||
return json.dumps(entry, separators=(",", ":"))
|
||||
|
||||
|
||||
class TestTranscriptToMessages:
|
||||
def test_basic_roundtrip(self):
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert messages[0] == {"role": "user", "content": "Hello"}
|
||||
assert messages[1] == {"role": "assistant", "content": "Hi"}
|
||||
|
||||
def test_skips_strippable_types(self):
|
||||
"""Progress and metadata entries are excluded."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "progress",
|
||||
"uuid": str(uuid4()),
|
||||
"parentUuid": None,
|
||||
"message": {"role": "assistant", "content": "..."},
|
||||
}
|
||||
),
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_empty_content(self):
|
||||
assert _transcript_to_messages("") == []
|
||||
|
||||
def test_tool_result_content(self):
|
||||
"""User entries with tool_result content blocks are flattened."""
|
||||
lines = [
|
||||
_make_entry(
|
||||
"user",
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "123",
|
||||
"content": "tool output",
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "tool output"
|
||||
|
||||
def test_malformed_json_lines_skipped(self):
|
||||
"""Malformed JSON lines in transcript are silently skipped."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
"this is not valid json",
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_empty_lines_skipped(self):
|
||||
"""Empty lines and whitespace-only lines are skipped."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
"",
|
||||
" ",
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_unicode_content_preserved(self):
|
||||
"""Unicode characters survive transcript roundtrip."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello 你好 🌍"),
|
||||
_make_entry(
|
||||
"assistant",
|
||||
"assistant",
|
||||
[{"type": "text", "text": "Bonjour 日本語 émojis 🎉"}],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert messages[0]["content"] == "Hello 你好 🌍"
|
||||
assert messages[1]["content"] == "Bonjour 日本語 émojis 🎉"
|
||||
|
||||
def test_entry_without_role_skipped(self):
|
||||
"""Entries with missing role in message are skipped."""
|
||||
entry_no_role = json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": str(uuid4()),
|
||||
"parentUuid": None,
|
||||
"message": {"content": "no role here"},
|
||||
}
|
||||
)
|
||||
lines = [
|
||||
entry_no_role,
|
||||
_make_entry("user", "user", "Hello"),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "Hello"
|
||||
|
||||
def test_tool_use_and_result_pairs(self):
|
||||
"""Tool use + tool result pairs are properly flattened."""
|
||||
lines = [
|
||||
_make_entry(
|
||||
"assistant",
|
||||
"assistant",
|
||||
[
|
||||
{"type": "text", "text": "Let me check."},
|
||||
{"type": "tool_use", "name": "read_file", "input": {"path": "/x"}},
|
||||
],
|
||||
),
|
||||
_make_entry(
|
||||
"user",
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "abc",
|
||||
"content": [{"type": "text", "text": "file contents"}],
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert "Let me check." in messages[0]["content"]
|
||||
assert "[tool_use: read_file]" in messages[0]["content"]
|
||||
assert messages[1]["content"] == "file contents"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _messages_to_transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessagesToTranscript:
|
||||
def test_produces_valid_jsonl(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = result.strip().split("\n")
|
||||
assert len(lines) == 2
|
||||
for line in lines:
|
||||
parsed = json.loads(line)
|
||||
assert "type" in parsed
|
||||
assert "uuid" in parsed
|
||||
assert "message" in parsed
|
||||
|
||||
def test_assistant_has_proper_structure(self):
|
||||
messages = [{"role": "assistant", "content": "Hello"}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["type"] == "assistant"
|
||||
msg = entry["message"]
|
||||
assert msg["role"] == "assistant"
|
||||
assert msg["type"] == "message"
|
||||
assert msg["stop_reason"] == "end_turn"
|
||||
assert isinstance(msg["content"], list)
|
||||
assert msg["content"][0]["type"] == "text"
|
||||
|
||||
def test_user_has_plain_content(self):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["type"] == "user"
|
||||
assert entry["message"]["content"] == "Hi"
|
||||
|
||||
def test_parent_uuid_chain(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "A"},
|
||||
{"role": "assistant", "content": "B"},
|
||||
{"role": "user", "content": "C"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = result.strip().split("\n")
|
||||
entries = [json.loads(line) for line in lines]
|
||||
assert entries[0]["parentUuid"] == ""
|
||||
assert entries[1]["parentUuid"] == entries[0]["uuid"]
|
||||
assert entries[2]["parentUuid"] == entries[1]["uuid"]
|
||||
|
||||
def test_empty_messages(self):
|
||||
assert _messages_to_transcript([]) == ""
|
||||
|
||||
def test_output_is_valid_transcript(self):
|
||||
"""Output should pass validate_transcript if it has assistant entries."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert validate_transcript(result)
|
||||
|
||||
def test_roundtrip_to_messages(self):
|
||||
"""Messages → transcript → messages preserves structure."""
|
||||
original = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
transcript = _messages_to_transcript(original)
|
||||
restored = _transcript_to_messages(transcript)
|
||||
assert len(restored) == len(original)
|
||||
for orig, rest in zip(original, restored):
|
||||
assert orig["role"] == rest["role"]
|
||||
assert orig["content"] == rest["content"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compact_transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactTranscript:
|
||||
@pytest.mark.asyncio
|
||||
async def test_too_few_messages_returns_none(self, mock_chat_config):
|
||||
"""compact_transcript returns None when transcript has < 2 messages."""
|
||||
transcript = _build_transcript([("user", "Hello")])
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_compacted(self, mock_chat_config):
|
||||
"""When compress_context says no compaction needed, returns None.
|
||||
The compressor couldn't reduce it, so retrying with the same
|
||||
content would fail identically."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi there"),
|
||||
]
|
||||
)
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": False,
|
||||
"messages": [],
|
||||
"original_token_count": 100,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 0,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_compacted_transcript(self, mock_chat_config):
|
||||
"""When compaction succeeds, returns a valid compacted transcript."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi"),
|
||||
("user", "More"),
|
||||
("assistant", "Details"),
|
||||
]
|
||||
)
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "[summary]"},
|
||||
{"role": "assistant", "content": "Summarized response"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 500,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 2,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[1]["content"] == "Summarized response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_compression_failure(self, mock_chat_config):
|
||||
"""When _run_compression raises, returns None."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi"),
|
||||
]
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_prompt_too_long
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsPromptTooLong:
|
||||
"""Unit tests for _is_prompt_too_long pattern matching."""
|
||||
|
||||
def test_prompt_is_too_long(self):
|
||||
err = RuntimeError("prompt is too long for model context")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_request_too_large(self):
|
||||
err = Exception("request too large: 250000 tokens")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_maximum_context_length(self):
|
||||
err = ValueError("maximum context length exceeded")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_context_length_exceeded(self):
|
||||
err = Exception("context_length_exceeded")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_input_tokens_exceed(self):
|
||||
err = Exception("input tokens exceed the max_tokens limit")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_input_is_too_long(self):
|
||||
err = Exception("input is too long for the model")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_content_length_exceeds(self):
|
||||
err = Exception("content length exceeds maximum")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_unrelated_error_returns_false(self):
|
||||
err = RuntimeError("network timeout")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
def test_auth_error_returns_false(self):
|
||||
err = Exception("authentication failed: invalid API key")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
def test_chained_exception_detected(self):
|
||||
"""Prompt-too-long error wrapped in another exception is detected."""
|
||||
inner = RuntimeError("prompt is too long")
|
||||
outer = Exception("SDK error")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_case_insensitive(self):
|
||||
err = Exception("PROMPT IS TOO LONG")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_old_max_tokens_exceeded_not_matched(self):
|
||||
"""The old broad 'max_tokens_exceeded' pattern was removed.
|
||||
Only 'input tokens exceed' should match now."""
|
||||
err = Exception("max_tokens_exceeded")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_compression timeout fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunCompressionTimeout:
|
||||
"""Verify _run_compression falls back to truncation when LLM times out."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_falls_back_to_truncation(self):
|
||||
"""When compress_context with LLM client times out,
|
||||
_run_compression falls back to truncation (client=None)."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
truncation_result = CompressResult(
|
||||
messages=messages,
|
||||
was_compacted=False,
|
||||
original_token_count=50,
|
||||
token_count=50,
|
||||
messages_summarized=0,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
call_args: list[dict] = []
|
||||
|
||||
async def _mock_compress(**kwargs):
|
||||
call_args.append(kwargs)
|
||||
if kwargs.get("client") is not None:
|
||||
# Simulate timeout by raising asyncio.TimeoutError
|
||||
raise asyncio.TimeoutError("LLM compaction timed out")
|
||||
return truncation_result
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value="fake-client",
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
side_effect=_mock_compress,
|
||||
),
|
||||
):
|
||||
result = await _run_compression(messages, "test-model", "[test]")
|
||||
|
||||
assert result == truncation_result
|
||||
# Should have been called twice: once with client, once without
|
||||
assert len(call_args) == 2
|
||||
assert call_args[0]["client"] is not None # LLM attempt
|
||||
assert call_args[1]["client"] is None # truncation fallback
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_client_uses_truncation_directly(self):
|
||||
"""When no OpenAI client is configured, goes straight to truncation."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
truncation_result = CompressResult(
|
||||
messages=messages,
|
||||
was_compacted=False,
|
||||
original_token_count=50,
|
||||
token_count=50,
|
||||
messages_summarized=0,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
):
|
||||
result = await _run_compression(messages, "test-model", "[test]")
|
||||
|
||||
assert result == truncation_result
|
||||
mock_compress.assert_called_once()
|
||||
# When no client, compress_context is called with client=None
|
||||
assert mock_compress.call_args.kwargs.get("client") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _friendly_error_text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFriendlyErrorText:
|
||||
"""Verify user-friendly error message mapping."""
|
||||
|
||||
def test_authentication_error(self):
|
||||
result = _friendly_error_text("authentication failed: invalid API key")
|
||||
assert "Authentication" in result
|
||||
assert "API key" in result
|
||||
|
||||
def test_rate_limit_error(self):
|
||||
result = _friendly_error_text("rate limit exceeded")
|
||||
assert "Rate limit" in result
|
||||
|
||||
def test_overloaded_error(self):
|
||||
result = _friendly_error_text("API is overloaded")
|
||||
assert "overloaded" in result
|
||||
|
||||
def test_timeout_error(self):
|
||||
result = _friendly_error_text("Request timeout after 30s")
|
||||
assert "timed out" in result
|
||||
|
||||
def test_connection_error(self):
|
||||
result = _friendly_error_text("Connection refused")
|
||||
assert "Connection" in result or "connection" in result
|
||||
|
||||
def test_unknown_error_passthrough(self):
|
||||
result = _friendly_error_text("some unknown error XYZ")
|
||||
assert "SDK stream error:" in result
|
||||
assert "XYZ" in result
|
||||
|
||||
def test_unauthorized_error(self):
|
||||
result = _friendly_error_text("401 Unauthorized")
|
||||
assert "Authentication" in result
|
||||
@@ -20,6 +20,7 @@ from claude_agent_sdk import (
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from backend.copilot.constants import FRIENDLY_TRANSIENT_MSG, is_transient_api_error
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
@@ -214,10 +215,12 @@ class SDKResponseAdapter:
|
||||
if sdk_message.subtype == "success":
|
||||
responses.append(StreamFinish())
|
||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
||||
error_msg = sdk_message.result or "Unknown error"
|
||||
responses.append(
|
||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
||||
)
|
||||
raw_error = str(sdk_message.result or "Unknown error")
|
||||
if is_transient_api_error(raw_error):
|
||||
error_text, code = FRIENDLY_TRANSIENT_MSG, "transient_api_error"
|
||||
else:
|
||||
error_text, code = raw_error, "sdk_error"
|
||||
responses.append(StreamError(errorText=error_text, code=code))
|
||||
responses.append(StreamFinish())
|
||||
else:
|
||||
logger.warning(
|
||||
|
||||
1410
autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py
Normal file
1410
autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -42,7 +42,7 @@ def _validate_workspace_path(
|
||||
Delegates to :func:`is_allowed_local_path` which permits:
|
||||
- The SDK working directory (``/tmp/copilot-<session>/``)
|
||||
- The current session's tool-results directory
|
||||
(``~/.claude/projects/<encoded-cwd>/tool-results/``)
|
||||
(``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/``)
|
||||
"""
|
||||
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
||||
if not path:
|
||||
@@ -302,7 +302,11 @@ def create_security_hooks(
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
# Sanitize untrusted input before logging to prevent log injection
|
||||
# Sanitize untrusted input: strip control chars for logging AND
|
||||
# for the value passed downstream. read_compacted_entries()
|
||||
# validates against _projects_base() as defence-in-depth, but
|
||||
# sanitizing here prevents log injection and rejects obviously
|
||||
# malformed paths early.
|
||||
transcript_path = (
|
||||
str(input_data.get("transcript_path", ""))
|
||||
.replace("\n", "")
|
||||
|
||||
@@ -122,7 +122,7 @@ def test_read_no_cwd_denies_absolute():
|
||||
|
||||
def test_read_tool_results_allowed():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/a1b2c3d4-e5f6-7890-abcd-ef1234567890/tool-results/12345.txt"
|
||||
# is_allowed_local_path requires the session's encoded cwd to be set
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,283 @@
|
||||
"""Unit tests for extracted service helpers.
|
||||
|
||||
Covers ``_is_prompt_too_long``, ``_reduce_context``, ``_iter_sdk_messages``,
|
||||
and the ``ReducedContext`` named tuple.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import (
|
||||
ReducedContext,
|
||||
_is_prompt_too_long,
|
||||
_iter_sdk_messages,
|
||||
_reduce_context,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_prompt_too_long
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsPromptTooLong:
|
||||
def test_direct_match(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("prompt is too long")) is True
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("PROMPT IS TOO LONG")) is True
|
||||
|
||||
def test_no_match(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("network timeout")) is False
|
||||
|
||||
def test_request_too_large(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("request too large for model")) is True
|
||||
|
||||
def test_context_length_exceeded(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("context_length_exceeded")) is True
|
||||
|
||||
def test_max_tokens_exceeded_not_matched(self) -> None:
|
||||
"""'max_tokens_exceeded' is intentionally excluded (too broad)."""
|
||||
assert _is_prompt_too_long(Exception("max_tokens_exceeded")) is False
|
||||
|
||||
def test_max_tokens_config_error_no_match(self) -> None:
|
||||
"""'max_tokens must be at least 1' should NOT match."""
|
||||
assert _is_prompt_too_long(Exception("max_tokens must be at least 1")) is False
|
||||
|
||||
def test_chained_cause(self) -> None:
|
||||
inner = Exception("prompt is too long")
|
||||
outer = RuntimeError("SDK error")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_chained_context(self) -> None:
|
||||
inner = Exception("request too large")
|
||||
outer = RuntimeError("wrapped")
|
||||
outer.__context__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_deep_chain(self) -> None:
|
||||
bottom = Exception("maximum context length")
|
||||
middle = RuntimeError("middle")
|
||||
middle.__cause__ = bottom
|
||||
top = ValueError("top")
|
||||
top.__cause__ = middle
|
||||
assert _is_prompt_too_long(top) is True
|
||||
|
||||
def test_chain_no_match(self) -> None:
|
||||
inner = Exception("rate limit exceeded")
|
||||
outer = RuntimeError("wrapped")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is False
|
||||
|
||||
def test_cycle_detection(self) -> None:
|
||||
"""Exception chain with a cycle should not infinite-loop."""
|
||||
a = Exception("error a")
|
||||
b = Exception("error b")
|
||||
a.__cause__ = b
|
||||
b.__cause__ = a # cycle
|
||||
assert _is_prompt_too_long(a) is False
|
||||
|
||||
def test_all_patterns(self) -> None:
|
||||
patterns = [
|
||||
"prompt is too long",
|
||||
"request too large",
|
||||
"maximum context length",
|
||||
"context_length_exceeded",
|
||||
"input tokens exceed",
|
||||
"input is too long",
|
||||
"content length exceeds",
|
||||
]
|
||||
for pattern in patterns:
|
||||
assert _is_prompt_too_long(Exception(pattern)) is True, pattern
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _reduce_context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReduceContext:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_retry_compaction_success(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=compacted,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.validate_transcript",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.write_transcript_to_tempfile",
|
||||
return_value="/tmp/resume.jsonl",
|
||||
),
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert isinstance(ctx, ReducedContext)
|
||||
assert ctx.use_resume is True
|
||||
assert ctx.resume_file == "/tmp/resume.jsonl"
|
||||
assert ctx.transcript_lost is False
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_fails_drops_transcript(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.resume_file is None
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_tried_compaction_skips(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
ctx = await _reduce_context(transcript, True, "sess-123", "/tmp/cwd", "[test]")
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_transcript_drops(self) -> None:
|
||||
ctx = await _reduce_context("", False, "sess-123", "/tmp/cwd", "[test]")
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_returns_same_content_drops(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=transcript, # same content
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_tempfile_fails_drops(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=compacted,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.validate_transcript",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.write_transcript_to_tempfile",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _iter_sdk_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIterSdkMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_yields_messages(self) -> None:
|
||||
messages = ["msg1", "msg2", "msg3"]
|
||||
client = AsyncMock()
|
||||
|
||||
async def _fake_receive() -> AsyncGenerator[str]:
|
||||
for m in messages:
|
||||
yield m
|
||||
|
||||
client.receive_response = _fake_receive
|
||||
result = [msg async for msg in _iter_sdk_messages(client)]
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartbeat_on_timeout(self) -> None:
|
||||
"""Yields None when asyncio.wait times out."""
|
||||
client = AsyncMock()
|
||||
received: list = []
|
||||
|
||||
async def _slow_receive() -> AsyncGenerator[str]:
|
||||
await asyncio.sleep(100) # never completes
|
||||
yield "never" # pragma: no cover — unreachable, yield makes this an async generator
|
||||
|
||||
client.receive_response = _slow_receive
|
||||
|
||||
with patch("backend.copilot.sdk.service._HEARTBEAT_INTERVAL", 0.01):
|
||||
count = 0
|
||||
async for msg in _iter_sdk_messages(client):
|
||||
received.append(msg)
|
||||
count += 1
|
||||
if count >= 3:
|
||||
break
|
||||
|
||||
assert all(m is None for m in received)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_propagates(self) -> None:
|
||||
client = AsyncMock()
|
||||
|
||||
async def _error_receive() -> AsyncGenerator[str]:
|
||||
raise RuntimeError("SDK crash")
|
||||
yield # pragma: no cover — unreachable, yield makes this an async generator
|
||||
|
||||
client.receive_response = _error_receive
|
||||
|
||||
with pytest.raises(RuntimeError, match="SDK crash"):
|
||||
async for _ in _iter_sdk_messages(client):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_cleanup_on_break(self) -> None:
|
||||
"""Pending task is cancelled when generator is closed."""
|
||||
client = AsyncMock()
|
||||
|
||||
async def _slow_receive() -> AsyncGenerator[str]:
|
||||
yield "first"
|
||||
await asyncio.sleep(100)
|
||||
yield "second"
|
||||
|
||||
client.receive_response = _slow_receive
|
||||
|
||||
gen = _iter_sdk_messages(client)
|
||||
first = await gen.__anext__()
|
||||
assert first == "first"
|
||||
await gen.aclose() # should cancel pending task cleanly
|
||||
@@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .service import _prepare_file_attachments
|
||||
from .service import _prepare_file_attachments, _resolve_sdk_model
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -288,3 +288,214 @@ class TestPromptSupplement:
|
||||
# Count how many times this tool appears as a bullet point
|
||||
count = docs.count(f"- **`{tool_name}`**")
|
||||
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cleanup_sdk_tool_results — orchestration + rate-limiting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupSdkToolResults:
|
||||
"""Tests for _cleanup_sdk_tool_results orchestration and sweep rate-limiting."""
|
||||
|
||||
# All valid cwds must start with /tmp/copilot- (the _SDK_CWD_PREFIX).
|
||||
_CWD_PREFIX = "/tmp/copilot-"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_removes_cwd_directory(self):
|
||||
"""Cleanup removes the session working directory."""
|
||||
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
cwd = "/tmp/copilot-test-cleanup-remove"
|
||||
os.makedirs(cwd, exist_ok=True)
|
||||
|
||||
with patch("backend.copilot.sdk.service.cleanup_stale_project_dirs"):
|
||||
import backend.copilot.sdk.service as svc_mod
|
||||
|
||||
svc_mod._last_sweep_time = 0.0
|
||||
await _cleanup_sdk_tool_results(cwd)
|
||||
|
||||
assert not os.path.exists(cwd)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sweep_runs_when_interval_elapsed(self):
|
||||
"""cleanup_stale_project_dirs is called when 5-minute interval has elapsed."""
|
||||
|
||||
import backend.copilot.sdk.service as svc_mod
|
||||
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
cwd = "/tmp/copilot-test-sweep-elapsed"
|
||||
os.makedirs(cwd, exist_ok=True)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
|
||||
) as mock_sweep:
|
||||
# Set last sweep to a time far in the past
|
||||
svc_mod._last_sweep_time = 0.0
|
||||
await _cleanup_sdk_tool_results(cwd)
|
||||
|
||||
mock_sweep.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sweep_skipped_within_interval(self):
|
||||
"""cleanup_stale_project_dirs is NOT called when within 5-minute interval."""
|
||||
import time
|
||||
|
||||
import backend.copilot.sdk.service as svc_mod
|
||||
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
cwd = "/tmp/copilot-test-sweep-ratelimit"
|
||||
os.makedirs(cwd, exist_ok=True)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
|
||||
) as mock_sweep:
|
||||
# Set last sweep to now — interval not elapsed
|
||||
svc_mod._last_sweep_time = time.time()
|
||||
await _cleanup_sdk_tool_results(cwd)
|
||||
|
||||
mock_sweep.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_path_outside_prefix(self, tmp_path):
|
||||
"""Cleanup rejects a cwd that does not start with the expected prefix."""
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
evil_cwd = str(tmp_path / "evil-path")
|
||||
os.makedirs(evil_cwd, exist_ok=True)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
|
||||
) as mock_sweep:
|
||||
await _cleanup_sdk_tool_results(evil_cwd)
|
||||
|
||||
# Directory should NOT have been removed (rejected early)
|
||||
assert os.path.exists(evil_cwd)
|
||||
mock_sweep.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Env vars that ChatConfig validators read — must be cleared so explicit
|
||||
# constructor values are used.
|
||||
# ---------------------------------------------------------------------------
|
||||
_CONFIG_ENV_VARS = (
|
||||
"CHAT_USE_OPENROUTER",
|
||||
"CHAT_API_KEY",
|
||||
"OPEN_ROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"CHAT_BASE_URL",
|
||||
"OPENROUTER_BASE_URL",
|
||||
"OPENAI_BASE_URL",
|
||||
"CHAT_USE_CLAUDE_CODE_SUBSCRIPTION",
|
||||
"CHAT_USE_CLAUDE_AGENT_SDK",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _clean_config_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _CONFIG_ENV_VARS:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestResolveSdkModel:
|
||||
"""Tests for _resolve_sdk_model — model ID resolution for the SDK CLI."""
|
||||
|
||||
def test_openrouter_active_keeps_dots(self, monkeypatch, _clean_config_env):
|
||||
"""When OpenRouter is fully active, model keeps dot-separated version."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4.6"
|
||||
|
||||
def test_openrouter_disabled_normalizes_to_hyphens(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""When OpenRouter is disabled, dots are replaced with hyphens."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4-6"
|
||||
|
||||
def test_openrouter_enabled_but_missing_key_normalizes(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""When OpenRouter is enabled but api_key is missing, falls back to
|
||||
direct Anthropic and normalizes dots to hyphens."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=True,
|
||||
api_key=None,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4-6"
|
||||
|
||||
def test_explicit_claude_agent_model_takes_precedence(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""When claude_agent_model is explicitly set, it is returned as-is."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model="claude-sonnet-4-5-20250514",
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-sonnet-4-5-20250514"
|
||||
|
||||
def test_subscription_mode_returns_none(self, monkeypatch, _clean_config_env):
|
||||
"""When using Claude Code subscription, returns None (CLI picks model)."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=True,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() is None
|
||||
|
||||
def test_model_without_provider_prefix(self, monkeypatch, _clean_config_env):
|
||||
"""When model has no provider prefix, it still normalizes correctly."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4-6"
|
||||
|
||||
@@ -146,7 +146,7 @@ def stash_pending_tool_output(tool_name: str, output: Any) -> None:
|
||||
event.set()
|
||||
|
||||
|
||||
async def wait_for_stash(timeout: float = 0.5) -> bool:
|
||||
async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
"""Wait for a PostToolUse hook to stash tool output.
|
||||
|
||||
The SDK fires PostToolUse hooks asynchronously via ``start_soon()`` —
|
||||
@@ -155,12 +155,12 @@ async def wait_for_stash(timeout: float = 0.5) -> bool:
|
||||
by waiting on the ``_stash_event``, which is signaled by
|
||||
:func:`stash_pending_tool_output`.
|
||||
|
||||
After the event fires, callers should ``await asyncio.sleep(0)`` to
|
||||
give any remaining concurrent hooks a chance to complete.
|
||||
|
||||
Returns ``True`` if a stash signal was received, ``False`` on timeout.
|
||||
The timeout is a safety net — normally the stash happens within
|
||||
microseconds of yielding to the event loop.
|
||||
|
||||
The 2.0 s default was chosen based on production metrics: the original
|
||||
0.5 s caused frequent timeouts under load (parallel tool calls, large
|
||||
outputs). 2.0 s gives a comfortable margin while still failing fast
|
||||
when the hook genuinely will not fire.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is None:
|
||||
@@ -285,7 +285,7 @@ async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
resolved = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(resolved) as f:
|
||||
with open(resolved, encoding="utf-8", errors="replace") as f:
|
||||
selected = list(itertools.islice(f, offset, offset + limit))
|
||||
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
||||
# don't delete here — the SDK may read in multiple chunks.
|
||||
|
||||
@@ -10,6 +10,9 @@ Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
||||
filesystem for self-hosted) — no DB column needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -17,8 +20,12 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_openai_client
|
||||
from backend.util.prompt import CompressResult, compress_context
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -99,7 +106,14 @@ def strip_progress_entries(content: str) -> str:
|
||||
continue
|
||||
parent = entry.get("parentUuid", "")
|
||||
original_parent = parent
|
||||
while parent in stripped_uuids:
|
||||
# seen_parents is local per-entry (not shared across iterations) so
|
||||
# it can only detect cycles within a single ancestry walk, not across
|
||||
# entries. This is intentional: each entry's parent chain is
|
||||
# independent, and reusing a global set would incorrectly short-circuit
|
||||
# valid re-use of the same UUID as a parent in different subtrees.
|
||||
seen_parents: set[str] = set()
|
||||
while parent in stripped_uuids and parent not in seen_parents:
|
||||
seen_parents.add(parent)
|
||||
parent = uuid_to_parent.get(parent, "")
|
||||
if parent != original_parent:
|
||||
entry["parentUuid"] = parent
|
||||
@@ -151,44 +165,110 @@ def _projects_base() -> str:
|
||||
return os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
|
||||
|
||||
def _cli_project_dir(sdk_cwd: str) -> str | None:
|
||||
"""Return the CLI's project directory for a given working directory.
|
||||
_STALE_PROJECT_DIR_SECONDS = 12 * 3600 # 12 hours — matches max session lifetime
|
||||
_MAX_PROJECT_DIRS_TO_SWEEP = 50 # limit per sweep to avoid long pauses
|
||||
|
||||
Returns ``None`` if the path would escape the projects base.
|
||||
|
||||
def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
|
||||
"""Remove CLI project directories older than ``_STALE_PROJECT_DIR_SECONDS``.
|
||||
|
||||
Each CoPilot SDK turn creates a unique ``~/.claude/projects/<encoded-cwd>/``
|
||||
directory. These are intentionally kept across turns so the model can read
|
||||
tool-result files via ``--resume``. However, after a session ends they
|
||||
become stale. This function sweeps old ones to prevent unbounded disk
|
||||
growth.
|
||||
|
||||
When *encoded_cwd* is provided the sweep is scoped to that single
|
||||
directory, making the operation safe in multi-tenant environments where
|
||||
multiple copilot sessions share the same host. Without it the function
|
||||
falls back to sweeping all directories matching the copilot naming pattern
|
||||
(``-tmp-copilot-``), which is only safe for single-tenant deployments.
|
||||
|
||||
Returns the number of directories removed.
|
||||
"""
|
||||
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
projects_base = _projects_base()
|
||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||
if not os.path.isdir(projects_base):
|
||||
return 0
|
||||
|
||||
if not project_dir.startswith(projects_base + os.sep):
|
||||
logger.warning(
|
||||
"[Transcript] Project dir escaped projects base: %s", project_dir
|
||||
)
|
||||
return None
|
||||
return project_dir
|
||||
now = time.time()
|
||||
removed = 0
|
||||
|
||||
|
||||
def _safe_glob_jsonl(project_dir: str) -> list[Path]:
|
||||
"""Glob ``*.jsonl`` files, filtering out symlinks that escape the directory."""
|
||||
try:
|
||||
resolved_base = Path(project_dir).resolve()
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to resolve project dir: %s", e)
|
||||
return []
|
||||
|
||||
result: list[Path] = []
|
||||
for candidate in Path(project_dir).glob("*.jsonl"):
|
||||
try:
|
||||
resolved = candidate.resolve()
|
||||
if resolved.is_relative_to(resolved_base):
|
||||
result.append(resolved)
|
||||
except (OSError, RuntimeError) as e:
|
||||
logger.debug(
|
||||
"[Transcript] Skipping invalid CLI session candidate %s: %s",
|
||||
candidate,
|
||||
e,
|
||||
# Scoped mode: only clean up the one directory for the current session.
|
||||
if encoded_cwd:
|
||||
target = Path(projects_base) / encoded_cwd
|
||||
if not target.is_dir():
|
||||
return 0
|
||||
# Guard: only sweep copilot-generated dirs.
|
||||
if "-tmp-copilot-" not in target.name:
|
||||
logger.warning(
|
||||
"[Transcript] Refusing to sweep non-copilot dir: %s", target.name
|
||||
)
|
||||
return result
|
||||
return 0
|
||||
try:
|
||||
# st_mtime is used as a proxy for session activity. Claude CLI writes
|
||||
# its JSONL transcript into this directory during each turn, so mtime
|
||||
# advances on every turn. A directory whose mtime is older than
|
||||
# _STALE_PROJECT_DIR_SECONDS has not had an active turn in that window
|
||||
# and is safe to remove (the session cannot --resume after cleanup).
|
||||
age = now - target.stat().st_mtime
|
||||
except OSError:
|
||||
return 0
|
||||
if age < _STALE_PROJECT_DIR_SECONDS:
|
||||
return 0
|
||||
try:
|
||||
shutil.rmtree(target, ignore_errors=True)
|
||||
removed = 1
|
||||
except OSError:
|
||||
pass
|
||||
if removed:
|
||||
logger.info(
|
||||
"[Transcript] Swept stale CLI project dir %s (age %ds > %ds)",
|
||||
target.name,
|
||||
int(age),
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
)
|
||||
return removed
|
||||
|
||||
# Unscoped fallback: sweep all copilot dirs across the projects base.
|
||||
# Only safe for single-tenant deployments; callers should prefer the
|
||||
# scoped variant by passing encoded_cwd.
|
||||
try:
|
||||
entries = Path(projects_base).iterdir()
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to list projects dir: %s", e)
|
||||
return 0
|
||||
|
||||
for entry in entries:
|
||||
if removed >= _MAX_PROJECT_DIRS_TO_SWEEP:
|
||||
break
|
||||
# Only sweep copilot-generated dirs (pattern: -tmp-copilot- or
|
||||
# -private-tmp-copilot-).
|
||||
if "-tmp-copilot-" not in entry.name:
|
||||
continue
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
try:
|
||||
# See the scoped-mode comment above: st_mtime advances on every turn,
|
||||
# so a stale mtime reliably indicates an inactive session.
|
||||
age = now - entry.stat().st_mtime
|
||||
except OSError:
|
||||
continue
|
||||
if age < _STALE_PROJECT_DIR_SECONDS:
|
||||
continue
|
||||
|
||||
try:
|
||||
shutil.rmtree(entry, ignore_errors=True)
|
||||
removed += 1
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if removed:
|
||||
logger.info(
|
||||
"[Transcript] Swept %d stale CLI project dirs (older than %ds)",
|
||||
removed,
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
)
|
||||
return removed
|
||||
|
||||
|
||||
def read_compacted_entries(transcript_path: str) -> list[dict] | None:
|
||||
@@ -255,63 +335,6 @@ def read_compacted_entries(transcript_path: str) -> list[dict] | None:
|
||||
return entries
|
||||
|
||||
|
||||
def read_cli_session_file(sdk_cwd: str) -> str | None:
|
||||
"""Read the CLI's own session file, which reflects any compaction.
|
||||
|
||||
The CLI writes its session transcript to
|
||||
``~/.claude/projects/<encoded_cwd>/<session_id>.jsonl``.
|
||||
Since each SDK turn uses a unique ``sdk_cwd``, there should be
|
||||
exactly one ``.jsonl`` file in that directory.
|
||||
|
||||
Returns the file content, or ``None`` if not found.
|
||||
"""
|
||||
project_dir = _cli_project_dir(sdk_cwd)
|
||||
if not project_dir or not os.path.isdir(project_dir):
|
||||
return None
|
||||
|
||||
jsonl_files = _safe_glob_jsonl(project_dir)
|
||||
if not jsonl_files:
|
||||
logger.debug("[Transcript] No CLI session file found in %s", project_dir)
|
||||
return None
|
||||
|
||||
# Pick the most recently modified file (should be only one per turn).
|
||||
try:
|
||||
session_file = max(jsonl_files, key=lambda p: p.stat().st_mtime)
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to inspect CLI session files: %s", e)
|
||||
return None
|
||||
|
||||
try:
|
||||
content = session_file.read_text()
|
||||
logger.info(
|
||||
"[Transcript] Read CLI session file: %s (%d bytes)",
|
||||
session_file,
|
||||
len(content),
|
||||
)
|
||||
return content
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory.
|
||||
|
||||
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
|
||||
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
|
||||
safe to remove entirely after the transcript has been uploaded.
|
||||
"""
|
||||
project_dir = _cli_project_dir(sdk_cwd)
|
||||
if not project_dir:
|
||||
return
|
||||
|
||||
if os.path.isdir(project_dir):
|
||||
shutil.rmtree(project_dir, ignore_errors=True)
|
||||
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
|
||||
else:
|
||||
logger.debug("[Transcript] Project dir not found: %s", project_dir)
|
||||
|
||||
|
||||
def write_transcript_to_tempfile(
|
||||
transcript_content: str,
|
||||
session_id: str,
|
||||
@@ -327,7 +350,7 @@ def write_transcript_to_tempfile(
|
||||
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
|
||||
real_cwd = os.path.realpath(cwd)
|
||||
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
|
||||
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
|
||||
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -337,17 +360,17 @@ def write_transcript_to_tempfile(
|
||||
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
|
||||
)
|
||||
if not jsonl_path.startswith(real_cwd):
|
||||
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
|
||||
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
|
||||
return None
|
||||
|
||||
with open(jsonl_path, "w") as f:
|
||||
f.write(transcript_content)
|
||||
|
||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
|
||||
return jsonl_path
|
||||
|
||||
except OSError as e:
|
||||
logger.warning(f"[Transcript] Failed to write resume file: {e}")
|
||||
logger.warning("[Transcript] Failed to write resume file: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
@@ -408,8 +431,6 @@ def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, s
|
||||
|
||||
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
|
||||
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||
|
||||
wid, fid, fname = parts
|
||||
if isinstance(backend, GCSWorkspaceStorage):
|
||||
blob = f"workspaces/{wid}/{fid}/{fname}"
|
||||
@@ -448,17 +469,15 @@ async def upload_transcript(
|
||||
content: Complete JSONL transcript (from TranscriptBuilder).
|
||||
message_count: ``len(session.messages)`` at upload time.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
# Strip metadata entries (progress, file-history-snapshot, etc.)
|
||||
# Note: SDK-built transcripts shouldn't have these, but strip for safety
|
||||
stripped = strip_progress_entries(content)
|
||||
if not validate_transcript(stripped):
|
||||
# Log entry types for debugging — helps identify why validation failed
|
||||
entry_types: list[str] = []
|
||||
for line in stripped.strip().split("\n"):
|
||||
entry = json.loads(line, fallback={"type": "INVALID_JSON"})
|
||||
entry_types.append(entry.get("type", "?"))
|
||||
entry_types = [
|
||||
json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?")
|
||||
for line in stripped.strip().split("\n")
|
||||
]
|
||||
logger.warning(
|
||||
"%s Skipping upload — stripped content not valid "
|
||||
"(types=%s, stripped_len=%d, raw_len=%d)",
|
||||
@@ -494,11 +513,14 @@ async def upload_transcript(
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
|
||||
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
|
||||
|
||||
logger.info(
|
||||
f"{log_prefix} Uploaded {len(encoded)}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count})"
|
||||
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
|
||||
log_prefix,
|
||||
len(encoded),
|
||||
len(content),
|
||||
message_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -512,8 +534,6 @@ async def download_transcript(
|
||||
Returns a ``TranscriptDownload`` with the JSONL content and the
|
||||
``message_count`` watermark from the upload, or ``None`` if not found.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
@@ -521,10 +541,10 @@ async def download_transcript(
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"{log_prefix} No transcript in storage")
|
||||
logger.debug("%s No transcript in storage", log_prefix)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
|
||||
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
@@ -536,10 +556,14 @@ async def download_transcript(
|
||||
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except (FileNotFoundError, Exception):
|
||||
except FileNotFoundError:
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
except Exception as e:
|
||||
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
|
||||
|
||||
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
|
||||
logger.info(
|
||||
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
|
||||
)
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
@@ -553,8 +577,6 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
|
||||
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
@@ -571,3 +593,280 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
logger.info("[Transcript] Deleted metadata for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete metadata: %s", e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transcript compaction — LLM summarization for prompt-too-long recovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# JSONL protocol values used in transcript serialization.
|
||||
STOP_REASON_END_TURN = "end_turn"
|
||||
COMPACT_MSG_ID_PREFIX = "msg_compact_"
|
||||
ENTRY_TYPE_MESSAGE = "message"
|
||||
|
||||
|
||||
def _flatten_assistant_content(blocks: list) -> str:
|
||||
"""Flatten assistant content blocks into a single plain-text string.
|
||||
|
||||
Structured ``tool_use`` blocks are converted to ``[tool_use: name]``
|
||||
placeholders. This is intentional: ``compress_context`` requires plain
|
||||
text for token counting and LLM summarization. The structural loss is
|
||||
acceptable because compaction only runs when the original transcript was
|
||||
already too large for the model — a summarized plain-text version is
|
||||
better than no context at all.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict):
|
||||
btype = block.get("type", "")
|
||||
if btype == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif btype == "tool_use":
|
||||
parts.append(f"[tool_use: {block.get('name', '?')}]")
|
||||
else:
|
||||
# Preserve non-text blocks (e.g. image) as placeholders.
|
||||
# Use __prefix__ to distinguish from literal user text.
|
||||
parts.append(f"[__{btype}__]")
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def _flatten_tool_result_content(blocks: list) -> str:
|
||||
"""Flatten tool_result and other content blocks into plain text.
|
||||
|
||||
Handles nested tool_result structures, text blocks, and raw strings.
|
||||
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
|
||||
or where ``text`` is ``None``.
|
||||
|
||||
Like ``_flatten_assistant_content``, structured blocks (images, nested
|
||||
tool results) are reduced to text representations for compression.
|
||||
"""
|
||||
str_parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
inner = block.get("content") or ""
|
||||
if isinstance(inner, list):
|
||||
for sub in inner:
|
||||
if isinstance(sub, dict):
|
||||
sub_type = sub.get("type")
|
||||
if sub_type in ("image", "document"):
|
||||
# Avoid serializing base64 binary data into
|
||||
# the compaction input — use a placeholder.
|
||||
str_parts.append(f"[__{sub_type}__]")
|
||||
elif sub_type == "text" or sub.get("text") is not None:
|
||||
str_parts.append(str(sub.get("text", "")))
|
||||
else:
|
||||
str_parts.append(json.dumps(sub))
|
||||
else:
|
||||
str_parts.append(str(sub))
|
||||
else:
|
||||
str_parts.append(str(inner))
|
||||
elif isinstance(block, dict) and block.get("type") == "text":
|
||||
str_parts.append(str(block.get("text", "")))
|
||||
elif isinstance(block, dict):
|
||||
# Preserve non-text/non-tool_result blocks (e.g. image) as placeholders.
|
||||
# Use __prefix__ to distinguish from literal user text.
|
||||
btype = block.get("type", "unknown")
|
||||
str_parts.append(f"[__{btype}__]")
|
||||
elif isinstance(block, str):
|
||||
str_parts.append(block)
|
||||
return "\n".join(str_parts) if str_parts else ""
|
||||
|
||||
|
||||
def _transcript_to_messages(content: str) -> list[dict]:
|
||||
"""Convert JSONL transcript entries to plain message dicts for compression.
|
||||
|
||||
Parses each line of the JSONL *content*, skips strippable metadata entries
|
||||
(progress, file-history-snapshot, etc.), and extracts the ``role`` and
|
||||
flattened ``content`` from the ``message`` field of each remaining entry.
|
||||
|
||||
Structured content blocks (``tool_use``, ``tool_result``, images) are
|
||||
flattened to plain text via ``_flatten_assistant_content`` and
|
||||
``_flatten_tool_result_content`` so that ``compress_context`` can
|
||||
perform token counting and LLM summarization on uniform strings.
|
||||
|
||||
Returns:
|
||||
A list of ``{"role": str, "content": str}`` dicts suitable for
|
||||
``compress_context``.
|
||||
"""
|
||||
messages: list[dict] = []
|
||||
for line in content.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
|
||||
"isCompactSummary"
|
||||
):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
role = msg.get("role", "")
|
||||
if not role:
|
||||
continue
|
||||
msg_dict: dict = {"role": role}
|
||||
raw_content = msg.get("content")
|
||||
if role == "assistant" and isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_assistant_content(raw_content)
|
||||
elif isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_tool_result_content(raw_content)
|
||||
else:
|
||||
msg_dict["content"] = raw_content or ""
|
||||
messages.append(msg_dict)
|
||||
return messages
|
||||
|
||||
|
||||
def _messages_to_transcript(messages: list[dict]) -> str:
|
||||
"""Convert compressed message dicts back to JSONL transcript format.
|
||||
|
||||
Rebuilds a minimal JSONL transcript from the ``{"role", "content"}``
|
||||
dicts returned by ``compress_context``. Each message becomes one JSONL
|
||||
line with a fresh ``uuid`` / ``parentUuid`` chain so the CLI's
|
||||
``--resume`` flag can reconstruct a valid conversation tree.
|
||||
|
||||
Assistant messages are wrapped in the full ``message`` envelope
|
||||
(``id``, ``model``, ``stop_reason``, structured ``content`` blocks)
|
||||
that the CLI expects. User messages use the simpler ``{role, content}``
|
||||
form.
|
||||
|
||||
Returns:
|
||||
A newline-terminated JSONL string, or an empty string if *messages*
|
||||
is empty.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
last_uuid: str = "" # root entry uses empty string, not null
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
uid = str(uuid4())
|
||||
content = msg.get("content", "")
|
||||
if role == "assistant":
|
||||
message: dict = {
|
||||
"role": "assistant",
|
||||
"model": "",
|
||||
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
|
||||
"type": ENTRY_TYPE_MESSAGE,
|
||||
"content": [{"type": "text", "text": content}] if content else [],
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"stop_sequence": None,
|
||||
}
|
||||
else:
|
||||
message = {"role": role, "content": content}
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": last_uuid,
|
||||
"message": message,
|
||||
}
|
||||
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
last_uuid = uid
|
||||
return "\n".join(lines) + "\n" if lines else ""
|
||||
|
||||
|
||||
_COMPACTION_TIMEOUT_SECONDS = 60
|
||||
_TRUNCATION_TIMEOUT_SECONDS = 30
|
||||
|
||||
|
||||
async def _run_compression(
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
log_prefix: str,
|
||||
) -> CompressResult:
|
||||
"""Run LLM-based compression with truncation fallback.
|
||||
|
||||
Uses the shared OpenAI client from ``get_openai_client()``.
|
||||
If no client is configured or the LLM call fails, falls back to
|
||||
truncation-based compression which drops older messages without
|
||||
summarization.
|
||||
|
||||
A 60-second timeout prevents a hung LLM call from blocking the
|
||||
retry path indefinitely. The truncation fallback also has a
|
||||
30-second timeout to guard against slow tokenization on very large
|
||||
transcripts.
|
||||
"""
|
||||
client = get_openai_client()
|
||||
if client is None:
|
||||
logger.warning("%s No OpenAI client configured, using truncation", log_prefix)
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=None),
|
||||
timeout=_TRUNCATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=client),
|
||||
timeout=_COMPACTION_TIMEOUT_SECONDS,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=None),
|
||||
timeout=_TRUNCATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
async def compact_transcript(
|
||||
content: str,
|
||||
*,
|
||||
model: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> str | None:
|
||||
"""Compact an oversized JSONL transcript using LLM summarization.
|
||||
|
||||
Converts transcript entries to plain messages, runs ``compress_context``
|
||||
(the same compressor used for pre-query history), and rebuilds JSONL.
|
||||
|
||||
Structured content (``tool_use`` blocks, ``tool_result`` nesting, images)
|
||||
is flattened to plain text for compression. This matches the fidelity of
|
||||
the Plan C (DB compression) fallback path, where
|
||||
``_format_conversation_context`` similarly renders tool calls as
|
||||
``You called tool: name(args)`` and results as ``Tool result: ...``.
|
||||
Neither path preserves structured API content blocks — the compacted
|
||||
context serves as text history for the LLM, which creates proper
|
||||
structured tool calls going forward.
|
||||
|
||||
Images are per-turn attachments loaded from workspace storage by file ID
|
||||
(via ``_prepare_file_attachments``), not part of the conversation history.
|
||||
They are re-attached each turn and are unaffected by compaction.
|
||||
|
||||
Returns the compacted JSONL string, or ``None`` on failure.
|
||||
|
||||
See also:
|
||||
``_compress_messages`` in ``service.py`` — compresses ``ChatMessage``
|
||||
lists for pre-query DB history. Both share ``compress_context()``
|
||||
but operate on different input formats (JSONL transcript entries
|
||||
here vs. ChatMessage dicts there).
|
||||
"""
|
||||
messages = _transcript_to_messages(content)
|
||||
if len(messages) < 2:
|
||||
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
|
||||
return None
|
||||
try:
|
||||
result = await _run_compression(messages, model, log_prefix)
|
||||
if not result.was_compacted:
|
||||
# Compressor says it's within budget, but the SDK rejected it.
|
||||
# Return None so the caller falls through to DB fallback.
|
||||
logger.warning(
|
||||
"%s Compressor reports within budget but SDK rejected — "
|
||||
"signalling failure",
|
||||
log_prefix,
|
||||
)
|
||||
return None
|
||||
logger.info(
|
||||
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
|
||||
log_prefix,
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
compacted = _messages_to_transcript(result.messages)
|
||||
if not validate_transcript(compacted):
|
||||
logger.warning("%s Compacted transcript failed validation", log_prefix)
|
||||
return None
|
||||
return compacted
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -68,7 +68,7 @@ class TranscriptBuilder:
|
||||
type=entry_type,
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
isCompactSummary=data.get("isCompactSummary") or None,
|
||||
isCompactSummary=data.get("isCompactSummary"),
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Unit tests for JSONL transcript management utilities."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -9,9 +10,7 @@ from backend.util import json
|
||||
|
||||
from .transcript import (
|
||||
STRIPPABLE_TYPES,
|
||||
_cli_project_dir,
|
||||
delete_transcript,
|
||||
read_cli_session_file,
|
||||
read_compacted_entries,
|
||||
strip_progress_entries,
|
||||
validate_transcript,
|
||||
@@ -292,85 +291,6 @@ class TestStripProgressEntries:
|
||||
assert asst_entry["parentUuid"] == "u1" # reparented
|
||||
|
||||
|
||||
# --- read_cli_session_file ---
|
||||
|
||||
|
||||
class TestReadCliSessionFile:
|
||||
def test_no_matching_files_returns_none(self, tmp_path, monkeypatch):
|
||||
"""read_cli_session_file returns None when no .jsonl files exist."""
|
||||
# Create a project dir with no jsonl files
|
||||
project_dir = tmp_path / "projects" / "encoded-cwd"
|
||||
project_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._cli_project_dir",
|
||||
lambda sdk_cwd: str(project_dir),
|
||||
)
|
||||
assert read_cli_session_file("/fake/cwd") is None
|
||||
|
||||
def test_one_jsonl_file_returns_content(self, tmp_path, monkeypatch):
|
||||
"""read_cli_session_file returns the content of a single .jsonl file."""
|
||||
project_dir = tmp_path / "projects" / "encoded-cwd"
|
||||
project_dir.mkdir(parents=True)
|
||||
jsonl_file = project_dir / "session.jsonl"
|
||||
jsonl_file.write_text("line1\nline2\n")
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._cli_project_dir",
|
||||
lambda sdk_cwd: str(project_dir),
|
||||
)
|
||||
result = read_cli_session_file("/fake/cwd")
|
||||
assert result == "line1\nline2\n"
|
||||
|
||||
def test_symlink_escaping_project_dir_is_skipped(self, tmp_path, monkeypatch):
|
||||
"""read_cli_session_file skips symlinks that escape the project dir."""
|
||||
project_dir = tmp_path / "projects" / "encoded-cwd"
|
||||
project_dir.mkdir(parents=True)
|
||||
|
||||
# Create a file outside the project dir
|
||||
outside = tmp_path / "outside"
|
||||
outside.mkdir()
|
||||
outside_file = outside / "evil.jsonl"
|
||||
outside_file.write_text("should not be read\n")
|
||||
|
||||
# Symlink from inside project_dir to outside file
|
||||
symlink = project_dir / "evil.jsonl"
|
||||
symlink.symlink_to(outside_file)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._cli_project_dir",
|
||||
lambda sdk_cwd: str(project_dir),
|
||||
)
|
||||
# The symlink target resolves outside project_dir, so it should be skipped
|
||||
result = read_cli_session_file("/fake/cwd")
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- _cli_project_dir ---
|
||||
|
||||
|
||||
class TestCliProjectDir:
|
||||
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
|
||||
"""_cli_project_dir returns None when the project dir symlink escapes projects base."""
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir()
|
||||
projects_dir = config_dir / "projects"
|
||||
projects_dir.mkdir()
|
||||
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
# Create a symlink inside projects/ that points outside of it.
|
||||
# _cli_project_dir encodes the cwd as all-alnum-hyphens, so use a
|
||||
# cwd whose encoded form matches the symlink name we create.
|
||||
evil_target = tmp_path / "escaped"
|
||||
evil_target.mkdir()
|
||||
|
||||
# The encoded form of "/evil/cwd" is "-evil-cwd"
|
||||
symlink_path = projects_dir / "-evil-cwd"
|
||||
symlink_path.symlink_to(evil_target)
|
||||
|
||||
result = _cli_project_dir("/evil/cwd")
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- delete_transcript ---
|
||||
|
||||
|
||||
@@ -382,7 +302,7 @@ class TestDeleteTranscript:
|
||||
mock_storage.delete = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -402,7 +322,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -420,7 +340,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -897,3 +817,386 @@ class TestCompactionFlowIntegration:
|
||||
output2 = builder2.to_jsonl()
|
||||
lines2 = [json.loads(line) for line in output2.strip().split("\n")]
|
||||
assert lines2[-1]["parentUuid"] == "a2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_compression (direct tests for the 3 code paths)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunCompression:
|
||||
"""Direct tests for ``_run_compression`` covering all 3 code paths.
|
||||
|
||||
Paths:
|
||||
(a) No OpenAI client configured → truncation fallback immediately.
|
||||
(b) LLM success → returns LLM-compressed result.
|
||||
(c) LLM call raises → truncation fallback.
|
||||
"""
|
||||
|
||||
def _make_compress_result(self, was_compacted: bool, msgs=None):
|
||||
"""Build a minimal CompressResult-like object."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
return SimpleNamespace(
|
||||
was_compacted=was_compacted,
|
||||
messages=msgs or [{"role": "user", "content": "summary"}],
|
||||
original_token_count=500,
|
||||
token_count=100 if was_compacted else 500,
|
||||
messages_summarized=2 if was_compacted else 0,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_client_uses_truncation(self):
|
||||
"""Path (a): ``get_openai_client()`` returns None → truncation only."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated"}]
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "hello"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called with client=None (truncation mode)
|
||||
call_kwargs = mock_compress.call_args
|
||||
assert (
|
||||
call_kwargs.kwargs.get("client") is None
|
||||
or (call_kwargs.args and call_kwargs.args[2] is None)
|
||||
or mock_compress.call_args[1].get("client") is None
|
||||
)
|
||||
assert result is truncation_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_success_returns_llm_result(self):
|
||||
"""Path (b): ``get_openai_client()`` returns a client → LLM compresses."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
llm_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "LLM summary"}]
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=llm_result,
|
||||
) as mock_compress,
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "long conversation"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called with the real client
|
||||
assert mock_compress.called
|
||||
assert result is llm_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_failure_falls_back_to_truncation(self):
|
||||
"""Path (c): LLM call raises → truncation fallback used instead."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated fallback"}]
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
call_count = [0]
|
||||
|
||||
async def _compress_side_effect(**kwargs):
|
||||
call_count[0] += 1
|
||||
if kwargs.get("client") is not None:
|
||||
raise RuntimeError("LLM timeout")
|
||||
return truncation_result
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "long conversation"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called twice: once for LLM (raises), once for truncation
|
||||
assert call_count[0] == 2
|
||||
assert result is truncation_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_timeout_falls_back_to_truncation(self):
|
||||
"""Path (d): LLM call exceeds timeout → truncation fallback used."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated after timeout"}]
|
||||
)
|
||||
call_count = [0]
|
||||
|
||||
async def _compress_side_effect(*, messages, model, client):
|
||||
call_count[0] += 1
|
||||
if client is not None:
|
||||
# Simulate a hang that exceeds the timeout
|
||||
await asyncio.sleep(9999)
|
||||
return truncation_result
|
||||
|
||||
fake_client = MagicMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=fake_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._COMPACTION_TIMEOUT_SECONDS",
|
||||
0.05,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._TRUNCATION_TIMEOUT_SECONDS",
|
||||
5,
|
||||
),
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "long conversation"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called twice: once for LLM (times out), once truncation
|
||||
assert call_count[0] == 2
|
||||
assert result is truncation_result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cleanup_stale_project_dirs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupStaleProjectDirs:
|
||||
"""Tests for cleanup_stale_project_dirs (disk leak prevention)."""
|
||||
|
||||
def test_removes_old_copilot_dirs(self, tmp_path, monkeypatch):
|
||||
"""Directories matching copilot pattern older than threshold are removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
# Create a stale dir
|
||||
stale = projects_dir / "-tmp-copilot-old-session"
|
||||
stale.mkdir()
|
||||
# Set mtime to past the threshold
|
||||
import time
|
||||
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
os.utime(stale, (old_time, old_time))
|
||||
|
||||
# Create a fresh dir
|
||||
fresh = projects_dir / "-tmp-copilot-new-session"
|
||||
fresh.mkdir()
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 1
|
||||
assert not stale.exists()
|
||||
assert fresh.exists()
|
||||
|
||||
def test_ignores_non_copilot_dirs(self, tmp_path, monkeypatch):
|
||||
"""Directories not matching copilot pattern are left alone."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
# Non-copilot dir that's old
|
||||
import time
|
||||
|
||||
other = projects_dir / "some-other-project"
|
||||
other.mkdir()
|
||||
old_time = time.time() - 999999
|
||||
os.utime(other, (old_time, old_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
assert other.exists()
|
||||
|
||||
def test_ttl_boundary_not_removed(self, tmp_path, monkeypatch):
|
||||
"""A directory exactly at the TTL boundary should NOT be removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
# Dir that's exactly at the TTL (age == threshold, not >) — should survive
|
||||
boundary = projects_dir / "-tmp-copilot-boundary"
|
||||
boundary.mkdir()
|
||||
boundary_time = time.time() - _STALE_PROJECT_DIR_SECONDS + 1
|
||||
os.utime(boundary, (boundary_time, boundary_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
assert boundary.exists()
|
||||
|
||||
def test_skips_non_directory_entries(self, tmp_path, monkeypatch):
|
||||
"""Regular files matching the copilot pattern are not removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
# Create a regular FILE (not a dir) with the copilot pattern name
|
||||
stale_file = projects_dir / "-tmp-copilot-stale-file"
|
||||
stale_file.write_text("not a dir")
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
os.utime(stale_file, (old_time, old_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
assert stale_file.exists()
|
||||
|
||||
def test_missing_base_dir_returns_zero(self, tmp_path, monkeypatch):
|
||||
"""If the projects base directory doesn't exist, return 0 gracefully."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
|
||||
nonexistent = str(tmp_path / "does-not-exist" / "projects")
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: nonexistent,
|
||||
)
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
|
||||
def test_scoped_removes_only_target_dir(self, tmp_path, monkeypatch):
|
||||
"""When encoded_cwd is supplied only that directory is swept."""
|
||||
import time
|
||||
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
|
||||
# Two stale copilot dirs
|
||||
target = projects_dir / "-tmp-copilot-session-abc"
|
||||
target.mkdir()
|
||||
os.utime(target, (old_time, old_time))
|
||||
|
||||
other = projects_dir / "-tmp-copilot-session-xyz"
|
||||
other.mkdir()
|
||||
os.utime(other, (old_time, old_time))
|
||||
|
||||
# Only the target dir should be removed
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="-tmp-copilot-session-abc")
|
||||
assert removed == 1
|
||||
assert not target.exists()
|
||||
assert other.exists() # untouched — not the current session
|
||||
|
||||
def test_scoped_fresh_dir_not_removed(self, tmp_path, monkeypatch):
|
||||
"""Scoped sweep leaves a fresh directory alone."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
fresh = projects_dir / "-tmp-copilot-session-new"
|
||||
fresh.mkdir()
|
||||
# mtime is now — well within TTL
|
||||
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="-tmp-copilot-session-new")
|
||||
assert removed == 0
|
||||
assert fresh.exists()
|
||||
|
||||
def test_scoped_non_copilot_dir_not_removed(self, tmp_path, monkeypatch):
|
||||
"""Scoped sweep refuses to remove a non-copilot directory."""
|
||||
import time
|
||||
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
non_copilot = projects_dir / "some-other-project"
|
||||
non_copilot.mkdir()
|
||||
os.utime(non_copilot, (old_time, old_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="some-other-project")
|
||||
assert removed == 0
|
||||
assert non_copilot.exists()
|
||||
|
||||
@@ -4,11 +4,12 @@ These tests verify the complete copilot flow using dummy implementations
|
||||
for agent generator and SDK service, allowing automated testing without
|
||||
external LLM calls.
|
||||
|
||||
Enable test mode with COPILOT_TEST_MODE=true environment variable.
|
||||
Enable test mode with CHAT_TEST_MODE=true environment variable (or in .env).
|
||||
|
||||
Note: StreamFinish is NOT emitted by the dummy service — it is published
|
||||
by mark_session_completed in the processor layer. These tests only cover
|
||||
the service-level streaming output (StreamStart + StreamTextDelta).
|
||||
The dummy service emits the full AI SDK protocol event sequence:
|
||||
StreamStart → StreamStartStep → StreamTextStart → StreamTextDelta(s) →
|
||||
StreamTextEnd → StreamFinishStep → StreamFinish.
|
||||
The processor skips StreamFinish and publishes its own via mark_session_completed.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -20,9 +21,14 @@ import pytest
|
||||
from backend.copilot.model import ChatMessage, ChatSession, upsert_chat_session
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
)
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
|
||||
@@ -30,9 +36,9 @@ from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_test_mode():
|
||||
"""Enable test mode for all tests in this module."""
|
||||
os.environ["COPILOT_TEST_MODE"] = "true"
|
||||
os.environ["CHAT_TEST_MODE"] = "true"
|
||||
yield
|
||||
os.environ.pop("COPILOT_TEST_MODE", None)
|
||||
os.environ.pop("CHAT_TEST_MODE", None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -110,9 +116,14 @@ async def test_streaming_event_types():
|
||||
):
|
||||
event_types.add(type(event).__name__)
|
||||
|
||||
# Required event types (StreamFinish is published by processor, not service)
|
||||
# Required event types for full AI SDK protocol
|
||||
assert "StreamStart" in event_types, "Missing StreamStart"
|
||||
assert "StreamStartStep" in event_types, "Missing StreamStartStep"
|
||||
assert "StreamTextStart" in event_types, "Missing StreamTextStart"
|
||||
assert "StreamTextDelta" in event_types, "Missing StreamTextDelta"
|
||||
assert "StreamTextEnd" in event_types, "Missing StreamTextEnd"
|
||||
assert "StreamFinishStep" in event_types, "Missing StreamFinishStep"
|
||||
assert "StreamFinish" in event_types, "Missing StreamFinish"
|
||||
|
||||
print(f"✅ Event types: {sorted(event_types)}")
|
||||
|
||||
@@ -175,16 +186,17 @@ async def test_streaming_heartbeat_timing():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling():
|
||||
"""Test that errors are properly formatted and sent."""
|
||||
# This would require a dummy that can trigger errors
|
||||
# For now, just verify error event structure
|
||||
|
||||
"""Test that error events have correct SSE structure."""
|
||||
error = StreamError(errorText="Test error", code="test_error")
|
||||
assert error.errorText == "Test error"
|
||||
assert error.code == "test_error"
|
||||
assert str(error.type.value) in ["error", "error"]
|
||||
|
||||
print("✅ Error structure verified")
|
||||
# Verify to_sse() strips code (AI SDK protocol compliance)
|
||||
sse = error.to_sse()
|
||||
assert '"errorText"' in sse
|
||||
assert '"code"' not in sse, "to_sse() must strip code field for AI SDK"
|
||||
|
||||
print("✅ Error structure verified (code stripped in SSE)")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -326,20 +338,85 @@ async def test_stream_completeness():
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Check for required events (StreamFinish is published by processor)
|
||||
has_start = any(isinstance(e, StreamStart) for e in events)
|
||||
has_text = any(isinstance(e, StreamTextDelta) for e in events)
|
||||
|
||||
assert has_start, "Stream must include StreamStart"
|
||||
assert has_text, "Stream must include text deltas"
|
||||
# Check for all required event types
|
||||
assert any(isinstance(e, StreamStart) for e in events), "Missing StreamStart"
|
||||
assert any(
|
||||
isinstance(e, StreamStartStep) for e in events
|
||||
), "Missing StreamStartStep"
|
||||
assert any(
|
||||
isinstance(e, StreamTextStart) for e in events
|
||||
), "Missing StreamTextStart"
|
||||
assert any(
|
||||
isinstance(e, StreamTextDelta) for e in events
|
||||
), "Missing StreamTextDelta"
|
||||
assert any(isinstance(e, StreamTextEnd) for e in events), "Missing StreamTextEnd"
|
||||
assert any(
|
||||
isinstance(e, StreamFinishStep) for e in events
|
||||
), "Missing StreamFinishStep"
|
||||
assert any(isinstance(e, StreamFinish) for e in events), "Missing StreamFinish"
|
||||
|
||||
# Verify exactly one start
|
||||
start_count = sum(1 for e in events if isinstance(e, StreamStart))
|
||||
assert start_count == 1, f"Should have exactly 1 StreamStart, got {start_count}"
|
||||
|
||||
print(
|
||||
f"✅ Completeness: 1 start, {sum(1 for e in events if isinstance(e, StreamTextDelta))} text deltas"
|
||||
)
|
||||
print(f"✅ Completeness: {len(events)} events, full protocol sequence")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transient_error_shows_retryable():
|
||||
"""Test __test_transient_error__ yields partial text then retryable StreamError."""
|
||||
events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-transient",
|
||||
message="please fail __test_transient_error__",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Should start with StreamStart
|
||||
assert isinstance(events[0], StreamStart)
|
||||
|
||||
# Should have some partial text before the error
|
||||
text_events = [e for e in events if isinstance(e, StreamTextDelta)]
|
||||
assert len(text_events) > 0, "Should stream partial text before error"
|
||||
|
||||
# Should end with StreamError
|
||||
error_events = [e for e in events if isinstance(e, StreamError)]
|
||||
assert len(error_events) == 1, "Should have exactly one StreamError"
|
||||
assert error_events[0].code == "transient_api_error"
|
||||
assert "connection interrupted" in error_events[0].errorText.lower()
|
||||
|
||||
print(f"✅ Transient error: {len(text_events)} partial deltas + retryable error")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fatal_error_not_retryable():
|
||||
"""Test __test_fatal_error__ yields StreamError without retryable code."""
|
||||
events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-fatal",
|
||||
message="__test_fatal_error__",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert isinstance(events[0], StreamStart)
|
||||
|
||||
# Should have StreamError with sdk_error code (not transient)
|
||||
error_events = [e for e in events if isinstance(e, StreamError)]
|
||||
assert len(error_events) == 1
|
||||
assert error_events[0].code == "sdk_error"
|
||||
assert "transient" not in error_events[0].code
|
||||
|
||||
# Should NOT have any text deltas (fatal errors fail immediately)
|
||||
text_events = [e for e in events if isinstance(e, StreamTextDelta)]
|
||||
assert len(text_events) == 0, "Fatal error should not stream any text"
|
||||
|
||||
print("✅ Fatal error: immediate error, no partial text")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -395,6 +472,8 @@ if __name__ == "__main__":
|
||||
asyncio.run(test_message_deduplication())
|
||||
asyncio.run(test_event_ordering())
|
||||
asyncio.run(test_stream_completeness())
|
||||
asyncio.run(test_transient_error_shows_retryable())
|
||||
asyncio.run(test_fatal_error_not_retryable())
|
||||
asyncio.run(test_text_delta_consistency())
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
@@ -12,6 +12,7 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
@@ -84,6 +85,7 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"browser_screenshot": BrowserScreenshotTool(),
|
||||
# Sandboxed code execution (bubblewrap)
|
||||
"bash_exec": BashExecTool(),
|
||||
"connect_integration": ConnectIntegrationTool(),
|
||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
||||
# Feature request tools
|
||||
"search_feature_requests": SearchFeatureRequestsTool(),
|
||||
|
||||
@@ -20,7 +20,9 @@ SSRF protection:
|
||||
|
||||
Requires:
|
||||
npm install -g agent-browser
|
||||
agent-browser install (downloads Chromium, one-time per machine)
|
||||
agent-browser install (downloads Chromium, one-time — skipped in Docker
|
||||
where system chromium is pre-installed and
|
||||
AGENT_BROWSER_EXECUTABLE_PATH is set)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any
|
||||
from .helpers import (
|
||||
AGENT_EXECUTOR_BLOCK_ID,
|
||||
MCP_TOOL_BLOCK_ID,
|
||||
SMART_DECISION_MAKER_BLOCK_ID,
|
||||
AgentDict,
|
||||
are_types_compatible,
|
||||
generate_uuid,
|
||||
@@ -30,6 +31,14 @@ _GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
|
||||
_GMAIL_SEND_BLOCK_ID = "6c27abc2-e51d-499e-a85f-5a0041ba94f0"
|
||||
_TEXT_REPLACE_BLOCK_ID = "7e7c87ab-3469-4bcc-9abe-67705091b713"
|
||||
|
||||
# Defaults applied to SmartDecisionMakerBlock nodes by the fixer.
|
||||
_SDM_DEFAULTS: dict[str, int | bool] = {
|
||||
"agent_mode_max_iterations": 10,
|
||||
"conversation_compaction": True,
|
||||
"retry": 3,
|
||||
"multiple_tool_calls": False,
|
||||
}
|
||||
|
||||
|
||||
class AgentFixer:
|
||||
"""
|
||||
@@ -1630,6 +1639,43 @@ class AgentFixer:
|
||||
|
||||
return agent
|
||||
|
||||
def fix_smart_decision_maker_blocks(self, agent: AgentDict) -> AgentDict:
|
||||
"""Fix SmartDecisionMakerBlock nodes to ensure agent-mode defaults.
|
||||
|
||||
Ensures:
|
||||
1. ``agent_mode_max_iterations`` defaults to ``10`` (bounded agent mode)
|
||||
2. ``conversation_compaction`` defaults to ``True``
|
||||
3. ``retry`` defaults to ``3``
|
||||
4. ``multiple_tool_calls`` defaults to ``False``
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to fix
|
||||
|
||||
Returns:
|
||||
The fixed agent dictionary
|
||||
"""
|
||||
nodes = agent.get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") != SMART_DECISION_MAKER_BLOCK_ID:
|
||||
continue
|
||||
|
||||
node_id = node.get("id", "unknown")
|
||||
input_default = node.get("input_default")
|
||||
if not isinstance(input_default, dict):
|
||||
input_default = {}
|
||||
node["input_default"] = input_default
|
||||
|
||||
for field, default_value in _SDM_DEFAULTS.items():
|
||||
if field not in input_default or input_default[field] is None:
|
||||
input_default[field] = default_value
|
||||
self.add_fix_log(
|
||||
f"SmartDecisionMakerBlock {node_id}: "
|
||||
f"Set {field}={default_value!r}"
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
def fix_dynamic_block_sink_names(self, agent: AgentDict) -> AgentDict:
|
||||
"""Fix links that use _#_ notation for dynamic block sink names.
|
||||
|
||||
@@ -1717,6 +1763,9 @@ class AgentFixer:
|
||||
# Apply fixes for MCPToolBlock nodes
|
||||
agent = self.fix_mcp_tool_blocks(agent)
|
||||
|
||||
# Apply fixes for SmartDecisionMakerBlock nodes (agent-mode defaults)
|
||||
agent = self.fix_smart_decision_maker_blocks(agent)
|
||||
|
||||
# Apply fixes for AgentExecutorBlock nodes (sub-agents)
|
||||
if library_agents:
|
||||
agent = self.fix_agent_executor_blocks(agent, library_agents)
|
||||
|
||||
@@ -12,6 +12,7 @@ __all__ = [
|
||||
"AGENT_OUTPUT_BLOCK_ID",
|
||||
"AgentDict",
|
||||
"MCP_TOOL_BLOCK_ID",
|
||||
"SMART_DECISION_MAKER_BLOCK_ID",
|
||||
"UUID_REGEX",
|
||||
"are_types_compatible",
|
||||
"generate_uuid",
|
||||
@@ -33,6 +34,7 @@ UUID_REGEX = re.compile(r"^" + UUID_RE_STR + r"$")
|
||||
|
||||
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||
MCP_TOOL_BLOCK_ID = "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
||||
SMART_DECISION_MAKER_BLOCK_ID = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
|
||||
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from .helpers import (
|
||||
AGENT_INPUT_BLOCK_ID,
|
||||
AGENT_OUTPUT_BLOCK_ID,
|
||||
MCP_TOOL_BLOCK_ID,
|
||||
SMART_DECISION_MAKER_BLOCK_ID,
|
||||
AgentDict,
|
||||
are_types_compatible,
|
||||
get_defined_property_type,
|
||||
@@ -181,15 +182,23 @@ class AgentValidator:
|
||||
|
||||
return valid
|
||||
|
||||
def _build_node_lookup(self, agent: AgentDict) -> dict[str, dict[str, Any]]:
|
||||
"""Build a node-id → node dict from the agent's nodes."""
|
||||
return {node.get("id", ""): node for node in agent.get("nodes", [])}
|
||||
|
||||
def validate_data_type_compatibility(
|
||||
self, agent: AgentDict, blocks: list[dict[str, Any]]
|
||||
self,
|
||||
agent: AgentDict,
|
||||
blocks: list[dict[str, Any]],
|
||||
node_lookup: dict[str, dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that linked data types are compatible between source and sink.
|
||||
Returns True if all data types are compatible, False otherwise.
|
||||
"""
|
||||
valid = True
|
||||
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
|
||||
if node_lookup is None:
|
||||
node_lookup = self._build_node_lookup(agent)
|
||||
block_lookup = {block.get("id", ""): block for block in blocks}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
@@ -209,8 +218,8 @@ class AgentValidator:
|
||||
valid = False
|
||||
continue
|
||||
|
||||
source_node = node_lookup.get(source_id, "")
|
||||
sink_node = node_lookup.get(sink_id, "")
|
||||
source_node = node_lookup.get(source_id)
|
||||
sink_node = node_lookup.get(sink_id)
|
||||
|
||||
if not source_node or not sink_node:
|
||||
continue
|
||||
@@ -248,7 +257,10 @@ class AgentValidator:
|
||||
return valid
|
||||
|
||||
def validate_nested_sink_links(
|
||||
self, agent: AgentDict, blocks: list[dict[str, Any]]
|
||||
self,
|
||||
agent: AgentDict,
|
||||
blocks: list[dict[str, Any]],
|
||||
node_lookup: dict[str, dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate nested sink links (links with _#_ notation).
|
||||
@@ -262,7 +274,8 @@ class AgentValidator:
|
||||
block_names = {
|
||||
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
|
||||
}
|
||||
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
|
||||
if node_lookup is None:
|
||||
node_lookup = self._build_node_lookup(agent)
|
||||
|
||||
for link in agent.get("links", []):
|
||||
sink_name = link.get("sink_name", "")
|
||||
@@ -388,7 +401,10 @@ class AgentValidator:
|
||||
return valid
|
||||
|
||||
def validate_source_output_existence(
|
||||
self, agent: AgentDict, blocks: list[dict[str, Any]]
|
||||
self,
|
||||
agent: AgentDict,
|
||||
blocks: list[dict[str, Any]],
|
||||
node_lookup: dict[str, dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that all source_names in links exist in the corresponding
|
||||
@@ -401,6 +417,7 @@ class AgentValidator:
|
||||
Args:
|
||||
agent: The agent dictionary to validate
|
||||
blocks: List of available blocks with their schemas
|
||||
node_lookup: Optional pre-built node-id → node dict
|
||||
|
||||
Returns:
|
||||
True if all source output fields exist, False otherwise
|
||||
@@ -415,7 +432,8 @@ class AgentValidator:
|
||||
block_names = {
|
||||
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
|
||||
}
|
||||
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
|
||||
if node_lookup is None:
|
||||
node_lookup = self._build_node_lookup(agent)
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_id = link.get("source_id")
|
||||
@@ -809,6 +827,96 @@ class AgentValidator:
|
||||
|
||||
return valid
|
||||
|
||||
def validate_smart_decision_maker_blocks(
|
||||
self,
|
||||
agent: AgentDict,
|
||||
node_lookup: dict[str, dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""Validate that SmartDecisionMakerBlock nodes have downstream tools.
|
||||
|
||||
Checks that each SmartDecisionMakerBlock node has at least one link
|
||||
with ``source_name == "tools"`` connecting to a downstream block.
|
||||
Without tools, the block has nothing to call and will error at runtime.
|
||||
|
||||
Returns True if all SmartDecisionMakerBlock nodes are valid.
|
||||
"""
|
||||
valid = True
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
if node_lookup is None:
|
||||
node_lookup = self._build_node_lookup(agent)
|
||||
non_tool_block_ids = {AGENT_INPUT_BLOCK_ID, AGENT_OUTPUT_BLOCK_ID}
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") != SMART_DECISION_MAKER_BLOCK_ID:
|
||||
continue
|
||||
|
||||
node_id = node.get("id", "unknown")
|
||||
customized_name = (node.get("metadata") or {}).get(
|
||||
"customized_name", node_id
|
||||
)
|
||||
|
||||
# Warn if agent_mode_max_iterations is 0 (traditional mode) —
|
||||
# requires complex external conversation-history loop wiring
|
||||
# that the agent generator does not produce.
|
||||
input_default = node.get("input_default", {})
|
||||
max_iter = input_default.get("agent_mode_max_iterations")
|
||||
if max_iter is not None and not isinstance(max_iter, int):
|
||||
self.add_error(
|
||||
f"SmartDecisionMakerBlock node '{customized_name}' "
|
||||
f"({node_id}) has non-integer "
|
||||
f"agent_mode_max_iterations={max_iter!r}. "
|
||||
f"This field must be an integer."
|
||||
)
|
||||
valid = False
|
||||
elif isinstance(max_iter, int) and max_iter < -1:
|
||||
self.add_error(
|
||||
f"SmartDecisionMakerBlock node '{customized_name}' "
|
||||
f"({node_id}) has invalid "
|
||||
f"agent_mode_max_iterations={max_iter}. "
|
||||
f"Use -1 for infinite or a positive number for "
|
||||
f"bounded iterations."
|
||||
)
|
||||
valid = False
|
||||
elif isinstance(max_iter, int) and max_iter > 100:
|
||||
self.add_error(
|
||||
f"SmartDecisionMakerBlock node '{customized_name}' "
|
||||
f"({node_id}) has agent_mode_max_iterations="
|
||||
f"{max_iter} which is unusually high. Values above "
|
||||
f"100 risk excessive cost and long execution times. "
|
||||
f"Consider using a lower value (3-10) or -1 for "
|
||||
f"genuinely open-ended tasks."
|
||||
)
|
||||
valid = False
|
||||
elif max_iter == 0:
|
||||
self.add_error(
|
||||
f"SmartDecisionMakerBlock node '{customized_name}' "
|
||||
f"({node_id}) has agent_mode_max_iterations=0 "
|
||||
f"(traditional mode). The agent generator only supports "
|
||||
f"agent mode (set to -1 for infinite or a positive "
|
||||
f"number for bounded iterations)."
|
||||
)
|
||||
valid = False
|
||||
|
||||
has_tools = any(
|
||||
link.get("source_id") == node_id
|
||||
and link.get("source_name") == "tools"
|
||||
and node_lookup.get(link.get("sink_id", ""), {}).get("block_id")
|
||||
not in non_tool_block_ids
|
||||
for link in links
|
||||
)
|
||||
|
||||
if not has_tools:
|
||||
self.add_error(
|
||||
f"SmartDecisionMakerBlock node '{customized_name}' "
|
||||
f"({node_id}) has no downstream tool blocks connected. "
|
||||
f"Connect at least one block to its 'tools' output so "
|
||||
f"the AI has tools to call."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_mcp_tool_blocks(self, agent: AgentDict) -> bool:
|
||||
"""Validate that MCPToolBlock nodes have required fields.
|
||||
|
||||
@@ -870,6 +978,9 @@ class AgentValidator:
|
||||
logger.info("Validating agent...")
|
||||
self.errors = []
|
||||
|
||||
# Build node lookup once and share across validation methods
|
||||
node_lookup = self._build_node_lookup(agent)
|
||||
|
||||
checks = [
|
||||
(
|
||||
"Block existence",
|
||||
@@ -885,15 +996,15 @@ class AgentValidator:
|
||||
),
|
||||
(
|
||||
"Data type compatibility",
|
||||
self.validate_data_type_compatibility(agent, blocks),
|
||||
self.validate_data_type_compatibility(agent, blocks, node_lookup),
|
||||
),
|
||||
(
|
||||
"Nested sink links",
|
||||
self.validate_nested_sink_links(agent, blocks),
|
||||
self.validate_nested_sink_links(agent, blocks, node_lookup),
|
||||
),
|
||||
(
|
||||
"Source output existence",
|
||||
self.validate_source_output_existence(agent, blocks),
|
||||
self.validate_source_output_existence(agent, blocks, node_lookup),
|
||||
),
|
||||
(
|
||||
"Prompt double curly braces spaces",
|
||||
@@ -913,6 +1024,10 @@ class AgentValidator:
|
||||
"MCP tool blocks",
|
||||
self.validate_mcp_tool_blocks(agent),
|
||||
),
|
||||
(
|
||||
"SmartDecisionMaker blocks",
|
||||
self.validate_smart_decision_maker_blocks(agent, node_lookup),
|
||||
),
|
||||
]
|
||||
|
||||
# Add AgentExecutorBlock detailed validation if library_agents
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.api.features.store.model import StoreAgent, StoreAgentDetails
|
||||
|
||||
from backend.data.db_accessors import library_db, store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
@@ -19,16 +19,12 @@ from .models import (
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .utils import is_creator_slug, is_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SearchSource = Literal["marketplace", "library"]
|
||||
|
||||
_UUID_PATTERN = re.compile(
|
||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Keywords that should be treated as "list all" rather than a literal search
|
||||
_LIST_ALL_KEYWORDS = frozenset({"all", "*", "everything", "any", ""})
|
||||
|
||||
@@ -39,149 +35,160 @@ async def search_agents(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""
|
||||
Search for agents in marketplace or user library.
|
||||
"""Search for agents in marketplace or user library."""
|
||||
if source == "marketplace":
|
||||
return await _search_marketplace(query, session_id)
|
||||
else:
|
||||
return await _search_library(query, session_id, user_id)
|
||||
|
||||
For library searches, keywords like "all", "*", "everything", or an empty
|
||||
query will list all agents without filtering.
|
||||
|
||||
Args:
|
||||
query: Search query string. Special keywords list all library agents.
|
||||
source: "marketplace" or "library"
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (required for library search)
|
||||
|
||||
Returns:
|
||||
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
|
||||
"""
|
||||
# Normalize list-all keywords to empty string for library searches
|
||||
if source == "library" and query.lower().strip() in _LIST_ALL_KEYWORDS:
|
||||
query = ""
|
||||
|
||||
if source == "marketplace" and not query:
|
||||
async def _search_marketplace(query: str, session_id: str | None) -> ToolResponseBase:
|
||||
"""Search marketplace agents, with direct creator/slug lookup fallback."""
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query", session_id=session_id
|
||||
)
|
||||
|
||||
if source == "library" and not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required to search library",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agents: list[AgentInfo] = []
|
||||
try:
|
||||
if source == "marketplace":
|
||||
# Direct lookup if query matches "creator/slug" pattern
|
||||
if is_creator_slug(query):
|
||||
logger.info(f"Query looks like creator/slug, trying direct lookup: {query}")
|
||||
creator, slug = query.split("/", 1)
|
||||
agent_info = await _get_marketplace_agent_by_slug(creator, slug)
|
||||
if agent_info:
|
||||
agents.append(agent_info)
|
||||
|
||||
if not agents:
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
results = await store_db().get_store_agents(search_query=query, page_size=5)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=f"{agent.creator}/{agent.slug}",
|
||||
name=agent.agent_name,
|
||||
description=agent.description or "",
|
||||
source="marketplace",
|
||||
in_library=False,
|
||||
creator=agent.creator,
|
||||
category="general",
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if _is_uuid(query):
|
||||
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
|
||||
agent = await _get_library_agent_by_id(user_id, query) # type: ignore[arg-type]
|
||||
if agent:
|
||||
agents.append(agent)
|
||||
logger.info(f"Found agent by direct ID lookup: {agent.name}")
|
||||
|
||||
if not agents:
|
||||
search_term = query or None
|
||||
logger.info(
|
||||
f"{'Listing all agents in' if not query else 'Searching'} "
|
||||
f"user library{'' if not query else f' for: {query}'}"
|
||||
)
|
||||
results = await library_db().list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=search_term,
|
||||
page_size=50 if not query else 10,
|
||||
)
|
||||
for agent in results.agents:
|
||||
agents.append(_library_agent_to_info(agent))
|
||||
logger.info(f"Found {len(agents)} agents in {source}")
|
||||
agents.append(_marketplace_agent_to_info(agent))
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching {source}: {e}", exc_info=True)
|
||||
logger.error(f"Error searching marketplace: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to search {source}. Please try again.",
|
||||
message="Failed to search marketplace. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agents:
|
||||
if source == "marketplace":
|
||||
suggestions = [
|
||||
"Try more general terms",
|
||||
"Browse categories in the marketplace",
|
||||
"Check spelling",
|
||||
]
|
||||
no_results_msg = (
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"No agents found matching '{query}'. Let the user know they can "
|
||||
"try different keywords or browse the marketplace. Also let them "
|
||||
"know you can create a custom agent for them based on their needs."
|
||||
),
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Browse categories in the marketplace",
|
||||
"Check spelling",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return AgentsFoundResponse(
|
||||
message=(
|
||||
"Now you have found some options for the user to choose from. "
|
||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||
"Please ask the user if they would like to use any of these agents. "
|
||||
"Let the user know we can create a custom agent for them based on their needs."
|
||||
),
|
||||
title=f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'",
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
async def _search_library(
|
||||
query: str, session_id: str | None, user_id: str | None
|
||||
) -> ToolResponseBase:
|
||||
"""Search user's library agents, with direct UUID lookup fallback."""
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required to search library",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
query = query.strip()
|
||||
# Normalize list-all keywords to empty string
|
||||
if query.lower() in _LIST_ALL_KEYWORDS:
|
||||
query = ""
|
||||
|
||||
agents: list[AgentInfo] = []
|
||||
try:
|
||||
if is_uuid(query):
|
||||
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
|
||||
agent = await _get_library_agent_by_id(user_id, query)
|
||||
if agent:
|
||||
agents.append(agent)
|
||||
|
||||
if not agents:
|
||||
logger.info(
|
||||
f"{'Listing all agents in' if not query else 'Searching'} "
|
||||
f"user library{'' if not query else f' for: {query}'}"
|
||||
)
|
||||
elif not query:
|
||||
# User asked to list all but library is empty
|
||||
suggestions = [
|
||||
"Browse the marketplace to find and add agents",
|
||||
"Use find_agent to search the marketplace",
|
||||
]
|
||||
no_results_msg = (
|
||||
"Your library is empty. Let the user know they can browse the "
|
||||
"marketplace to find agents, or you can create a custom agent "
|
||||
"for them based on their needs."
|
||||
results = await library_db().list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=query or None,
|
||||
page_size=50 if not query else 10,
|
||||
)
|
||||
else:
|
||||
suggestions = [
|
||||
"Try different keywords",
|
||||
"Use find_agent to search the marketplace",
|
||||
"Check your library at /library",
|
||||
]
|
||||
no_results_msg = (
|
||||
for agent in results.agents:
|
||||
agents.append(_library_agent_to_info(agent))
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching library: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search library. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agents:
|
||||
if not query:
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
"Your library is empty. Let the user know they can browse the "
|
||||
"marketplace to find agents, or you can create a custom agent "
|
||||
"for them based on their needs."
|
||||
),
|
||||
suggestions=[
|
||||
"Browse the marketplace to find and add agents",
|
||||
"Use find_agent to search the marketplace",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"No agents matching '{query}' found in your library. Let the "
|
||||
"user know you can create a custom agent for them based on "
|
||||
"their needs."
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||
),
|
||||
suggestions=[
|
||||
"Try different keywords",
|
||||
"Use find_agent to search the marketplace",
|
||||
"Check your library at /library",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if source == "marketplace":
|
||||
title = (
|
||||
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'"
|
||||
)
|
||||
elif not query:
|
||||
if not query:
|
||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library"
|
||||
else:
|
||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library for '{query}'"
|
||||
|
||||
message = (
|
||||
"Now you have found some options for the user to choose from. "
|
||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||
"Please ask the user if they would like to use any of these agents. "
|
||||
"Let the user know we can create a custom agent for them based on their needs."
|
||||
if source == "marketplace"
|
||||
else "Found agents in the user's library. You can provide a link to view "
|
||||
"an agent at: /library/agents/{agent_id}. Use agent_output to get "
|
||||
"execution results, or run_agent to execute. Let the user know we can "
|
||||
"create a custom agent for them based on their needs."
|
||||
)
|
||||
|
||||
return AgentsFoundResponse(
|
||||
message=message,
|
||||
message=(
|
||||
"Found agents in the user's library. You can provide a link to view "
|
||||
"an agent at: /library/agents/{agent_id}. Use agent_output to get "
|
||||
"execution results, or run_agent to execute. Let the user know we can "
|
||||
"create a custom agent for them based on their needs."
|
||||
),
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
@@ -189,9 +196,20 @@ async def search_agents(
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(text: str) -> bool:
|
||||
"""Check if text is a valid UUID v4."""
|
||||
return bool(_UUID_PATTERN.match(text.strip()))
|
||||
def _marketplace_agent_to_info(agent: StoreAgent | StoreAgentDetails) -> AgentInfo:
|
||||
"""Convert a marketplace agent (StoreAgent or StoreAgentDetails) to an AgentInfo."""
|
||||
return AgentInfo(
|
||||
id=f"{agent.creator}/{agent.slug}",
|
||||
name=agent.agent_name,
|
||||
description=agent.description or "",
|
||||
source="marketplace",
|
||||
in_library=False,
|
||||
creator=agent.creator,
|
||||
category="general",
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False,
|
||||
)
|
||||
|
||||
|
||||
def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
|
||||
@@ -214,6 +232,23 @@ def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
|
||||
)
|
||||
|
||||
|
||||
async def _get_marketplace_agent_by_slug(creator: str, slug: str) -> AgentInfo | None:
|
||||
"""Fetch a marketplace agent by creator/slug identifier."""
|
||||
try:
|
||||
details = await store_db().get_store_agent_details(creator, slug)
|
||||
return _marketplace_agent_to_info(details)
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch marketplace agent {creator}/{slug}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
|
||||
"""Fetch a library agent by ID (library agent ID or graph_id).
|
||||
|
||||
@@ -226,10 +261,9 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
||||
try:
|
||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return _library_agent_to_info(agent)
|
||||
except NotFoundError:
|
||||
logger.debug(f"Library agent not found by graph_id: {agent_id}")
|
||||
pass
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -241,10 +275,9 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
||||
try:
|
||||
agent = await lib_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return _library_agent_to_info(agent)
|
||||
except NotFoundError:
|
||||
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||
pass
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
"""Tests for agent search direct lookup functionality."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .models import AgentsFoundResponse, NoResultsResponse
|
||||
|
||||
_TEST_USER_ID = "test-user-agent-search"
|
||||
|
||||
|
||||
class TestMarketplaceSlugLookup:
|
||||
"""Tests for creator/slug direct lookup in marketplace search."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_slug_lookup_found(self):
|
||||
"""creator/slug query returns the agent directly."""
|
||||
mock_details = MagicMock()
|
||||
mock_details.creator = "testuser"
|
||||
mock_details.slug = "my-agent"
|
||||
mock_details.agent_name = "My Agent"
|
||||
mock_details.description = "A test agent"
|
||||
mock_details.rating = 4.5
|
||||
mock_details.runs = 100
|
||||
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_store_agent_details = AsyncMock(return_value=mock_details)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.store_db",
|
||||
return_value=mock_store,
|
||||
):
|
||||
response = await search_agents(
|
||||
query="testuser/my-agent",
|
||||
source="marketplace",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.count == 1
|
||||
assert response.agents[0].id == "testuser/my-agent"
|
||||
assert response.agents[0].name == "My Agent"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_slug_lookup_not_found_falls_back_to_search(self):
|
||||
"""creator/slug not found falls back to general search."""
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_store_agent_details = AsyncMock(side_effect=NotFoundError(""))
|
||||
|
||||
# Fallback search returns results
|
||||
mock_search_results = MagicMock()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.creator = "other"
|
||||
mock_agent.slug = "similar-agent"
|
||||
mock_agent.agent_name = "Similar Agent"
|
||||
mock_agent.description = "A similar agent"
|
||||
mock_agent.rating = 3.0
|
||||
mock_agent.runs = 50
|
||||
mock_search_results.agents = [mock_agent]
|
||||
|
||||
mock_store.get_store_agents = AsyncMock(return_value=mock_search_results)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.store_db",
|
||||
return_value=mock_store,
|
||||
):
|
||||
response = await search_agents(
|
||||
query="testuser/my-agent",
|
||||
source="marketplace",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.count == 1
|
||||
assert response.agents[0].id == "other/similar-agent"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_slug_lookup_not_found_no_search_results(self):
|
||||
"""creator/slug not found and search returns nothing."""
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_store_agent_details = AsyncMock(side_effect=NotFoundError(""))
|
||||
mock_search_results = MagicMock()
|
||||
mock_search_results.agents = []
|
||||
mock_store.get_store_agents = AsyncMock(return_value=mock_search_results)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.store_db",
|
||||
return_value=mock_store,
|
||||
):
|
||||
response = await search_agents(
|
||||
query="testuser/nonexistent",
|
||||
source="marketplace",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_non_slug_query_goes_to_search(self):
|
||||
"""Regular keyword query skips slug lookup and goes to search."""
|
||||
mock_store = MagicMock()
|
||||
mock_search_results = MagicMock()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.creator = "creator1"
|
||||
mock_agent.slug = "email-agent"
|
||||
mock_agent.agent_name = "Email Agent"
|
||||
mock_agent.description = "Sends emails"
|
||||
mock_agent.rating = 4.0
|
||||
mock_agent.runs = 200
|
||||
mock_search_results.agents = [mock_agent]
|
||||
mock_store.get_store_agents = AsyncMock(return_value=mock_search_results)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.store_db",
|
||||
return_value=mock_store,
|
||||
):
|
||||
response = await search_agents(
|
||||
query="email",
|
||||
source="marketplace",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
# get_store_agent_details should NOT have been called
|
||||
mock_store.get_store_agent_details.assert_not_called()
|
||||
|
||||
|
||||
class TestLibraryUUIDLookup:
|
||||
"""Tests for UUID direct lookup in library search."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_found_by_graph_id(self):
|
||||
"""UUID query matching a graph_id returns the agent directly."""
|
||||
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.id = "lib-agent-id"
|
||||
mock_agent.name = "My Library Agent"
|
||||
mock_agent.description = "A library agent"
|
||||
mock_agent.creator_name = "testuser"
|
||||
mock_agent.status.value = "HEALTHY"
|
||||
mock_agent.can_access_graph = True
|
||||
mock_agent.has_external_trigger = False
|
||||
mock_agent.new_output = False
|
||||
mock_agent.graph_id = agent_id
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.input_schema = {}
|
||||
mock_agent.output_schema = {}
|
||||
|
||||
mock_lib_db = MagicMock()
|
||||
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.library_db",
|
||||
return_value=mock_lib_db,
|
||||
):
|
||||
response = await search_agents(
|
||||
query=agent_id,
|
||||
source="library",
|
||||
session_id="test-session",
|
||||
user_id=_TEST_USER_ID,
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.count == 1
|
||||
assert response.agents[0].name == "My Library Agent"
|
||||
@@ -164,8 +164,9 @@ class BaseTool:
|
||||
|
||||
"""
|
||||
if self.requires_auth and not user_id:
|
||||
logger.error(
|
||||
f"Attempted tool call for {self.name} but user not authenticated"
|
||||
logger.warning(
|
||||
"Attempted tool call for %s but user not authenticated",
|
||||
self.name,
|
||||
)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
@@ -196,7 +197,7 @@ class BaseTool:
|
||||
output=raw_output,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in {self.name}: {e}", exc_info=True)
|
||||
logger.warning("Error in %s", self.name, exc_info=True)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
|
||||
@@ -22,6 +22,7 @@ from e2b import AsyncSandbox
|
||||
from e2b.exceptions import TimeoutException
|
||||
|
||||
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
|
||||
from backend.copilot.integration_creds import get_integration_env_vars
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
@@ -74,7 +75,10 @@ class BashExecTool(BaseTool):
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
# True because _execute_on_e2b injects user tokens (GH_TOKEN etc.)
|
||||
# when user_id is present. Defense-in-depth: ensures only authenticated
|
||||
# users reach the token injection path.
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
@@ -82,6 +86,14 @@ class BashExecTool(BaseTool):
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Run a bash command on E2B (if available) or in a bubblewrap sandbox.
|
||||
|
||||
Dispatches to :meth:`_execute_on_e2b` when a sandbox is present in the
|
||||
current execution context, otherwise falls back to the local bubblewrap
|
||||
sandbox. Returns a :class:`BashExecResponse` on success or an
|
||||
:class:`ErrorResponse` when the sandbox is unavailable or the command
|
||||
is empty.
|
||||
"""
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
command: str = (kwargs.get("command") or "").strip()
|
||||
@@ -96,7 +108,9 @@ class BashExecTool(BaseTool):
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
|
||||
return await self._execute_on_e2b(
|
||||
sandbox, command, timeout, session_id, user_id
|
||||
)
|
||||
|
||||
# Bubblewrap fallback: local isolated execution.
|
||||
if not has_full_sandbox():
|
||||
@@ -133,19 +147,42 @@ class BashExecTool(BaseTool):
|
||||
command: str,
|
||||
timeout: int,
|
||||
session_id: str | None,
|
||||
user_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute *command* on the E2B sandbox via commands.run()."""
|
||||
"""Execute *command* on the E2B sandbox via commands.run().
|
||||
|
||||
Integration tokens (e.g. GH_TOKEN) are injected into the sandbox env
|
||||
for any user with connected accounts. E2B has full internet access, so
|
||||
CLI tools like ``gh`` work without manual authentication.
|
||||
"""
|
||||
envs: dict[str, str] = {
|
||||
"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin",
|
||||
}
|
||||
# Collect injected secret values so we can scrub them from output.
|
||||
secret_values: list[str] = []
|
||||
if user_id is not None:
|
||||
integration_env = await get_integration_env_vars(user_id)
|
||||
secret_values = [v for v in integration_env.values() if v]
|
||||
envs.update(integration_env)
|
||||
|
||||
try:
|
||||
result = await sandbox.commands.run(
|
||||
f"bash -c {shlex.quote(command)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=timeout,
|
||||
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
|
||||
envs=envs,
|
||||
)
|
||||
stdout = result.stdout or ""
|
||||
stderr = result.stderr or ""
|
||||
# Scrub injected tokens from command output to prevent exfiltration
|
||||
# via `echo $GH_TOKEN`, `env`, `printenv`, etc.
|
||||
for secret in secret_values:
|
||||
stdout = stdout.replace(secret, "[REDACTED]")
|
||||
stderr = stderr.replace(secret, "[REDACTED]")
|
||||
return BashExecResponse(
|
||||
message=f"Command executed on E2B (exit {result.exit_code})",
|
||||
stdout=result.stdout or "",
|
||||
stderr=result.stderr or "",
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
exit_code=result.exit_code,
|
||||
timed_out=False,
|
||||
session_id=session_id,
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Tests for BashExecTool — E2B path with token injection."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ._test_data import make_session
|
||||
from .bash_exec import BashExecTool
|
||||
from .models import BashExecResponse
|
||||
|
||||
_USER = "user-bash-exec-test"
|
||||
|
||||
|
||||
def _make_tool() -> BashExecTool:
|
||||
return BashExecTool()
|
||||
|
||||
|
||||
def _make_sandbox(exit_code: int = 0, stdout: str = "", stderr: str = "") -> MagicMock:
|
||||
result = MagicMock()
|
||||
result.exit_code = exit_code
|
||||
result.stdout = stdout
|
||||
result.stderr = stderr
|
||||
|
||||
sandbox = MagicMock()
|
||||
sandbox.commands.run = AsyncMock(return_value=result)
|
||||
return sandbox
|
||||
|
||||
|
||||
class TestBashExecE2BTokenInjection:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_token_injected_when_user_id_set(self):
|
||||
"""When user_id is provided, integration env vars are merged into sandbox envs."""
|
||||
tool = _make_tool()
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
env_vars = {"GH_TOKEN": "gh-secret", "GITHUB_TOKEN": "gh-secret"}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value=env_vars),
|
||||
) as mock_get_env:
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
timeout=10,
|
||||
session_id=session.session_id,
|
||||
user_id=_USER,
|
||||
)
|
||||
|
||||
mock_get_env.assert_awaited_once_with(_USER)
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert call_kwargs["envs"]["GH_TOKEN"] == "gh-secret"
|
||||
assert call_kwargs["envs"]["GITHUB_TOKEN"] == "gh-secret"
|
||||
assert isinstance(result, BashExecResponse)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_token_injection_when_user_id_is_none(self):
|
||||
"""When user_id is None, get_integration_env_vars must NOT be called."""
|
||||
tool = _make_tool()
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value={"GH_TOKEN": "should-not-appear"}),
|
||||
) as mock_get_env:
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
timeout=10,
|
||||
session_id=session.session_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
mock_get_env.assert_not_called()
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert "GH_TOKEN" not in call_kwargs["envs"]
|
||||
assert isinstance(result, BashExecResponse)
|
||||
@@ -0,0 +1,196 @@
|
||||
"""Tool for prompting the user to connect a required integration.
|
||||
|
||||
When the copilot encounters an authentication failure (e.g. `gh` CLI returns
|
||||
"authentication required"), it calls this tool to surface the credentials
|
||||
setup card in the chat — the same UI that appears when a GitHub block runs
|
||||
without configured credentials.
|
||||
"""
|
||||
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.providers import SUPPORTED_PROVIDERS, get_provider_auth_types
|
||||
from backend.copilot.tools.models import (
|
||||
ErrorResponse,
|
||||
ResponseType,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
|
||||
|
||||
class _CredentialEntry(TypedDict):
|
||||
"""Shape of each entry inside SetupRequirementsResponse.user_readiness.missing_credentials.
|
||||
|
||||
Partially overlaps with :class:`~backend.data.model.CredentialsMetaInput`
|
||||
(``id``, ``title``, ``provider``) but carries extra UI-facing fields
|
||||
(``types``, ``scopes``) that the frontend ``SetupRequirementsCard`` needs
|
||||
to render the inline credential setup card.
|
||||
|
||||
Display name is derived from :data:`SUPPORTED_PROVIDERS` at build time
|
||||
rather than stored here — eliminates the old ``provider_name`` field.
|
||||
``types`` replaces the old singular ``type`` field; the frontend already
|
||||
prefers ``types`` and only fell back to ``type`` for compatibility.
|
||||
"""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
# Slug used as the credential key (e.g. "github").
|
||||
provider: str
|
||||
# All supported credential types the user can choose from (e.g. ["api_key", "oauth2"]).
|
||||
# The first element is the default/primary type.
|
||||
types: list[str]
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class ConnectIntegrationTool(BaseTool):
|
||||
"""Surface the credentials setup UI when an integration is not connected."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "connect_integration"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Prompt the user to connect a required integration (e.g. GitHub). "
|
||||
"Call this when an external CLI or API call fails because the user "
|
||||
"has not connected the relevant account. "
|
||||
"The tool surfaces a credentials setup card in the chat so the user "
|
||||
"can authenticate without leaving the page. "
|
||||
"After the user connects the account, retry the operation. "
|
||||
"In E2B/cloud sandbox mode the token (GH_TOKEN/GITHUB_TOKEN) is "
|
||||
"automatically injected per-command in bash_exec — no manual export needed. "
|
||||
"In local bubblewrap mode network is isolated so GitHub CLI commands "
|
||||
"will still fail after connecting; inform the user of this limitation."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Integration provider slug, e.g. 'github'. "
|
||||
"Must be one of the supported providers."
|
||||
),
|
||||
"enum": list(SUPPORTED_PROVIDERS.keys()),
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Brief explanation of why the integration is needed, "
|
||||
"shown to the user in the setup card."
|
||||
),
|
||||
"maxLength": 500,
|
||||
},
|
||||
"scopes": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"OAuth scopes to request. Omit to use the provider default. "
|
||||
"Add extra scopes when you need more access — e.g. for GitHub: "
|
||||
"'repo' (clone/push/pull), 'read:org' (org membership), "
|
||||
"'workflow' (GitHub Actions). "
|
||||
"Requesting only the scopes you actually need is best practice."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["provider"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
# Require auth so only authenticated users can trigger the setup card.
|
||||
# The card itself is user-agnostic (no per-user data needed), so
|
||||
# user_id is intentionally unused in _execute.
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Build and return a :class:`SetupRequirementsResponse` for the requested provider.
|
||||
|
||||
Validates the *provider* slug against the known registry, merges any
|
||||
agent-requested OAuth *scopes* with the provider defaults, and constructs
|
||||
the credential setup card payload that the frontend renders as an inline
|
||||
authentication prompt.
|
||||
|
||||
Returns an :class:`ErrorResponse` if *provider* is unknown.
|
||||
"""
|
||||
_ = user_id # setup card is user-agnostic; auth is enforced via requires_auth
|
||||
session_id = session.session_id if session else None
|
||||
provider: str = (kwargs.get("provider") or "").strip().lower()
|
||||
reason: str = (kwargs.get("reason") or "").strip()[
|
||||
:500
|
||||
] # cap LLM-controlled text
|
||||
extra_scopes: list[str] = [
|
||||
str(s).strip() for s in (kwargs.get("scopes") or []) if str(s).strip()
|
||||
]
|
||||
|
||||
entry = SUPPORTED_PROVIDERS.get(provider)
|
||||
if not entry:
|
||||
supported = ", ".join(f"'{p}'" for p in SUPPORTED_PROVIDERS)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Unknown provider '{provider}'. "
|
||||
f"Supported providers: {supported}."
|
||||
),
|
||||
error="unknown_provider",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
display_name: str = entry["name"]
|
||||
supported_types: list[str] = get_provider_auth_types(provider)
|
||||
# Merge agent-requested scopes with provider defaults (deduplicated, order preserved).
|
||||
default_scopes: list[str] = entry["default_scopes"]
|
||||
seen: set[str] = set()
|
||||
scopes: list[str] = []
|
||||
for s in default_scopes + extra_scopes:
|
||||
if s not in seen:
|
||||
seen.add(s)
|
||||
scopes.append(s)
|
||||
field_key = f"{provider}_credentials"
|
||||
|
||||
message_parts = [
|
||||
f"To continue, please connect your {display_name} account.",
|
||||
]
|
||||
if reason:
|
||||
message_parts.append(reason)
|
||||
|
||||
credential_entry: _CredentialEntry = {
|
||||
"id": field_key,
|
||||
"title": f"{display_name} Credentials",
|
||||
"provider": provider,
|
||||
"types": supported_types,
|
||||
"scopes": scopes,
|
||||
}
|
||||
missing_credentials: dict[str, _CredentialEntry] = {field_key: credential_entry}
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
type=ResponseType.SETUP_REQUIREMENTS,
|
||||
message=" ".join(message_parts),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=f"connect_{provider}",
|
||||
agent_name=display_name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_credentials,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": [missing_credentials[field_key]],
|
||||
"inputs": [],
|
||||
"execution_modes": [],
|
||||
},
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Tests for ConnectIntegrationTool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from ._test_data import make_session
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
from .models import ErrorResponse, SetupRequirementsResponse
|
||||
|
||||
_TEST_USER_ID = "test-user-connect-integration"
|
||||
|
||||
|
||||
class TestConnectIntegrationTool:
|
||||
def _make_tool(self) -> ConnectIntegrationTool:
|
||||
return ConnectIntegrationTool()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unknown_provider_returns_error(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="nonexistent"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "unknown_provider"
|
||||
assert "nonexistent" in result.message
|
||||
assert "github" in result.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_provider_returns_error(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "unknown_provider"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_provider_returns_setup_response(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.setup_info.agent_name == "GitHub"
|
||||
assert result.setup_info.agent_id == "connect_github"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_has_missing_credentials_in_readiness(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
readiness = result.setup_info.user_readiness
|
||||
assert readiness.has_all_credentials is False
|
||||
assert readiness.ready_to_run is False
|
||||
assert "github_credentials" in readiness.missing_credentials
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_requirements_include_credential_entry(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
creds = result.setup_info.requirements["credentials"]
|
||||
assert len(creds) == 1
|
||||
assert creds[0]["provider"] == "github"
|
||||
assert creds[0]["id"] == "github_credentials"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_reason_appears_in_message(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
reason = "Needed to create a pull request."
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github", reason=reason
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert reason in result.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_session_id_propagated(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.session_id == session.session_id
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_provider_case_insensitive(self):
|
||||
"""Provider slug is normalised to lowercase before lookup."""
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="GitHub"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
|
||||
def test_tool_name(self):
|
||||
assert ConnectIntegrationTool().name == "connect_integration"
|
||||
|
||||
def test_requires_auth(self):
|
||||
assert ConnectIntegrationTool().requires_auth is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unauthenticated_user_gets_need_login_response(self):
|
||||
"""execute() with user_id=None must return NeedLoginResponse, not the setup card.
|
||||
|
||||
This verifies that the requires_auth guard in BaseTool.execute() fires
|
||||
before _execute() is called, so unauthenticated callers cannot probe
|
||||
which integrations are configured.
|
||||
"""
|
||||
import json
|
||||
|
||||
tool = self._make_tool()
|
||||
# Session still needs a user_id string; the None is passed to execute()
|
||||
# to simulate an unauthenticated call.
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool.execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
tool_call_id="test-call-id",
|
||||
provider="github",
|
||||
)
|
||||
raw = result.output
|
||||
output = json.loads(raw) if isinstance(raw, str) else raw
|
||||
assert output.get("type") == "need_login"
|
||||
assert result.success is False
|
||||
@@ -41,8 +41,7 @@ import contextlib
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Literal
|
||||
|
||||
from e2b import AsyncSandbox
|
||||
from e2b.sandbox.sandbox_api import SandboxLifecycle
|
||||
from e2b import AsyncSandbox, SandboxLifecycle
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
|
||||
@@ -19,7 +19,8 @@ class FindAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Discover agents from the marketplace based on capabilities and user needs."
|
||||
"Discover agents from the marketplace based on capabilities and "
|
||||
"user needs, or look up a specific agent by its creator/slug ID."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -29,7 +30,7 @@ class FindAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query describing what the user wants to accomplish. Use single keywords for best results.",
|
||||
"description": "Search query describing what the user wants to accomplish, or a creator/slug ID (e.g. 'username/agent-name') for direct lookup. Use single keywords for best results.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
@@ -38,6 +39,7 @@ class FindAgentTool(BaseTool):
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Search marketplace for agents matching the query."""
|
||||
return await search_agents(
|
||||
query=kwargs.get("query", "").strip(),
|
||||
source="marketplace",
|
||||
|
||||
@@ -15,6 +15,7 @@ from .models import (
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
)
|
||||
from .utils import is_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,7 +38,8 @@ COPILOT_EXCLUDED_BLOCK_TYPES = {
|
||||
|
||||
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
||||
COPILOT_EXCLUDED_BLOCK_IDS = {
|
||||
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
|
||||
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology;
|
||||
# usable in agent graphs (guide hardcodes its ID) but cannot run standalone.
|
||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||
}
|
||||
|
||||
@@ -52,7 +54,8 @@ class FindBlockTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for available blocks by name or description. "
|
||||
"Search for available blocks by name or description, or look up a "
|
||||
"specific block by its ID. "
|
||||
"Blocks are reusable components that perform specific tasks like "
|
||||
"sending emails, making API calls, processing text, etc. "
|
||||
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
|
||||
@@ -68,7 +71,8 @@ class FindBlockTool(BaseTool):
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find blocks by name or description. "
|
||||
"Search query to find blocks by name or description, "
|
||||
"or a block ID (UUID) for direct lookup. "
|
||||
"Use keywords like 'email', 'http', 'text', 'ai', etc."
|
||||
),
|
||||
},
|
||||
@@ -113,11 +117,77 @@ class FindBlockTool(BaseTool):
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
message="Please provide a search query or block ID",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Direct ID lookup if query looks like a UUID
|
||||
if is_uuid(query):
|
||||
block = get_block(query.lower())
|
||||
if block:
|
||||
if block.disabled:
|
||||
return NoResultsResponse(
|
||||
message=f"Block '{block.name}' (ID: {block.id}) is disabled and cannot be used.",
|
||||
suggestions=["Search for an alternative block by name"],
|
||||
session_id=session_id,
|
||||
)
|
||||
if (
|
||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
):
|
||||
if block.block_type == BlockType.MCP_TOOL:
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' (ID: {block.id}) is not "
|
||||
"runnable through find_block/run_block. Use "
|
||||
"run_mcp_tool instead."
|
||||
),
|
||||
suggestions=[
|
||||
"Use run_mcp_tool to discover and run this MCP tool",
|
||||
"Search for an alternative block by name",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' (ID: {block.id}) is not available "
|
||||
"in CoPilot. It can only be used within agent graphs."
|
||||
),
|
||||
suggestions=[
|
||||
"Search for an alternative block by name",
|
||||
"Use this block in an agent graph instead",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
summary = BlockInfoSummary(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
description=(
|
||||
block.optimized_description or block.description or ""
|
||||
),
|
||||
categories=[c.value for c in block.categories],
|
||||
)
|
||||
if include_schemas:
|
||||
info = block.get_info()
|
||||
summary.input_schema = info.inputSchema
|
||||
summary.output_schema = info.outputSchema
|
||||
summary.static_output = info.staticOutput
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found block '{block.name}' by ID. "
|
||||
"To see inputs/outputs and execute it, use "
|
||||
"run_block with the block's 'id' - providing "
|
||||
"no inputs."
|
||||
),
|
||||
blocks=[summary],
|
||||
count=1,
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Search for blocks using hybrid search
|
||||
results, total = await search().unified_hybrid_search(
|
||||
query=query,
|
||||
|
||||
@@ -499,3 +499,123 @@ class TestFindBlockFiltering:
|
||||
assert response.blocks[0].input_schema == input_schema
|
||||
assert response.blocks[0].output_schema == output_schema
|
||||
assert response.blocks[0].static_output is True
|
||||
|
||||
|
||||
class TestFindBlockDirectLookup:
|
||||
"""Tests for direct UUID lookup in FindBlockTool."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_found(self):
|
||||
"""UUID query returns the block directly without search."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
block = make_mock_block(block_id, "Test Block", BlockType.STANDARD)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query=block_id
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert response.count == 1
|
||||
assert response.blocks[0].id == block_id
|
||||
assert response.blocks[0].name == "Test Block"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_not_found_falls_through(self):
|
||||
"""UUID that doesn't match any block falls through to search."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(return_value=([], 0))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
),
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query=block_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_disabled_block(self):
|
||||
"""UUID matching a disabled block returns NoResultsResponse."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
block = make_mock_block(
|
||||
block_id, "Disabled Block", BlockType.STANDARD, disabled=True
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query=block_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
assert "disabled" in response.message.lower()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_excluded_block_type(self):
|
||||
"""UUID matching an excluded block type returns NoResultsResponse."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
block = make_mock_block(block_id, "Input Block", BlockType.INPUT)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query=block_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
assert "not available" in response.message.lower()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_excluded_block_id(self):
|
||||
"""UUID matching an excluded block ID returns NoResultsResponse."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
block = make_mock_block(
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query=smart_decision_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
assert "not available" in response.message.lower()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Shared utilities for chat tools."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import model as library_model
|
||||
@@ -19,6 +20,26 @@ from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Shared UUID v4 pattern used by multiple tools for direct ID lookups.
|
||||
_UUID_V4_PATTERN = re.compile(
|
||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def is_uuid(text: str) -> bool:
|
||||
"""Check if text is a valid UUID v4."""
|
||||
return bool(_UUID_V4_PATTERN.match(text.strip()))
|
||||
|
||||
|
||||
# Matches "creator/slug" identifiers used in the marketplace
|
||||
_CREATOR_SLUG_PATTERN = re.compile(r"^[\w-]+/[\w-]+$")
|
||||
|
||||
|
||||
def is_creator_slug(text: str) -> bool:
|
||||
"""Check if text matches a 'creator/slug' marketplace identifier."""
|
||||
return bool(_CREATOR_SLUG_PATTERN.match(text.strip()))
|
||||
|
||||
|
||||
async def fetch_graph_from_store_slug(
|
||||
username: str,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -10,7 +11,9 @@ from pydantic import BaseModel
|
||||
from backend.copilot.context import (
|
||||
E2B_WORKDIR,
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
get_workspace_manager,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
@@ -24,6 +27,10 @@ from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Sentinel file_id used when a tool-result file is read directly from the local
|
||||
# host filesystem (rather than from workspace storage).
|
||||
_LOCAL_TOOL_RESULT_FILE_ID = "local"
|
||||
|
||||
|
||||
async def _resolve_write_content(
|
||||
content_text: str | None,
|
||||
@@ -275,6 +282,93 @@ class WorkspaceFileContentResponse(ToolResponseBase):
|
||||
content_base64: str
|
||||
|
||||
|
||||
_MAX_LOCAL_TOOL_RESULT_BYTES = 10 * 1024 * 1024 # 10 MB
|
||||
|
||||
|
||||
def _read_local_tool_result(
|
||||
path: str,
|
||||
char_offset: int,
|
||||
char_length: Optional[int],
|
||||
session_id: str,
|
||||
sdk_cwd: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""Read an SDK tool-result file from local disk.
|
||||
|
||||
This is a fallback for when the model mistakenly calls
|
||||
``read_workspace_file`` with an SDK tool-result path that only exists on
|
||||
the host filesystem, not in cloud workspace storage.
|
||||
|
||||
Defence-in-depth: validates *path* via :func:`is_allowed_local_path`
|
||||
regardless of what the caller has already checked.
|
||||
"""
|
||||
# TOCTOU: path validated then opened separately. Acceptable because
|
||||
# the tool-results directory is server-controlled, not user-writable.
|
||||
expanded = os.path.realpath(os.path.expanduser(path))
|
||||
# Defence-in-depth: re-check with resolved path (caller checked raw path).
|
||||
if not is_allowed_local_path(expanded, sdk_cwd or get_sdk_cwd()):
|
||||
return ErrorResponse(
|
||||
message=f"Path not allowed: {os.path.basename(path)}", session_id=session_id
|
||||
)
|
||||
try:
|
||||
# The 10 MB cap (_MAX_LOCAL_TOOL_RESULT_BYTES) bounds memory usage.
|
||||
# Pre-read size check prevents loading files far above the cap;
|
||||
# the remaining TOCTOU gap is acceptable for server-controlled paths.
|
||||
file_size = os.path.getsize(expanded)
|
||||
if file_size > _MAX_LOCAL_TOOL_RESULT_BYTES:
|
||||
return ErrorResponse(
|
||||
message=(f"File too large: {os.path.basename(path)}"),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Detect binary files: try strict UTF-8 first, fall back to
|
||||
# base64-encoding the raw bytes for binary content.
|
||||
with open(expanded, "rb") as fh:
|
||||
raw = fh.read()
|
||||
try:
|
||||
text_content = raw.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# Binary file — return raw base64, ignore char_offset/char_length
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=_LOCAL_TOOL_RESULT_FILE_ID,
|
||||
name=os.path.basename(path),
|
||||
path=path,
|
||||
mime_type=mimetypes.guess_type(path)[0] or "application/octet-stream",
|
||||
content_base64=base64.b64encode(raw).decode("ascii"),
|
||||
message=(
|
||||
f"Read {file_size:,} bytes (binary) from local tool-result "
|
||||
f"{os.path.basename(path)}"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
end = (
|
||||
char_offset + char_length if char_length is not None else len(text_content)
|
||||
)
|
||||
slice_text = text_content[char_offset:end]
|
||||
except FileNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {os.path.basename(path)}", session_id=session_id
|
||||
)
|
||||
except Exception as exc:
|
||||
return ErrorResponse(
|
||||
message=f"Error reading file: {type(exc).__name__}", session_id=session_id
|
||||
)
|
||||
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=_LOCAL_TOOL_RESULT_FILE_ID,
|
||||
name=os.path.basename(path),
|
||||
path=path,
|
||||
mime_type=mimetypes.guess_type(path)[0] or "text/plain",
|
||||
content_base64=base64.b64encode(slice_text.encode("utf-8")).decode("ascii"),
|
||||
message=(
|
||||
f"Read chars {char_offset}\u2013{char_offset + len(slice_text)} "
|
||||
f"of {len(text_content):,} chars from local tool-result "
|
||||
f"{os.path.basename(path)}"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceFileMetadataResponse(ToolResponseBase):
|
||||
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
|
||||
|
||||
@@ -533,6 +627,14 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
# Fallback: if the path is an SDK tool-result on local disk,
|
||||
# read it directly instead of failing. The model sometimes
|
||||
# calls read_workspace_file for these paths by mistake.
|
||||
sdk_cwd = get_sdk_cwd()
|
||||
if path and is_allowed_local_path(path, sdk_cwd):
|
||||
return _read_local_tool_result(
|
||||
path, char_offset, char_length, session_id, sdk_cwd=sdk_cwd
|
||||
)
|
||||
return resolved
|
||||
target_file_id, file_info = resolved
|
||||
|
||||
|
||||
@@ -2,18 +2,25 @@
|
||||
|
||||
import base64
|
||||
import os
|
||||
import shutil
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import SDK_PROJECTS_DIR, _current_project_dir
|
||||
from backend.copilot.tools._test_data import make_session, setup_test_data
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
from backend.copilot.tools.workspace_files import (
|
||||
_MAX_LOCAL_TOOL_RESULT_BYTES,
|
||||
DeleteWorkspaceFileTool,
|
||||
ListWorkspaceFilesTool,
|
||||
ReadWorkspaceFileTool,
|
||||
WorkspaceDeleteResponse,
|
||||
WorkspaceFileContentResponse,
|
||||
WorkspaceFileListResponse,
|
||||
WorkspaceWriteResponse,
|
||||
WriteWorkspaceFileTool,
|
||||
_read_local_tool_result,
|
||||
_resolve_write_content,
|
||||
_validate_ephemeral_path,
|
||||
)
|
||||
@@ -325,3 +332,294 @@ async def test_write_workspace_file_source_path(setup_test_data):
|
||||
await delete_tool._execute(
|
||||
user_id=user.id, session=session, file_id=write_resp.file_id
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_local_tool_result — local disk fallback for SDK tool-result files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CONV_UUID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
|
||||
|
||||
class TestReadLocalToolResult:
|
||||
"""Tests for _read_local_tool_result (local disk fallback)."""
|
||||
|
||||
def _make_tool_result(self, encoded: str, filename: str, content: bytes) -> str:
|
||||
"""Create a tool-results file and return its path."""
|
||||
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, _CONV_UUID, "tool-results")
|
||||
os.makedirs(tool_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_dir, filename)
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(content)
|
||||
return filepath
|
||||
|
||||
def _cleanup(self, encoded: str) -> None:
|
||||
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
|
||||
|
||||
def test_read_text_file(self):
|
||||
"""Read a UTF-8 text tool-result file."""
|
||||
encoded = "-tmp-copilot-local-read-text"
|
||||
path = self._make_tool_result(encoded, "output.txt", b"hello world")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == "hello world"
|
||||
assert "text/plain" in result.mime_type
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_read_text_with_offset(self):
|
||||
"""Read a slice of a text file using char_offset and char_length."""
|
||||
encoded = "-tmp-copilot-local-read-offset"
|
||||
path = self._make_tool_result(encoded, "data.txt", b"ABCDEFGHIJ")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 3, 4, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == "DEFG"
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_read_binary_file(self):
|
||||
"""Binary files are returned as raw base64."""
|
||||
encoded = "-tmp-copilot-local-read-binary"
|
||||
binary_data = bytes(range(256))
|
||||
path = self._make_tool_result(encoded, "image.png", binary_data)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64)
|
||||
assert decoded == binary_data
|
||||
assert "binary" in result.message
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_disallowed_path_rejected(self):
|
||||
"""Paths not under allowed directories are rejected."""
|
||||
result = _read_local_tool_result("/etc/passwd", 0, None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "not allowed" in result.message.lower()
|
||||
|
||||
def test_file_not_found(self):
|
||||
"""Missing files return an error."""
|
||||
encoded = "-tmp-copilot-local-read-missing"
|
||||
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, _CONV_UUID, "tool-results")
|
||||
os.makedirs(tool_dir, exist_ok=True)
|
||||
path = os.path.join(tool_dir, "nope.txt")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "not found" in result.message.lower()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_file_too_large(self, monkeypatch):
|
||||
"""Files exceeding the size limit are rejected."""
|
||||
encoded = "-tmp-copilot-local-read-large"
|
||||
# Create a small file but fake os.path.getsize to return a huge value
|
||||
path = self._make_tool_result(encoded, "big.txt", b"small")
|
||||
token = _current_project_dir.set(encoded)
|
||||
monkeypatch.setattr(
|
||||
"os.path.getsize", lambda _: _MAX_LOCAL_TOOL_RESULT_BYTES + 1
|
||||
)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "too large" in result.message.lower()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_offset_beyond_file_length(self):
|
||||
"""Offset past end-of-file returns empty content."""
|
||||
encoded = "-tmp-copilot-local-read-past-eof"
|
||||
path = self._make_tool_result(encoded, "short.txt", b"abc")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 999, 10, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == ""
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_zero_length_read(self):
|
||||
"""Requesting zero characters returns empty content."""
|
||||
encoded = "-tmp-copilot-local-read-zero-len"
|
||||
path = self._make_tool_result(encoded, "data.txt", b"ABCDEF")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 2, 0, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == ""
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_mime_type_from_json_extension(self):
|
||||
"""JSON files get application/json MIME type, not hardcoded text/plain."""
|
||||
encoded = "-tmp-copilot-local-read-json"
|
||||
path = self._make_tool_result(encoded, "result.json", b'{"key": "value"}')
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
assert result.mime_type == "application/json"
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_mime_type_from_png_extension(self):
|
||||
"""Binary .png files get image/png MIME type via mimetypes."""
|
||||
encoded = "-tmp-copilot-local-read-png-mime"
|
||||
binary_data = bytes(range(256))
|
||||
path = self._make_tool_result(encoded, "chart.png", binary_data)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
assert result.mime_type == "image/png"
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_explicit_sdk_cwd_parameter(self):
|
||||
"""The sdk_cwd parameter overrides get_sdk_cwd() for path validation."""
|
||||
encoded = "-tmp-copilot-local-read-sdkcwd"
|
||||
path = self._make_tool_result(encoded, "out.txt", b"content")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
# Pass sdk_cwd explicitly — should still succeed because the path
|
||||
# is under SDK_PROJECTS_DIR which is always allowed.
|
||||
result = _read_local_tool_result(
|
||||
path, 0, None, "s1", sdk_cwd="/tmp/copilot-test"
|
||||
)
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == "content"
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_offset_with_no_length_reads_to_end(self):
|
||||
"""When char_length is None, read from offset to end of file."""
|
||||
encoded = "-tmp-copilot-local-read-offset-noLen"
|
||||
path = self._make_tool_result(encoded, "data.txt", b"0123456789")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 5, None, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == "56789"
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ReadWorkspaceFileTool fallback to _read_local_tool_result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_read_workspace_file_falls_back_to_local_tool_result(setup_test_data):
|
||||
"""When _resolve_file returns ErrorResponse for an allowed local path,
|
||||
ReadWorkspaceFileTool should fall back to _read_local_tool_result."""
|
||||
user = setup_test_data["user"]
|
||||
session = make_session(user.id)
|
||||
|
||||
# Create a real tool-result file on disk so the fallback can read it.
|
||||
encoded = "-tmp-copilot-fallback-test"
|
||||
conv_uuid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-results")
|
||||
os.makedirs(tool_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_dir, "result.txt")
|
||||
with open(filepath, "w") as f:
|
||||
f.write("fallback content")
|
||||
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
# Mock _resolve_file to return an ErrorResponse (simulating "file not
|
||||
# found in workspace") so the fallback branch is exercised.
|
||||
mock_resolve = AsyncMock(
|
||||
return_value=ErrorResponse(
|
||||
message="File not found at path: result.txt",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
)
|
||||
with patch("backend.copilot.tools.workspace_files._resolve_file", mock_resolve):
|
||||
read_tool = ReadWorkspaceFileTool()
|
||||
result = await read_tool._execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
path=filepath,
|
||||
)
|
||||
|
||||
# Should have fallen back to _read_local_tool_result and succeeded.
|
||||
assert isinstance(result, WorkspaceFileContentResponse), (
|
||||
f"Expected fallback to local read, got {type(result).__name__}: "
|
||||
f"{getattr(result, 'message', '')}"
|
||||
)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == "fallback content"
|
||||
mock_resolve.assert_awaited_once()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_read_workspace_file_no_fallback_when_resolve_succeeds(setup_test_data):
|
||||
"""When _resolve_file succeeds, the local-disk fallback must NOT be invoked."""
|
||||
user = setup_test_data["user"]
|
||||
session = make_session(user.id)
|
||||
|
||||
fake_file_id = "fake-file-id-001"
|
||||
fake_content = b"workspace content"
|
||||
|
||||
# Build a minimal file_info stub that the tool's happy-path needs.
|
||||
class _FakeFileInfo:
|
||||
id = fake_file_id
|
||||
name = "result.json"
|
||||
path = "/result.json"
|
||||
mime_type = "text/plain"
|
||||
size_bytes = len(fake_content)
|
||||
|
||||
mock_resolve = AsyncMock(return_value=(fake_file_id, _FakeFileInfo()))
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.read_file_by_id = AsyncMock(return_value=fake_content)
|
||||
|
||||
with (
|
||||
patch("backend.copilot.tools.workspace_files._resolve_file", mock_resolve),
|
||||
patch(
|
||||
"backend.copilot.tools.workspace_files.get_workspace_manager",
|
||||
AsyncMock(return_value=mock_manager),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.workspace_files._read_local_tool_result"
|
||||
) as patched_local,
|
||||
):
|
||||
read_tool = ReadWorkspaceFileTool()
|
||||
result = await read_tool._execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
file_id=fake_file_id,
|
||||
)
|
||||
|
||||
# Fallback must not have been called.
|
||||
patched_local.assert_not_called()
|
||||
# Normal workspace path must have produced a content response.
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
assert base64.b64decode(result.content_base64) == fake_content
|
||||
|
||||
@@ -423,7 +423,7 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
BlockCost(
|
||||
cost_amount=10,
|
||||
cost_filter={
|
||||
"model": FluxKontextModelName.PRO.api_name,
|
||||
"model": FluxKontextModelName.FLUX_KONTEXT_PRO,
|
||||
"credentials": {
|
||||
"id": replicate_credentials.id,
|
||||
"provider": replicate_credentials.provider,
|
||||
@@ -434,7 +434,29 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
BlockCost(
|
||||
cost_amount=20,
|
||||
cost_filter={
|
||||
"model": FluxKontextModelName.MAX.api_name,
|
||||
"model": FluxKontextModelName.FLUX_KONTEXT_MAX,
|
||||
"credentials": {
|
||||
"id": replicate_credentials.id,
|
||||
"provider": replicate_credentials.provider,
|
||||
"type": replicate_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
BlockCost(
|
||||
cost_amount=14, # Nano Banana Pro
|
||||
cost_filter={
|
||||
"model": FluxKontextModelName.NANO_BANANA_PRO,
|
||||
"credentials": {
|
||||
"id": replicate_credentials.id,
|
||||
"provider": replicate_credentials.provider,
|
||||
"type": replicate_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
BlockCost(
|
||||
cost_amount=14, # Nano Banana 2
|
||||
cost_filter={
|
||||
"model": FluxKontextModelName.NANO_BANANA_2,
|
||||
"credentials": {
|
||||
"id": replicate_credentials.id,
|
||||
"provider": replicate_credentials.provider,
|
||||
@@ -632,6 +654,17 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
},
|
||||
},
|
||||
),
|
||||
BlockCost(
|
||||
cost_amount=14, # Nano Banana 2: same pricing tier as Pro
|
||||
cost_filter={
|
||||
"model": ImageGenModel.NANO_BANANA_2,
|
||||
"credentials": {
|
||||
"id": replicate_credentials.id,
|
||||
"provider": replicate_credentials.provider,
|
||||
"type": replicate_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
AIImageCustomizerBlock: [
|
||||
BlockCost(
|
||||
@@ -656,6 +689,17 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
},
|
||||
},
|
||||
),
|
||||
BlockCost(
|
||||
cost_amount=14, # Nano Banana 2: same pricing tier as Pro
|
||||
cost_filter={
|
||||
"model": GeminiImageModel.NANO_BANANA_2,
|
||||
"credentials": {
|
||||
"id": replicate_credentials.id,
|
||||
"provider": replicate_credentials.provider,
|
||||
"type": replicate_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
VideoNarrationBlock: [
|
||||
BlockCost(
|
||||
|
||||
@@ -30,6 +30,7 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
|
||||
MAX_NODE_EXECUTIONS_FETCH = 1000
|
||||
MAX_LIBRARY_AGENT_EXECUTIONS_FETCH = 10
|
||||
MAX_LIBRARY_AGENTS_LAST_EXECUTED_FETCH = 1000
|
||||
|
||||
# Default limits for potentially large result sets
|
||||
MAX_CREDIT_REFUND_REQUESTS_FETCH = 100
|
||||
@@ -109,6 +110,8 @@ def library_agent_include(
|
||||
- Listing optimization (no nodes/executions): ~2s for 15 agents vs potential timeouts
|
||||
- Unlimited executions: varies by user (thousands of executions = timeouts)
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
result: prisma.types.LibraryAgentInclude = {
|
||||
"Creator": True, # Always needed for creator info
|
||||
"Folder": True, # Always needed for folder info
|
||||
@@ -126,7 +129,7 @@ def library_agent_include(
|
||||
if include_executions:
|
||||
agent_graph_include["Executions"] = {
|
||||
"where": {"userId": user_id},
|
||||
"order_by": {"createdAt": "desc"},
|
||||
"order_by": {"updatedAt": "desc"},
|
||||
"take": execution_limit,
|
||||
}
|
||||
|
||||
|
||||
@@ -110,6 +110,11 @@ def normalize_email(email: str) -> str:
|
||||
return email.strip().lower()
|
||||
|
||||
|
||||
def is_internal_email(email: str) -> bool:
|
||||
"""Return True for @agpt.co addresses, which always bypass the invite gate."""
|
||||
return normalize_email(email).endswith("@agpt.co")
|
||||
|
||||
|
||||
def _normalize_name(name: Optional[str]) -> Optional[str]:
|
||||
if name is None:
|
||||
return None
|
||||
@@ -210,6 +215,25 @@ async def _apply_tally_understanding(
|
||||
)
|
||||
|
||||
|
||||
async def check_invite_eligibility(email: str) -> bool:
|
||||
"""Check if an email is allowed to sign up based on the invite list.
|
||||
|
||||
Args:
|
||||
email: The email to check (will be normalized internally).
|
||||
|
||||
Returns True if the email has an active (INVITED) invite record.
|
||||
Does NOT check enable_invite_gate — the caller is responsible for that.
|
||||
"""
|
||||
email = normalize_email(email)
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"email": email}
|
||||
)
|
||||
return (
|
||||
invited_user is not None
|
||||
and invited_user.status == prisma.enums.InvitedUserStatus.INVITED
|
||||
)
|
||||
|
||||
|
||||
async def list_invited_users(
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
@@ -664,7 +688,7 @@ async def get_or_activate_user(user_data: dict) -> User:
|
||||
if existing_user is not None:
|
||||
return existing_user
|
||||
|
||||
if not _settings.config.enable_invite_gate or normalized_email.endswith("@agpt.co"):
|
||||
if not _settings.config.enable_invite_gate or is_internal_email(normalized_email):
|
||||
return await _open_signup_create_user(
|
||||
auth_user_id, normalized_email, metadata_name
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ from backend.util.exceptions import NotAuthorizedError, PreconditionFailed
|
||||
from .invited_user import (
|
||||
InvitedUserRecord,
|
||||
bulk_create_invited_users_from_file,
|
||||
check_invite_eligibility,
|
||||
create_invited_user,
|
||||
get_or_activate_user,
|
||||
retry_invited_user_tally,
|
||||
@@ -247,6 +248,10 @@ async def test_get_or_activate_user_creates_user_from_invite(
|
||||
"backend.data.invited_user._apply_tally_understanding",
|
||||
AsyncMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user._settings.config.enable_invite_gate",
|
||||
True,
|
||||
)
|
||||
mocker.patch("backend.data.invited_user.transaction", fake_transaction)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", side_effect=user_prisma
|
||||
@@ -333,3 +338,72 @@ async def test_bulk_create_invited_users_handles_csv_duplicates_and_invalid_rows
|
||||
"SKIPPED",
|
||||
]
|
||||
assert create_invited.await_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_invite_eligibility tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_invite_eligibility_returns_true_for_invited(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
invited = _invited_user_db_record(status=prisma.enums.InvitedUserStatus.INVITED)
|
||||
repo = Mock()
|
||||
repo.find_unique = AsyncMock(return_value=invited)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=repo,
|
||||
)
|
||||
|
||||
result = await check_invite_eligibility("invited@example.com")
|
||||
assert result is True
|
||||
repo.find_unique.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_invite_eligibility_returns_false_for_no_record(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
repo = Mock()
|
||||
repo.find_unique = AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=repo,
|
||||
)
|
||||
|
||||
result = await check_invite_eligibility("unknown@example.com")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_invite_eligibility_returns_false_for_claimed(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
claimed = _invited_user_db_record(status=prisma.enums.InvitedUserStatus.CLAIMED)
|
||||
repo = Mock()
|
||||
repo.find_unique = AsyncMock(return_value=claimed)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=repo,
|
||||
)
|
||||
|
||||
result = await check_invite_eligibility("claimed@example.com")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_invite_eligibility_returns_false_for_revoked(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
revoked = _invited_user_db_record(status=prisma.enums.InvitedUserStatus.REVOKED)
|
||||
repo = Mock()
|
||||
repo.find_unique = AsyncMock(return_value=revoked)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=repo,
|
||||
)
|
||||
|
||||
result = await check_invite_eligibility("revoked@example.com")
|
||||
assert result is False
|
||||
|
||||
@@ -25,6 +25,53 @@ logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
_on_creds_changed: Callable[[str, str], None] | None = None
|
||||
|
||||
|
||||
def register_creds_changed_hook(hook: Callable[[str, str], None]) -> None:
|
||||
"""Register a callback invoked after any credential is created/updated/deleted.
|
||||
|
||||
The callback receives ``(user_id, provider)`` and should be idempotent.
|
||||
Only one hook can be registered at a time. Intended to be called once at
|
||||
application startup (e.g. by the copilot module) without creating an
|
||||
import cycle.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If a hook is already registered. Call
|
||||
:func:`unregister_creds_changed_hook` first if replacement is needed.
|
||||
"""
|
||||
global _on_creds_changed
|
||||
if _on_creds_changed is not None:
|
||||
raise RuntimeError(
|
||||
"A creds_changed hook is already registered. "
|
||||
"Call unregister_creds_changed_hook() before registering a new one."
|
||||
)
|
||||
_on_creds_changed = hook
|
||||
|
||||
|
||||
def unregister_creds_changed_hook() -> None:
|
||||
"""Remove the currently registered creds-changed hook (if any).
|
||||
|
||||
Primarily useful in tests to reset global state between test cases.
|
||||
"""
|
||||
global _on_creds_changed
|
||||
_on_creds_changed = None
|
||||
|
||||
|
||||
def _invoke_creds_changed_hook(user_id: str, provider: str) -> None:
|
||||
"""Invoke the registered creds-changed hook (if any)."""
|
||||
if _on_creds_changed is not None:
|
||||
try:
|
||||
_on_creds_changed(user_id, provider)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Credential-change hook failed for user=%s provider=%s",
|
||||
user_id,
|
||||
provider,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
class IntegrationCredentialsManager:
|
||||
"""
|
||||
Handles the lifecycle of integration credentials.
|
||||
@@ -69,7 +116,10 @@ class IntegrationCredentialsManager:
|
||||
return self._locks
|
||||
|
||||
async def create(self, user_id: str, credentials: Credentials) -> None:
|
||||
return await self.store.add_creds(user_id, credentials)
|
||||
result = await self.store.add_creds(user_id, credentials)
|
||||
# Notify listeners so downstream caches are invalidated immediately.
|
||||
_invoke_creds_changed_hook(user_id, credentials.provider)
|
||||
return result
|
||||
|
||||
async def exists(self, user_id: str, credentials_id: str) -> bool:
|
||||
return (await self.store.get_creds_by_id(user_id, credentials_id)) is not None
|
||||
@@ -146,8 +196,7 @@ class IntegrationCredentialsManager:
|
||||
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
|
||||
if oauth_handler.needs_refresh(credentials):
|
||||
logger.debug(
|
||||
f"Refreshing '{credentials.provider}' "
|
||||
f"credentials #{credentials.id}"
|
||||
f"Refreshing '{credentials.provider}' credentials #{credentials.id}"
|
||||
)
|
||||
_lock = None
|
||||
if lock:
|
||||
@@ -156,11 +205,16 @@ class IntegrationCredentialsManager:
|
||||
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
# Notify listeners so the refreshed token is picked up immediately.
|
||||
_invoke_creds_changed_hook(user_id, fresh_credentials.provider)
|
||||
if _lock and (await _lock.locked()) and (await _lock.owned()):
|
||||
try:
|
||||
await _lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release OAuth refresh lock: {e}")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to release OAuth refresh lock",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
credentials = fresh_credentials
|
||||
return credentials
|
||||
@@ -168,10 +222,17 @@ class IntegrationCredentialsManager:
|
||||
async def update(self, user_id: str, updated: Credentials) -> None:
|
||||
async with self._locked(user_id, updated.id):
|
||||
await self.store.update_creds(user_id, updated)
|
||||
# Notify listeners so the updated credential is picked up immediately.
|
||||
_invoke_creds_changed_hook(user_id, updated.provider)
|
||||
|
||||
async def delete(self, user_id: str, credentials_id: str) -> None:
|
||||
async with self._locked(user_id, credentials_id):
|
||||
# Read inside the lock to avoid TOCTOU — another coroutine could
|
||||
# delete the same credential between the read and the delete.
|
||||
creds = await self.store.get_creds_by_id(user_id, credentials_id)
|
||||
await self.store.delete_creds_by_id(user_id, credentials_id)
|
||||
if creds:
|
||||
_invoke_creds_changed_hook(user_id, creds.provider)
|
||||
|
||||
# -- Locking utilities -- #
|
||||
|
||||
@@ -195,8 +256,11 @@ class IntegrationCredentialsManager:
|
||||
if (await lock.locked()) and (await lock.owned()):
|
||||
try:
|
||||
await lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release credentials lock: {e}")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to release credentials lock",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def release_all_locks(self):
|
||||
"""Call this on process termination to ensure all locks are released"""
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
"""Tests for creds_manager hook system: register, invoke, and CRUD integration."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.integrations.creds_manager import (
|
||||
_invoke_creds_changed_hook,
|
||||
register_creds_changed_hook,
|
||||
unregister_creds_changed_hook,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_hook():
|
||||
"""Ensure global hook state is clean before and after every test."""
|
||||
unregister_creds_changed_hook()
|
||||
yield
|
||||
unregister_creds_changed_hook()
|
||||
|
||||
|
||||
class TestRegisterCredsChangedHook:
|
||||
def test_register_and_invoke(self):
|
||||
calls: list[tuple[str, str]] = []
|
||||
register_creds_changed_hook(lambda u, p: calls.append((u, p)))
|
||||
|
||||
_invoke_creds_changed_hook("user-1", "github")
|
||||
assert calls == [("user-1", "github")]
|
||||
|
||||
def test_double_register_raises(self):
|
||||
register_creds_changed_hook(lambda u, p: None)
|
||||
with pytest.raises(RuntimeError, match="already registered"):
|
||||
register_creds_changed_hook(lambda u, p: None)
|
||||
|
||||
def test_unregister_then_reregister(self):
|
||||
register_creds_changed_hook(lambda u, p: None)
|
||||
unregister_creds_changed_hook()
|
||||
# Should not raise after unregister.
|
||||
register_creds_changed_hook(lambda u, p: None)
|
||||
|
||||
|
||||
class TestInvokeCredsChangedHook:
|
||||
def test_noop_when_no_hook_registered(self):
|
||||
# Must not raise even when no hook is registered.
|
||||
_invoke_creds_changed_hook("user-1", "github")
|
||||
|
||||
def test_hook_exception_is_swallowed(self):
|
||||
def bad_hook(user_id: str, provider: str) -> None:
|
||||
raise ValueError("boom")
|
||||
|
||||
register_creds_changed_hook(bad_hook)
|
||||
# Must not propagate the exception.
|
||||
_invoke_creds_changed_hook("user-1", "github")
|
||||
|
||||
def test_hook_receives_correct_args(self):
|
||||
calls: list[tuple[str, str]] = []
|
||||
register_creds_changed_hook(lambda u, p: calls.append((u, p)))
|
||||
|
||||
_invoke_creds_changed_hook("user-a", "github")
|
||||
_invoke_creds_changed_hook("user-b", "slack")
|
||||
|
||||
assert calls == [("user-a", "github"), ("user-b", "slack")]
|
||||
@@ -46,7 +46,7 @@ class EmailSender:
|
||||
|
||||
MAX_EMAIL_CHARS = 5_000_000 # ~5MB buffer
|
||||
|
||||
def send_templated(
|
||||
async def send_templated(
|
||||
self,
|
||||
notification: NotificationType,
|
||||
user_email: str,
|
||||
@@ -71,7 +71,7 @@ class EmailSender:
|
||||
template_data = {"notifications": data} if isinstance(data, list) else data
|
||||
|
||||
try:
|
||||
subject, full_message = self.formatter.format_email(
|
||||
subject, full_message = await self.formatter.format_email(
|
||||
base_template=template.base_template,
|
||||
subject_template=template.subject_template,
|
||||
content_template=template.body_template,
|
||||
|
||||
@@ -378,7 +378,7 @@ class NotificationManager(AppService):
|
||||
continue
|
||||
logger.info(f"{events=}")
|
||||
|
||||
self.email_sender.send_templated(
|
||||
await self.email_sender.send_templated(
|
||||
notification=notification_type,
|
||||
user_email=recipient_email,
|
||||
data=events,
|
||||
@@ -600,7 +600,7 @@ class NotificationManager(AppService):
|
||||
return False
|
||||
logger.debug(f"Processing notification for admin: {event}")
|
||||
recipient_email = settings.config.refund_notification_email
|
||||
self.email_sender.send_templated(event.type, recipient_email, event)
|
||||
await self.email_sender.send_templated(event.type, recipient_email, event)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing notification for admin queue: {e}")
|
||||
@@ -632,7 +632,7 @@ class NotificationManager(AppService):
|
||||
|
||||
unsub_link = generate_unsubscribe_link(event.user_id)
|
||||
|
||||
self.email_sender.send_templated(
|
||||
await self.email_sender.send_templated(
|
||||
notification=event.type,
|
||||
user_email=recipient_email,
|
||||
data=event,
|
||||
@@ -715,12 +715,14 @@ class NotificationManager(AppService):
|
||||
try:
|
||||
# Try to render the email to check its size
|
||||
template = self.email_sender._get_template(event.type)
|
||||
_, test_message = self.email_sender.formatter.format_email(
|
||||
base_template=template.base_template,
|
||||
subject_template=template.subject_template,
|
||||
content_template=template.body_template,
|
||||
data={"notifications": chunk},
|
||||
unsubscribe_link=f"{self.email_sender.formatter.env.globals.get('base_url', '')}/profile/settings",
|
||||
_, test_message = (
|
||||
await self.email_sender.formatter.format_email(
|
||||
base_template=template.base_template,
|
||||
subject_template=template.subject_template,
|
||||
content_template=template.body_template,
|
||||
data={"notifications": chunk},
|
||||
unsubscribe_link=f"{self.email_sender.formatter.env.globals.get('base_url', '')}/profile/settings",
|
||||
)
|
||||
)
|
||||
|
||||
if len(test_message) < MAX_EMAIL_SIZE:
|
||||
@@ -730,7 +732,7 @@ class NotificationManager(AppService):
|
||||
f"(size: {len(test_message):,} chars)"
|
||||
)
|
||||
|
||||
self.email_sender.send_templated(
|
||||
await self.email_sender.send_templated(
|
||||
notification=event.type,
|
||||
user_email=recipient_email,
|
||||
data=chunk,
|
||||
@@ -975,7 +977,7 @@ class NotificationManager(AppService):
|
||||
data=summary_data,
|
||||
)
|
||||
|
||||
self.email_sender.send_templated(
|
||||
await self.email_sender.send_templated(
|
||||
notification=event.type,
|
||||
user_email=recipient_email,
|
||||
data=data,
|
||||
|
||||
@@ -19,6 +19,7 @@ class TestNotificationErrorHandling:
|
||||
with patch("backend.notifications.notifications.AppService.__init__"):
|
||||
manager = NotificationManager()
|
||||
manager.email_sender = MagicMock()
|
||||
manager.email_sender.send_templated = AsyncMock()
|
||||
# Mock the _get_template method used by _process_batch
|
||||
template_mock = Mock()
|
||||
template_mock.base_template = "base"
|
||||
@@ -27,9 +28,10 @@ class TestNotificationErrorHandling:
|
||||
manager.email_sender._get_template = Mock(return_value=template_mock)
|
||||
# Mock the formatter
|
||||
manager.email_sender.formatter = Mock()
|
||||
manager.email_sender.formatter.format_email = Mock(
|
||||
manager.email_sender.formatter.format_email = AsyncMock(
|
||||
return_value=("subject", "body content")
|
||||
)
|
||||
manager.email_sender.send_templated = AsyncMock()
|
||||
manager.email_sender.formatter.env = Mock()
|
||||
manager.email_sender.formatter.env.globals = {
|
||||
"base_url": "http://example.com"
|
||||
@@ -331,7 +333,7 @@ class TestNotificationErrorHandling:
|
||||
return ("subject", "x" * 5_000_000) # Over 4.5MB limit
|
||||
return ("subject", "normal sized content")
|
||||
|
||||
notification_manager.email_sender.formatter.format_email = Mock(
|
||||
notification_manager.email_sender.formatter.format_email = AsyncMock(
|
||||
side_effect=format_side_effect
|
||||
)
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ Provides decorators for caching function results with support for:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import inspect
|
||||
import logging
|
||||
import pickle
|
||||
@@ -32,6 +34,9 @@ T = TypeVar("T")
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
# Length of the HMAC-SHA256 signature prefix on cached values.
|
||||
_HMAC_SIG_LEN = 32
|
||||
|
||||
# RECOMMENDED REDIS CONFIGURATION FOR PRODUCTION:
|
||||
# Configure Redis with the following settings for optimal caching performance:
|
||||
# maxmemory-policy allkeys-lru # Evict least recently used keys when memory limit reached
|
||||
@@ -176,35 +181,45 @@ def cached(
|
||||
"""
|
||||
|
||||
def decorator(target_func: Callable[P, R]) -> CachedFunction[P, R]:
|
||||
func_name = target_func.__name__
|
||||
cache_storage: dict[tuple, CachedValue] = {}
|
||||
_event_loop_locks: dict[Any, asyncio.Lock] = {}
|
||||
|
||||
def _get_from_redis(redis_key: str) -> Any | None:
|
||||
"""Get value from Redis, optionally refreshing TTL."""
|
||||
"""Get value from Redis, optionally refreshing TTL.
|
||||
|
||||
Values are expected to carry an HMAC-SHA256 prefix for integrity
|
||||
verification. Unsigned (legacy) or tampered entries are silently
|
||||
discarded and treated as cache misses, so the caller recomputes and
|
||||
re-stores them with a valid signature.
|
||||
"""
|
||||
try:
|
||||
if refresh_ttl_on_get:
|
||||
# Use GETEX to get value and refresh expiry atomically
|
||||
cached_bytes = _get_redis().getex(redis_key, ex=ttl_seconds)
|
||||
else:
|
||||
cached_bytes = _get_redis().get(redis_key)
|
||||
|
||||
if cached_bytes and isinstance(cached_bytes, bytes):
|
||||
return pickle.loads(cached_bytes)
|
||||
payload = _verify_and_strip(cached_bytes)
|
||||
if payload is None:
|
||||
logger.warning(
|
||||
"[SECURITY] Cache HMAC verification failed "
|
||||
f"for {func_name}, discarding entry: "
|
||||
"possible tampering or legacy unsigned value"
|
||||
)
|
||||
return None
|
||||
return pickle.loads(payload)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Redis error during cache check for {target_func.__name__}: {e}"
|
||||
)
|
||||
logger.error(f"Redis error during cache check for {func_name}: {e}")
|
||||
return None
|
||||
|
||||
def _set_to_redis(redis_key: str, value: Any) -> None:
|
||||
"""Set value in Redis with TTL."""
|
||||
"""Set HMAC-signed pickled value in Redis with TTL."""
|
||||
try:
|
||||
pickled_value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
_get_redis().setex(redis_key, ttl_seconds, pickled_value)
|
||||
pickled = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
_get_redis().setex(redis_key, ttl_seconds, _sign_payload(pickled))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Redis error storing cache for {target_func.__name__}: {e}"
|
||||
)
|
||||
logger.error(f"Redis error storing cache for {func_name}: {e}")
|
||||
|
||||
def _get_from_memory(key: tuple) -> Any | None:
|
||||
"""Get value from in-memory cache, checking TTL."""
|
||||
@@ -212,7 +227,7 @@ def cached(
|
||||
cached_data = cache_storage[key]
|
||||
if time.time() - cached_data.timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {target_func.__name__} args: {key[0]} kwargs: {key[1]}"
|
||||
f"Cache hit for {func_name} args: {key[0]} kwargs: {key[1]}"
|
||||
)
|
||||
return cached_data.result
|
||||
return None
|
||||
@@ -244,9 +259,7 @@ def cached(
|
||||
@wraps(target_func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
redis_key = (
|
||||
_make_redis_key(key, target_func.__name__) if shared_cache else ""
|
||||
)
|
||||
redis_key = _make_redis_key(key, func_name) if shared_cache else ""
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
@@ -271,7 +284,7 @@ def cached(
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
logger.debug(f"Cache miss for {func_name}")
|
||||
result = await target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
@@ -291,9 +304,7 @@ def cached(
|
||||
@wraps(target_func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
redis_key = (
|
||||
_make_redis_key(key, target_func.__name__) if shared_cache else ""
|
||||
)
|
||||
redis_key = _make_redis_key(key, func_name) if shared_cache else ""
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
@@ -318,7 +329,7 @@ def cached(
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
logger.debug(f"Cache miss for {func_name}")
|
||||
result = target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
@@ -337,16 +348,10 @@ def cached(
|
||||
if shared_cache:
|
||||
if pattern:
|
||||
# Clear entries matching pattern
|
||||
keys = list(
|
||||
_get_redis().scan_iter(
|
||||
f"cache:{target_func.__name__}:{pattern}"
|
||||
)
|
||||
)
|
||||
keys = list(_get_redis().scan_iter(f"cache:{func_name}:{pattern}"))
|
||||
else:
|
||||
# Clear all cache keys
|
||||
keys = list(
|
||||
_get_redis().scan_iter(f"cache:{target_func.__name__}:*")
|
||||
)
|
||||
keys = list(_get_redis().scan_iter(f"cache:{func_name}:*"))
|
||||
|
||||
if keys:
|
||||
pipeline = _get_redis().pipeline()
|
||||
@@ -364,9 +369,7 @@ def cached(
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
if shared_cache:
|
||||
cache_keys = list(
|
||||
_get_redis().scan_iter(f"cache:{target_func.__name__}:*")
|
||||
)
|
||||
cache_keys = list(_get_redis().scan_iter(f"cache:{func_name}:*"))
|
||||
return {
|
||||
"size": len(cache_keys),
|
||||
"maxsize": None, # Redis manages its own size
|
||||
@@ -383,7 +386,7 @@ def cached(
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if shared_cache:
|
||||
redis_key = _make_redis_key(key, target_func.__name__)
|
||||
redis_key = _make_redis_key(key, func_name)
|
||||
deleted_count = cast(int, _get_redis().delete(redis_key))
|
||||
return deleted_count > 0
|
||||
else:
|
||||
@@ -401,6 +404,52 @@ def cached(
|
||||
return decorator
|
||||
|
||||
|
||||
def _sign_payload(data: bytes) -> bytes:
|
||||
"""Return *signature + data* (32-byte HMAC-SHA256 prefix)."""
|
||||
sig = hmac.new(_get_hmac_key(), data, hashlib.sha256).digest()
|
||||
return sig + data
|
||||
|
||||
|
||||
def _verify_and_strip(blob: bytes) -> bytes | None:
|
||||
"""Verify the HMAC prefix and return the payload, or `None`.
|
||||
|
||||
During deployment, the cache may still contain unsigned (legacy) entries.
|
||||
These will fail verification and return `None`, causing callers to treat
|
||||
them as cache misses. The value is then recomputed and stored with a valid
|
||||
HMAC signature. This means the transition is fully automatic: no cache
|
||||
flush is required, and all entries self-heal on next access within their
|
||||
TTL window (max 1 hour for the longest-lived entries).
|
||||
"""
|
||||
if len(blob) <= _HMAC_SIG_LEN:
|
||||
return None
|
||||
sig, data = blob[:_HMAC_SIG_LEN], blob[_HMAC_SIG_LEN:]
|
||||
expected = hmac.new(_get_hmac_key(), data, hashlib.sha256).digest()
|
||||
if hmac.compare_digest(sig, expected):
|
||||
return data
|
||||
return None
|
||||
|
||||
|
||||
@cache
|
||||
def _get_hmac_key() -> bytes:
|
||||
"""Derive a stable HMAC key for signing cached values in Redis.
|
||||
|
||||
Uses `encryption_key` — a backend-only secret that Redis never sees.
|
||||
This ensures that even if an attacker compromises Redis, they cannot forge
|
||||
valid HMAC signatures for cache entries.
|
||||
|
||||
Falls back to a hardcoded default with a loud warning so the decorator
|
||||
never crashes in development/test environments without secrets configured.
|
||||
"""
|
||||
secret = settings.secrets.encryption_key
|
||||
if not secret:
|
||||
logger.warning(
|
||||
"[SECURITY] No encryption_key configured: cache HMAC signing will use a "
|
||||
"weak default key. Set ENCRYPTION_KEY for production deployments."
|
||||
)
|
||||
secret = "autogpt-cache-default-hmac-key"
|
||||
return hashlib.sha256(secret.encode()).digest()
|
||||
|
||||
|
||||
def thread_cached(func):
|
||||
"""
|
||||
Thread-local cache decorator for both sync and async functions.
|
||||
|
||||
@@ -1121,3 +1121,105 @@ class TestSharedCache:
|
||||
# Cleanup
|
||||
shared_perf_function.cache_clear()
|
||||
local_perf_function.cache_clear()
|
||||
|
||||
|
||||
class TestCacheHMAC:
|
||||
"""Tests for HMAC integrity verification on Redis-backed cache."""
|
||||
|
||||
def test_hmac_signed_roundtrip(self):
|
||||
"""Values written to Redis can be read back via HMAC verification."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def hmac_roundtrip_fn(x: int) -> dict:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return {"value": x, "nested": [1, 2, 3]}
|
||||
|
||||
hmac_roundtrip_fn.cache_clear()
|
||||
|
||||
result1 = hmac_roundtrip_fn(42)
|
||||
assert result1 == {"value": 42, "nested": [1, 2, 3]}
|
||||
assert call_count == 1
|
||||
|
||||
# Second call should hit cache (HMAC verification passes)
|
||||
result2 = hmac_roundtrip_fn(42)
|
||||
assert result2 == {"value": 42, "nested": [1, 2, 3]}
|
||||
assert call_count == 1
|
||||
|
||||
hmac_roundtrip_fn.cache_clear()
|
||||
|
||||
def test_tampered_cache_entry_rejected(self):
|
||||
"""A tampered Redis entry is rejected and treated as a cache miss."""
|
||||
from backend.util.cache import _get_redis
|
||||
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def tamper_test_fn(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
tamper_test_fn.cache_clear()
|
||||
|
||||
# Populate the cache
|
||||
result = tamper_test_fn(7)
|
||||
assert result == 14
|
||||
assert call_count == 1
|
||||
|
||||
# Find and tamper with the Redis key
|
||||
redis = _get_redis()
|
||||
keys = list(redis.scan_iter("cache:tamper_test_fn:*"))
|
||||
assert len(keys) >= 1, "Expected at least one cache key"
|
||||
|
||||
for key in keys:
|
||||
raw: bytes | None = redis.get(key) # type: ignore[assignment]
|
||||
assert raw is not None
|
||||
# Flip a byte in the signature portion to simulate tampering
|
||||
tampered = bytes([raw[0] ^ 0xFF]) + raw[1:] # type: ignore[index]
|
||||
redis.set(key, tampered)
|
||||
|
||||
# Next call should detect tampering and recompute
|
||||
result2 = tamper_test_fn(7)
|
||||
assert result2 == 14
|
||||
assert call_count == 2 # Had to recompute
|
||||
|
||||
tamper_test_fn.cache_clear()
|
||||
|
||||
def test_unsigned_legacy_entry_rejected(self):
|
||||
"""A raw pickled value (no HMAC prefix) is rejected as a cache miss."""
|
||||
import pickle as _pickle
|
||||
|
||||
from backend.util.cache import _get_redis
|
||||
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def legacy_test_fn(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x + 100
|
||||
|
||||
legacy_test_fn.cache_clear()
|
||||
|
||||
# Manually write an unsigned (legacy) pickled value directly to Redis
|
||||
redis = _get_redis()
|
||||
# We need to figure out the cache key format; populate first then overwrite
|
||||
result = legacy_test_fn(5)
|
||||
assert result == 105
|
||||
assert call_count == 1
|
||||
|
||||
keys = list(redis.scan_iter("cache:legacy_test_fn:*"))
|
||||
assert len(keys) >= 1
|
||||
|
||||
# Overwrite with raw unsigned pickle (simulating a legacy entry)
|
||||
for key in keys:
|
||||
redis.set(key, _pickle.dumps(999))
|
||||
|
||||
# Next call should reject the unsigned value and recompute
|
||||
result2 = legacy_test_fn(5)
|
||||
assert result2 == 105
|
||||
assert call_count == 2
|
||||
|
||||
legacy_test_fn.cache_clear()
|
||||
|
||||
@@ -71,11 +71,15 @@ def sanitize_filename(filename: str) -> str:
|
||||
|
||||
# Truncate if too long
|
||||
if len(sanitized) > MAX_FILENAME_LENGTH:
|
||||
# Keep the extension if possible
|
||||
# Keep the extension if possible, but only if it's reasonable length
|
||||
if "." in sanitized:
|
||||
name, ext = sanitized.rsplit(".", 1)
|
||||
max_name_length = MAX_FILENAME_LENGTH - len(ext) - 1
|
||||
sanitized = name[:max_name_length] + "." + ext
|
||||
# If extension is too long, it's likely not a file extension but just text
|
||||
if len(ext) <= 20:
|
||||
max_name_length = MAX_FILENAME_LENGTH - len(ext) - 1
|
||||
sanitized = name[:max_name_length] + "." + ext
|
||||
else:
|
||||
sanitized = sanitized[:MAX_FILENAME_LENGTH]
|
||||
else:
|
||||
sanitized = sanitized[:MAX_FILENAME_LENGTH]
|
||||
|
||||
@@ -129,7 +133,7 @@ async def store_media_file(
|
||||
|
||||
Return format options:
|
||||
- "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||
- "for_external_api": Returns data URI (base64) - use when sending to external APIs
|
||||
- "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
||||
- "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
|
||||
|
||||
:param file: Data URI, URL, workspace://, or local (relative) path.
|
||||
|
||||
@@ -70,6 +70,10 @@ def _msg_tokens(msg: dict, enc) -> int:
|
||||
# Count tool result tokens
|
||||
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
# Count text block tokens (standard: "text" key, fallback: "content")
|
||||
text_val = item.get("text") or item.get("content", "")
|
||||
tool_call_tokens += _tok_len(text_val, enc)
|
||||
elif isinstance(item, dict) and "content" in item:
|
||||
# Other content types with content field
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
@@ -145,10 +149,16 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||
if len(ids) <= max_tok:
|
||||
return text # nothing to do
|
||||
|
||||
# Need at least 3 tokens (head + ellipsis + tail) for meaningful truncation
|
||||
if max_tok < 1:
|
||||
return ""
|
||||
mid = enc.encode(" … ")
|
||||
if max_tok < 3:
|
||||
return enc.decode(ids[:max_tok])
|
||||
|
||||
# Split the allowance between the two ends:
|
||||
head = max_tok // 2 - 1 # -1 for the ellipsis
|
||||
tail = max_tok - head - 1
|
||||
mid = enc.encode(" … ")
|
||||
return enc.decode(ids[:head] + mid + ids[-tail:])
|
||||
|
||||
|
||||
@@ -545,6 +555,14 @@ async def _summarize_messages_llm(
|
||||
"- Actions taken and key decisions made\n"
|
||||
"- Technical specifics (file names, tool outputs, function signatures)\n"
|
||||
"- Errors encountered and resolutions applied\n\n"
|
||||
"IMPORTANT: Preserve all concrete references verbatim — these are small but "
|
||||
"critical for continuing the conversation:\n"
|
||||
"- File paths and directory paths (e.g. /src/app/page.tsx, ./output/result.csv)\n"
|
||||
"- Image/media file paths from tool outputs\n"
|
||||
"- URLs, API endpoints, and webhook addresses\n"
|
||||
"- Resource IDs, session IDs, and identifiers\n"
|
||||
"- Tool names that were called and their key parameters\n"
|
||||
"- Environment variables, config keys, and credentials names (not values)\n\n"
|
||||
"Include ONLY the sections below that have relevant content "
|
||||
"(skip sections with nothing to report):\n\n"
|
||||
"## 1. Primary Request and Intent\n"
|
||||
@@ -552,7 +570,8 @@ async def _summarize_messages_llm(
|
||||
"## 2. Key Technical Concepts\n"
|
||||
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
|
||||
"## 3. Files and Resources Involved\n"
|
||||
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
|
||||
"Specific files examined or modified, with relevant snippets and identifiers. "
|
||||
"Include exact file paths, image paths from tool outputs, and resource URLs.\n\n"
|
||||
"## 4. Errors and Fixes\n"
|
||||
"Problems encountered, error messages, and their resolutions.\n\n"
|
||||
"## 5. All User Messages\n"
|
||||
@@ -566,7 +585,7 @@ async def _summarize_messages_llm(
|
||||
},
|
||||
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
|
||||
],
|
||||
max_tokens=1500,
|
||||
max_tokens=2000,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
@@ -686,11 +705,15 @@ async def compress_context(
|
||||
msgs = [summary_msg] + recent_msgs
|
||||
|
||||
logger.info(
|
||||
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
|
||||
f"summarized {messages_summarized} messages"
|
||||
"Context summarized: %d -> %d tokens, summarized %d messages",
|
||||
original_count,
|
||||
total_tokens(),
|
||||
messages_summarized,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Summarization failed, continuing with truncation: {e}")
|
||||
logger.warning(
|
||||
"Summarization failed, continuing with truncation: %s", e
|
||||
)
|
||||
# Fall through to content truncation
|
||||
|
||||
# ---- STEP 2: Normalize content ----------------------------------------
|
||||
|
||||
@@ -50,10 +50,13 @@ BLOCKED_IP_NETWORKS = [
|
||||
# IPv4 Ranges
|
||||
ipaddress.ip_network("0.0.0.0/8"), # "This" Network
|
||||
ipaddress.ip_network("10.0.0.0/8"), # Private-Use
|
||||
ipaddress.ip_network("100.64.0.0/10"), # Shared Address Space (CGNAT, RFC 6598)
|
||||
ipaddress.ip_network("127.0.0.0/8"), # Loopback
|
||||
ipaddress.ip_network("169.254.0.0/16"), # Link Local
|
||||
ipaddress.ip_network("172.16.0.0/12"), # Private-Use
|
||||
ipaddress.ip_network("192.0.0.0/24"), # IETF Protocol Assignments
|
||||
ipaddress.ip_network("192.168.0.0/16"), # Private-Use
|
||||
ipaddress.ip_network("198.18.0.0/15"), # Benchmarking
|
||||
ipaddress.ip_network("224.0.0.0/4"), # Multicast
|
||||
ipaddress.ip_network("240.0.0.0/4"), # Reserved for Future Use
|
||||
# IPv6 Ranges
|
||||
@@ -71,8 +74,17 @@ HOSTNAME_REGEX = re.compile(r"^[A-Za-z0-9.-]+$") # Basic DNS-safe hostname patt
|
||||
def _is_ip_blocked(ip: str) -> bool:
|
||||
"""
|
||||
Checks if the IP address is in a blocked network.
|
||||
|
||||
IPv4-mapped IPv6 addresses (e.g. ``::ffff:127.0.0.1``) are normalized to
|
||||
their IPv4 equivalent before checking, so the IPv4 blocklist cannot be
|
||||
bypassed by encoding a private IPv4 address as IPv6.
|
||||
"""
|
||||
ip_addr = ipaddress.ip_address(ip)
|
||||
|
||||
# Normalize IPv4-mapped IPv6 → IPv4 so the IPv4 blocklist applies
|
||||
if isinstance(ip_addr, ipaddress.IPv6Address) and ip_addr.ipv4_mapped:
|
||||
ip_addr = ip_addr.ipv4_mapped
|
||||
|
||||
return any(ip_addr in network for network in BLOCKED_IP_NETWORKS)
|
||||
|
||||
|
||||
@@ -216,7 +228,7 @@ async def validate_url_host(
|
||||
return parsed, True, []
|
||||
|
||||
# If not allowlisted, go ahead with host resolution and IP target check
|
||||
return parsed, False, await _resolve_and_check_blocked(ascii_hostname)
|
||||
return parsed, False, await resolve_and_check_blocked(ascii_hostname)
|
||||
|
||||
|
||||
def matches_allowed_host(url: URL, allowed: URL) -> bool:
|
||||
@@ -230,7 +242,7 @@ def matches_allowed_host(url: URL, allowed: URL) -> bool:
|
||||
return url.port == allowed.port
|
||||
|
||||
|
||||
async def _resolve_and_check_blocked(hostname: str) -> list[str]:
|
||||
async def resolve_and_check_blocked(hostname: str) -> list[str]:
|
||||
"""
|
||||
Resolves hostname to IPs and raises ValueError if any resolve to
|
||||
a blocked network. Returns the list of resolved IP addresses.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from backend.util.request import pin_url, validate_url_host
|
||||
from backend.util.request import _is_ip_blocked, pin_url, validate_url_host
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -171,3 +171,76 @@ async def test_large_header_handling():
|
||||
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
|
||||
|
||||
# ---------- IPv4-mapped IPv6 bypass tests (GHSA-8qc5-rhmg-r6r6) ----------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ip, expected_blocked",
|
||||
[
|
||||
# IPv4-mapped IPv6 encoding of blocked IPv4 ranges
|
||||
("::ffff:127.0.0.1", True), # Loopback
|
||||
("::ffff:10.0.0.1", True), # Private-Use
|
||||
("::ffff:192.168.1.1", True), # Private-Use
|
||||
("::ffff:172.16.0.1", True), # Private-Use
|
||||
("::ffff:169.254.1.1", True), # Link Local
|
||||
("::ffff:100.64.0.1", True), # CGNAT (RFC 6598)
|
||||
# Plain IPv4 (should still be blocked)
|
||||
("127.0.0.1", True),
|
||||
("10.0.0.1", True),
|
||||
("100.64.0.1", True), # CGNAT
|
||||
("192.0.0.1", True), # IETF Protocol Assignments
|
||||
("198.18.0.1", True), # Benchmarking
|
||||
# Public IPs (should NOT be blocked)
|
||||
("8.8.8.8", False),
|
||||
("1.1.1.1", False),
|
||||
("::ffff:8.8.8.8", False), # IPv4-mapped but public
|
||||
# Native IPv6 blocked ranges
|
||||
("::1", True), # Loopback
|
||||
("fe80::1", True), # Link-local
|
||||
("fc00::1", True), # ULA
|
||||
# Public IPv6 (should NOT be blocked)
|
||||
("2607:f8b0:4004:800::200e", False), # Google
|
||||
],
|
||||
)
|
||||
def test_is_ip_blocked(ip: str, expected_blocked: bool):
|
||||
assert (
|
||||
_is_ip_blocked(ip) == expected_blocked
|
||||
), f"Expected _is_ip_blocked({ip!r}) == {expected_blocked}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw_url, resolved_ips, should_raise",
|
||||
[
|
||||
# IPv4-mapped IPv6 loopback — must be blocked
|
||||
("mapped-loopback.example.com", ["::ffff:127.0.0.1"], True),
|
||||
# IPv4-mapped IPv6 private — must be blocked
|
||||
("mapped-private.example.com", ["::ffff:10.0.0.1"], True),
|
||||
# CGNAT range — must be blocked
|
||||
("cgnat.example.com", ["100.64.0.1"], True),
|
||||
# Mixed: one public, one mapped-private — must be blocked (any blocked = reject)
|
||||
("mixed.example.com", ["8.8.8.8", "::ffff:192.168.1.1"], True),
|
||||
# All public — should pass
|
||||
("public.example.com", ["8.8.8.8", "9.9.9.9"], False),
|
||||
],
|
||||
)
|
||||
async def test_ipv4_mapped_ipv6_bypass(
|
||||
monkeypatch,
|
||||
raw_url: str,
|
||||
resolved_ips: list[str],
|
||||
should_raise: bool,
|
||||
):
|
||||
"""Ensures IPv4-mapped IPv6 addresses are checked against IPv4 blocklist."""
|
||||
|
||||
def mock_getaddrinfo(host, port, *args, **kwargs):
|
||||
return [(None, None, None, None, (ip, port)) for ip in resolved_ips]
|
||||
|
||||
monkeypatch.setattr("socket.getaddrinfo", mock_getaddrinfo)
|
||||
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError):
|
||||
await validate_url_host(raw_url)
|
||||
else:
|
||||
url, _, ip_addresses = await validate_url_host(raw_url)
|
||||
assert ip_addresses # Should have resolved IPs
|
||||
|
||||
@@ -122,7 +122,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="If authentication is enabled or not",
|
||||
)
|
||||
enable_invite_gate: bool = Field(
|
||||
default=True,
|
||||
default=False,
|
||||
description="If the invite-only signup gate is enforced",
|
||||
)
|
||||
enable_credit: bool = Field(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
|
||||
@@ -10,6 +11,12 @@ from markupsafe import Markup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Resource limits for template rendering
|
||||
MAX_EXPONENT = 1000 # Max allowed exponent in ** operations
|
||||
MAX_RANGE = 10_000 # Max items from range()
|
||||
MAX_SEQUENCE_REPEAT = 10_000 # Max length from sequence * int
|
||||
TEMPLATE_RENDER_TIMEOUT = 10 # Seconds before render is killed
|
||||
|
||||
|
||||
def format_filter_for_jinja2(value, format_string=None):
|
||||
if format_string:
|
||||
@@ -19,9 +26,14 @@ def format_filter_for_jinja2(value, format_string=None):
|
||||
|
||||
class TextFormatter:
|
||||
def __init__(self, autoescape: bool = True):
|
||||
self.env = SandboxedEnvironment(loader=BaseLoader(), autoescape=autoescape)
|
||||
self.env = _RestrictedEnvironment(
|
||||
loader=BaseLoader(), autoescape=autoescape, enable_async=True
|
||||
)
|
||||
self.env.globals.clear()
|
||||
|
||||
# Replace range with a safe capped version
|
||||
self.env.globals["range"] = _safe_range
|
||||
|
||||
# Instead of clearing all filters, just remove potentially unsafe ones
|
||||
unsafe_filters = ["pprint", "tojson", "urlize", "xmlattr"]
|
||||
for f in unsafe_filters:
|
||||
@@ -101,15 +113,34 @@ class TextFormatter:
|
||||
"img": ["src"],
|
||||
}
|
||||
|
||||
def format_string(self, template_str: str, values=None, **kwargs) -> str:
|
||||
"""Regular template rendering with escaping"""
|
||||
async def format_string(
|
||||
self,
|
||||
template_str: str,
|
||||
values=None,
|
||||
*,
|
||||
timeout: float | None = TEMPLATE_RENDER_TIMEOUT,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Render a Jinja2 template with resource limits.
|
||||
|
||||
Uses Jinja2's native async rendering (``render_async``) with
|
||||
``asyncio.wait_for`` as a defense-in-depth timeout.
|
||||
"""
|
||||
try:
|
||||
template = self.env.from_string(template_str)
|
||||
return template.render(values or {}, **kwargs)
|
||||
coro = template.render_async(values or {}, **kwargs)
|
||||
if timeout is not None:
|
||||
return await asyncio.wait_for(coro, timeout=timeout)
|
||||
return await coro
|
||||
except TimeoutError:
|
||||
raise ValueError(
|
||||
f"Template rendering timed out after {timeout}s "
|
||||
"(expression too complex)"
|
||||
)
|
||||
except TemplateError as e:
|
||||
raise ValueError(e) from e
|
||||
|
||||
def format_email(
|
||||
async def format_email(
|
||||
self,
|
||||
subject_template: str,
|
||||
base_template: str,
|
||||
@@ -121,7 +152,7 @@ class TextFormatter:
|
||||
Special handling for email templates where content needs to be rendered as HTML
|
||||
"""
|
||||
# First render the content template
|
||||
content = self.format_string(content_template, data, **kwargs)
|
||||
content = await self.format_string(content_template, data, **kwargs)
|
||||
|
||||
# Clean the HTML + CSS but don't escape it
|
||||
clean_content = bleach.clean(
|
||||
@@ -136,17 +167,21 @@ class TextFormatter:
|
||||
safe_content = Markup(clean_content)
|
||||
|
||||
# Render subject
|
||||
rendered_subject_template = self.format_string(subject_template, data, **kwargs)
|
||||
rendered_subject_template = await self.format_string(
|
||||
subject_template, data, **kwargs
|
||||
)
|
||||
|
||||
# Create new env just for HTML template
|
||||
html_env = SandboxedEnvironment(loader=BaseLoader(), autoescape=True)
|
||||
# Create restricted env for HTML template (defense-in-depth)
|
||||
html_env = _RestrictedEnvironment(
|
||||
loader=BaseLoader(), autoescape=True, enable_async=True
|
||||
)
|
||||
html_env.filters["safe"] = lambda x: (
|
||||
x if isinstance(x, Markup) else Markup(str(x))
|
||||
)
|
||||
|
||||
# Render base template with the safe content
|
||||
template = html_env.from_string(base_template)
|
||||
rendered_base_template = template.render(
|
||||
rendered_base_template = await template.render_async(
|
||||
data={
|
||||
"message": safe_content,
|
||||
"title": rendered_subject_template,
|
||||
@@ -157,6 +192,66 @@ class TextFormatter:
|
||||
return rendered_subject_template, rendered_base_template
|
||||
|
||||
|
||||
def _safe_range(*args: int) -> range:
|
||||
"""range() replacement that caps the number of items to prevent DoS."""
|
||||
r = range(*args)
|
||||
if len(r) > MAX_RANGE:
|
||||
raise OverflowError(f"range() too large ({len(r)} items, max {MAX_RANGE})")
|
||||
return r
|
||||
|
||||
|
||||
class _RestrictedEnvironment(SandboxedEnvironment):
|
||||
"""SandboxedEnvironment with computational complexity limits.
|
||||
|
||||
Prevents resource-exhaustion attacks such as ``{{ 999999999**999999999 }}``
|
||||
or ``{{ range(999999999) | list }}`` by intercepting dangerous builtins.
|
||||
"""
|
||||
|
||||
# Tell Jinja2 to route these operators through call_binop()
|
||||
intercepted_binops = frozenset(["**", "*"])
|
||||
|
||||
def call(
|
||||
__self, # noqa: N805 – Jinja2 convention
|
||||
__context,
|
||||
__obj,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
# Intercept pow() to cap the exponent
|
||||
if __obj is pow and len(args) >= 2:
|
||||
base, exp = args[0], args[1]
|
||||
if isinstance(exp, (int, float)) and abs(exp) > MAX_EXPONENT:
|
||||
raise OverflowError(f"Exponent too large (max {MAX_EXPONENT})")
|
||||
if isinstance(base, (int, float)) and abs(base) > MAX_EXPONENT:
|
||||
raise OverflowError(
|
||||
f"Base too large for exponentiation (max {MAX_EXPONENT})"
|
||||
)
|
||||
return super().call(__context, __obj, *args, **kwargs)
|
||||
|
||||
def call_binop(self, context, operator, left, right):
|
||||
# Intercept the ** (power) operator
|
||||
if operator == "**":
|
||||
if isinstance(right, (int, float)) and abs(right) > MAX_EXPONENT:
|
||||
raise OverflowError(f"Exponent too large (max {MAX_EXPONENT})")
|
||||
if isinstance(left, (int, float)) and abs(left) > MAX_EXPONENT:
|
||||
raise OverflowError(
|
||||
f"Base too large for exponentiation (max {MAX_EXPONENT})"
|
||||
)
|
||||
# Intercept sequence repetition via * (strings, lists, tuples)
|
||||
if operator == "*":
|
||||
if isinstance(left, (str, list, tuple)) and isinstance(right, int):
|
||||
if len(left) * right > MAX_SEQUENCE_REPEAT:
|
||||
raise OverflowError(
|
||||
f"Sequence repeat too large (max {MAX_SEQUENCE_REPEAT} items)"
|
||||
)
|
||||
if isinstance(right, (str, list, tuple)) and isinstance(left, int):
|
||||
if len(right) * left > MAX_SEQUENCE_REPEAT:
|
||||
raise OverflowError(
|
||||
f"Sequence repeat too large (max {MAX_SEQUENCE_REPEAT} items)"
|
||||
)
|
||||
return super().call_binop(context, operator, left, right)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CamelCase splitting
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -165,6 +260,7 @@ class TextFormatter:
|
||||
# Mirrors the frontend exception list in frontend/src/lib/utils.ts.
|
||||
_CAMELCASE_EXCEPTIONS: dict[str, str] = {
|
||||
"Auto GPT": "AutoGPT",
|
||||
"Auto Pilot": "AutoPilot",
|
||||
"Open AI": "OpenAI",
|
||||
"You Tube": "YouTube",
|
||||
"Git Hub": "GitHub",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user