mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
12 Commits
remove-cla
...
feat/analy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9afd9fa01 | ||
|
|
ddb4f6e9de | ||
|
|
f585d97928 | ||
|
|
7d39234fdd | ||
|
|
6e9d4c4333 | ||
|
|
8aad333a45 | ||
|
|
856f0d980d | ||
|
|
3c3aadd361 | ||
|
|
e87a693fdd | ||
|
|
fe265c10d4 | ||
|
|
5d00a94693 | ||
|
|
6e1605994d |
17
.claude/skills/backend-check/SKILL.md
Normal file
17
.claude/skills/backend-check/SKILL.md
Normal file
@@ -0,0 +1,17 @@
|
||||
---
|
||||
name: backend-check
|
||||
description: Run the full backend formatting, linting, and test suite. Ensures code quality before commits and PRs. TRIGGER when backend Python code has been modified and needs validation.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Backend Check
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Format**: `poetry run format` — runs formatting AND linting. NEVER run ruff/black/isort individually
|
||||
2. **Fix** any remaining errors manually, re-run until clean
|
||||
3. **Test**: `poetry run test` (runs DB setup + pytest). For specific files: `poetry run pytest -s -vvv <test_files>`
|
||||
4. **Snapshots** (if needed): `poetry run pytest path/to/test.py --snapshot-update` — review with `git diff`
|
||||
35
.claude/skills/code-style/SKILL.md
Normal file
35
.claude/skills/code-style/SKILL.md
Normal file
@@ -0,0 +1,35 @@
|
||||
---
|
||||
name: code-style
|
||||
description: Python code style preferences for the AutoGPT backend. Apply when writing or reviewing Python code. TRIGGER when writing new Python code, reviewing PRs, or refactoring backend code.
|
||||
user-invocable: false
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Code Style
|
||||
|
||||
## Imports
|
||||
|
||||
- **Top-level only** — no local/inner imports. Move all imports to the top of the file.
|
||||
|
||||
## Typing
|
||||
|
||||
- **No duck typing** — avoid `hasattr`, `getattr`, `isinstance` for type dispatch. Use proper typed interfaces, unions, or protocols.
|
||||
- **Pydantic models** over dataclass, namedtuple, or raw dict for structured data.
|
||||
- **No linter suppressors** — avoid `# type: ignore`, `# noqa`, `# pyright: ignore` etc. 99% of the time the right fix is fixing the type/code, not silencing the tool.
|
||||
|
||||
## Code Structure
|
||||
|
||||
- **List comprehensions** over manual loop-and-append.
|
||||
- **Early return** — guard clauses first, avoid deep nesting.
|
||||
- **Flatten inline** — prefer short, concise expressions. Reduce `if/else` chains with direct returns or ternaries when readable.
|
||||
- **Modular functions** — break complex logic into small, focused functions rather than long blocks with nested conditionals.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
Before finishing, always ask:
|
||||
- Can any function be split into smaller pieces?
|
||||
- Is there unnecessary nesting that an early return would eliminate?
|
||||
- Can any loop be a comprehension?
|
||||
- Is there a simpler way to express this logic?
|
||||
16
.claude/skills/frontend-check/SKILL.md
Normal file
16
.claude/skills/frontend-check/SKILL.md
Normal file
@@ -0,0 +1,16 @@
|
||||
---
|
||||
name: frontend-check
|
||||
description: Run the full frontend formatting, linting, and type checking suite. Ensures code quality before commits and PRs. TRIGGER when frontend TypeScript/React code has been modified and needs validation.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Frontend Check
|
||||
|
||||
## Steps (in order)
|
||||
|
||||
1. **Format**: `pnpm format` — NEVER run individual formatters
|
||||
2. **Lint**: `pnpm lint` — fix errors, re-run until clean
|
||||
3. **Types**: `pnpm types` — if it keeps failing after multiple attempts, stop and ask the user
|
||||
29
.claude/skills/new-block/SKILL.md
Normal file
29
.claude/skills/new-block/SKILL.md
Normal file
@@ -0,0 +1,29 @@
|
||||
---
|
||||
name: new-block
|
||||
description: Create a new backend block following the Block SDK Guide. Guides through provider configuration, schema definition, authentication, and testing. TRIGGER when user asks to create a new block, add a new integration, or build a new node for the graph editor.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# New Block Creation
|
||||
|
||||
Read `docs/platform/block-sdk-guide.md` first for the full guide.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Provider config** (if external service): create `_config.py` with `ProviderBuilder`
|
||||
2. **Block file** in `backend/blocks/` (from `autogpt_platform/backend/`):
|
||||
- Generate a UUID once with `uuid.uuid4()`, then **hard-code that string** as `id` (IDs must be stable across imports)
|
||||
- `Input(BlockSchema)` and `Output(BlockSchema)` classes
|
||||
- `async def run` that `yield`s output fields
|
||||
3. **Files**: use `store_media_file()` with `"for_block_output"` for outputs
|
||||
4. **Test**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[MyBlock]' -xvs`
|
||||
5. **Format**: `poetry run format`
|
||||
|
||||
## Rules
|
||||
|
||||
- Analyze interfaces: do inputs/outputs connect well with other blocks in a graph?
|
||||
- Use top-level imports, avoid duck typing
|
||||
- Always use `for_block_output` for block outputs
|
||||
28
.claude/skills/openapi-regen/SKILL.md
Normal file
28
.claude/skills/openapi-regen/SKILL.md
Normal file
@@ -0,0 +1,28 @@
|
||||
---
|
||||
name: openapi-regen
|
||||
description: Regenerate the OpenAPI spec and frontend API client. Starts the backend REST server, fetches the spec, and regenerates the typed frontend hooks. TRIGGER when API routes change, new endpoints are added, or frontend API types are stale.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# OpenAPI Spec Regeneration
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Run end-to-end** in a single shell block (so `REST_PID` persists):
|
||||
```bash
|
||||
cd autogpt_platform/backend && poetry run rest &
|
||||
REST_PID=$!
|
||||
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && kill $REST_PID && exit 1; done
|
||||
cd ../frontend && pnpm generate:api:force
|
||||
kill $REST_PID
|
||||
pnpm types && pnpm lint && pnpm format
|
||||
```
|
||||
|
||||
## Rules
|
||||
|
||||
- Always use `pnpm generate:api:force` (not `pnpm generate:api`)
|
||||
- Don't manually edit files in `src/app/api/__generated__/`
|
||||
- Generated hooks follow: `use{Method}{Version}{OperationName}`
|
||||
@@ -1,200 +0,0 @@
|
||||
---
|
||||
name: pr-address
|
||||
description: Address PR review comments and loop until CI green and all comments resolved. TRIGGER when user asks to address comments, fix PR feedback, respond to reviewers, or babysit/monitor a PR.
|
||||
user-invocable: true
|
||||
argument-hint: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# PR Address
|
||||
|
||||
## Find the PR
|
||||
|
||||
```bash
|
||||
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Fetch comments (all sources)
|
||||
|
||||
### 1. Inline review threads — GraphQL (primary source of actionable items)
|
||||
|
||||
Use GraphQL to fetch inline threads. It natively exposes `isResolved`, returns threads already grouped with all replies, and paginates via cursor — no manual thread reconstruction needed.
|
||||
|
||||
```bash
|
||||
gh api graphql -f query='
|
||||
{
|
||||
repository(owner: "Significant-Gravitas", name: "AutoGPT") {
|
||||
pullRequest(number: {N}) {
|
||||
reviewThreads(first: 100) {
|
||||
pageInfo { hasNextPage endCursor }
|
||||
nodes {
|
||||
id
|
||||
isResolved
|
||||
path
|
||||
comments(last: 1) {
|
||||
nodes { databaseId body author { login } createdAt }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
If `pageInfo.hasNextPage` is true, fetch subsequent pages by adding `after: "<endCursor>"` to `reviewThreads(first: 100, after: "...")` and repeat until `hasNextPage` is false.
|
||||
|
||||
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
|
||||
|
||||
### 2. Top-level reviews — REST (MUST paginate)
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
|
||||
```
|
||||
|
||||
**CRITICAL — always `--paginate`.** Reviews default to 30 per page. PRs can have 80–170+ reviews (mostly empty resolution events). Without pagination you miss reviews past position 30 — including `autogpt-reviewer`'s structured review which is typically posted after several CI runs and sits well beyond the first page.
|
||||
|
||||
Two things to extract:
|
||||
- **Overall state**: look for `CHANGES_REQUESTED` or `APPROVED` reviews.
|
||||
- **Actionable feedback**: non-empty bodies only. Empty-body reviews are thread-resolution events — they indicate progress but have no feedback to act on.
|
||||
|
||||
**Where each reviewer posts:**
|
||||
- `autogpt-reviewer` — posts detailed structured reviews ("Blockers", "Should Fix", "Nice to Have") as **top-level reviews**. Not present on every PR. Address ALL items.
|
||||
- `sentry[bot]` — posts bug predictions as **inline threads**. Fix real bugs, explain false positives.
|
||||
- `coderabbitai[bot]` — posts summaries as **top-level reviews** AND actionable items as **inline threads**. Address actionable items.
|
||||
- Human reviewers — can post in any source. Address ALL non-empty feedback.
|
||||
|
||||
### 3. PR conversation comments — REST
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
|
||||
```
|
||||
|
||||
Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`github-actions[bot]`), and author status updates. Scan for non-empty messages from non-bot human reviewers that aren't the PR author — those are the ones that need a response.
|
||||
|
||||
## For each unaddressed comment
|
||||
|
||||
Address comments **one at a time**: fix → commit → push → inline reply → next.
|
||||
|
||||
1. Read the referenced code, make the fix (or reply explaining why it's not needed)
|
||||
2. Commit and push the fix
|
||||
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
|
||||
|
||||
| Comment type | How to reply |
|
||||
|---|---|
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
|
||||
## Format and commit
|
||||
|
||||
After fixing, format the changed code:
|
||||
|
||||
- **Backend** (from `autogpt_platform/backend/`): `poetry run format`
|
||||
- **Frontend** (from `autogpt_platform/frontend/`): `pnpm format && pnpm lint && pnpm types`
|
||||
|
||||
If API routes changed, regenerate the frontend client:
|
||||
```bash
|
||||
cd autogpt_platform/backend && poetry run rest &
|
||||
REST_PID=$!
|
||||
trap "kill $REST_PID 2>/dev/null" EXIT
|
||||
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && exit 1; done
|
||||
cd ../frontend && pnpm generate:api:force
|
||||
kill $REST_PID 2>/dev/null; trap - EXIT
|
||||
```
|
||||
Never manually edit files in `src/app/api/__generated__/`.
|
||||
|
||||
Then commit and **push immediately** — never batch commits without pushing.
|
||||
|
||||
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
||||
|
||||
## The loop
|
||||
|
||||
```text
|
||||
address comments → format → commit → push
|
||||
→ wait for CI (while addressing new comments) → fix failures → push
|
||||
→ re-check comments after CI settles
|
||||
→ repeat until: all comments addressed AND CI green AND no new comments arriving
|
||||
```
|
||||
|
||||
### Polling for CI + new comments
|
||||
|
||||
After pushing, poll for **both** CI status and new comments in a single loop. Do not use `gh pr checks --watch` — it blocks the tool and prevents reacting to new comments while CI is running.
|
||||
|
||||
> **Note:** `gh pr checks --watch --fail-fast` is tempting but it blocks the entire Bash tool call, meaning the agent cannot check for or address new comments until CI fully completes. Always poll manually instead.
|
||||
|
||||
**Polling loop — repeat every 30 seconds:**
|
||||
|
||||
1. Check CI status:
|
||||
```bash
|
||||
gh pr checks {N} --repo Significant-Gravitas/AutoGPT --json bucket,name,link
|
||||
```
|
||||
Parse the results: if every check has `bucket` of `"pass"` or `"skipping"`, CI is green. If any has `"fail"`, CI has failed. Otherwise CI is still pending.
|
||||
|
||||
2. Check for merge conflicts:
|
||||
```bash
|
||||
gh pr view {N} --repo Significant-Gravitas/AutoGPT --json mergeable --jq '.mergeable'
|
||||
```
|
||||
If the result is `"CONFLICTING"`, the PR has a merge conflict — see "Resolving merge conflicts" below. If `"UNKNOWN"`, GitHub is still computing mergeability — wait and re-check next poll.
|
||||
|
||||
3. Check for new/changed comments (all three sources):
|
||||
|
||||
**Inline threads** — re-run the GraphQL query from "Fetch comments". For each unresolved thread, record `{thread_id, last_comment_databaseId}` as your baseline. On each poll, action is needed if:
|
||||
- A new thread `id` appears that wasn't in the baseline (new thread), OR
|
||||
- An existing thread's `last_comment_databaseId` has changed (new reply on existing thread)
|
||||
|
||||
**Conversation comments:**
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
|
||||
```
|
||||
Compare total count and newest `id` against baseline. Filter to non-empty, non-bot, non-author-update messages.
|
||||
|
||||
**Top-level reviews:**
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
|
||||
```
|
||||
Watch for new non-empty reviews (`CHANGES_REQUESTED` or `COMMENTED` with body). Compare total count and newest `id` against baseline.
|
||||
|
||||
4. **React in this precedence order (first match wins):**
|
||||
|
||||
| What happened | Action |
|
||||
|---|---|
|
||||
| Merge conflict detected | See "Resolving merge conflicts" below. |
|
||||
| Mergeability is `UNKNOWN` | GitHub is still computing mergeability. Sleep 30 seconds, then restart polling from the top. |
|
||||
| New comments detected | Address them (fix → commit → push → reply). After pushing, re-fetch all comments to update your baseline, then restart this polling loop from the top (new commits invalidate CI status). |
|
||||
| CI failed (bucket == "fail") | Get failed check links: `gh pr checks {N} --repo Significant-Gravitas/AutoGPT --json bucket,link --jq '.[] \| select(.bucket == "fail") \| .link'`. Extract run ID from link (format: `.../actions/runs/<run-id>/job/...`), read logs with `gh run view <run-id> --repo Significant-Gravitas/AutoGPT --log-failed`. Fix → commit → push → restart polling. |
|
||||
| CI green + no new comments | **Do not exit immediately.** Bots (coderabbitai, sentry) often post reviews shortly after CI settles. Continue polling for **2 more cycles (60s)** after CI goes green. Only exit after 2 consecutive green+quiet polls. |
|
||||
| CI pending + no new comments | Sleep 30 seconds, then poll again. |
|
||||
|
||||
**The loop ends when:** CI fully green + all comments addressed + **2 consecutive polls with no new comments after CI settled.**
|
||||
|
||||
### Resolving merge conflicts
|
||||
|
||||
1. Identify the PR's target branch and remote:
|
||||
```bash
|
||||
gh pr view {N} --repo Significant-Gravitas/AutoGPT --json baseRefName --jq '.baseRefName'
|
||||
git remote -v # find the remote pointing to Significant-Gravitas/AutoGPT (typically 'upstream' in forks, 'origin' for direct contributors)
|
||||
```
|
||||
|
||||
2. Pull the latest base branch with a 3-way merge:
|
||||
```bash
|
||||
git pull {base-remote} {base-branch} --no-rebase
|
||||
```
|
||||
|
||||
3. Resolve conflicting files, then verify no conflict markers remain:
|
||||
```bash
|
||||
if grep -R -n -E '^(<<<<<<<|=======|>>>>>>>)' <conflicted-files>; then
|
||||
echo "Unresolved conflict markers found — resolve before proceeding."
|
||||
exit 1
|
||||
fi
|
||||
```
|
||||
|
||||
4. Stage and push:
|
||||
```bash
|
||||
git add <conflicted-files>
|
||||
git commit -m "Resolve merge conflicts with {base-branch}"
|
||||
git push
|
||||
```
|
||||
|
||||
5. Restart the polling loop from the top — new commits reset CI status.
|
||||
31
.claude/skills/pr-create/SKILL.md
Normal file
31
.claude/skills/pr-create/SKILL.md
Normal file
@@ -0,0 +1,31 @@
|
||||
---
|
||||
name: pr-create
|
||||
description: Create a pull request for the current branch. TRIGGER when user asks to create a PR, open a pull request, push changes for review, or submit work for merging.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Create Pull Request
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Check for existing PR**: `gh pr view --json url -q .url 2>/dev/null` — if a PR already exists, output its URL and stop
|
||||
2. **Understand changes**: `git status`, `git diff dev...HEAD`, `git log dev..HEAD --oneline`
|
||||
3. **Read PR template**: `.github/PULL_REQUEST_TEMPLATE.md`
|
||||
4. **Draft PR title**: Use conventional commits format (see CLAUDE.md for types and scopes)
|
||||
5. **Fill out PR template** as the body — be thorough in the Changes section
|
||||
6. **Format first** (if relevant changes exist):
|
||||
- Backend: `cd autogpt_platform/backend && poetry run format`
|
||||
- Frontend: `cd autogpt_platform/frontend && pnpm format`
|
||||
- Fix any lint errors, then commit formatting changes before pushing
|
||||
7. **Push**: `git push -u origin HEAD`
|
||||
8. **Create PR**: `gh pr create --base dev`
|
||||
9. **Output** the PR URL
|
||||
|
||||
## Rules
|
||||
|
||||
- Always target `dev` branch
|
||||
- Do NOT run tests — CI will handle that
|
||||
- Use the PR template from `.github/PULL_REQUEST_TEMPLATE.md`
|
||||
@@ -1,74 +1,51 @@
|
||||
---
|
||||
name: pr-review
|
||||
description: Review a PR for correctness, security, code quality, and testing issues. TRIGGER when user asks to review a PR, check PR quality, or give feedback on a PR.
|
||||
description: Address all open PR review comments systematically. Fetches comments, addresses each one, reacts +1/-1, and replies when clarification is needed. Keeps iterating until all comments are addressed and CI is green. TRIGGER when user shares a PR URL, asks to address review comments, fix PR feedback, or respond to reviewer comments.
|
||||
user-invocable: true
|
||||
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# PR Review
|
||||
# PR Review Comment Workflow
|
||||
|
||||
## Find the PR
|
||||
## Steps
|
||||
|
||||
```bash
|
||||
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
|
||||
gh pr view {N}
|
||||
```
|
||||
1. **Find PR**: `gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT`
|
||||
2. **Fetch comments** (all three sources):
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` (top-level reviews)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` (inline review comments)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` (PR conversation comments)
|
||||
3. **Skip** comments already reacted to by PR author
|
||||
4. **For each unreacted comment**:
|
||||
- Read referenced code, make the fix (or reply if you disagree/need info)
|
||||
- **Inline review comments** (`pulls/{N}/comments`):
|
||||
- React: `gh api repos/.../pulls/comments/{ID}/reactions -f content="+1"` (or `-1`)
|
||||
- Reply: `gh api repos/.../pulls/{N}/comments/{ID}/replies -f body="..."`
|
||||
- **PR conversation comments** (`issues/{N}/comments`):
|
||||
- React: `gh api repos/.../issues/comments/{ID}/reactions -f content="+1"` (or `-1`)
|
||||
- No threaded replies — post a new issue comment if needed
|
||||
- **Top-level reviews**: no reaction API — address in code, reply via issue comment if needed
|
||||
5. **Include autogpt-reviewer bot fixes** too
|
||||
6. **Format**: `cd autogpt_platform/backend && poetry run format`, `cd autogpt_platform/frontend && pnpm format`
|
||||
7. **Commit & push**
|
||||
8. **Re-fetch comments** immediately — address any new unreacted ones before waiting on CI
|
||||
9. **Stay productive while CI runs** — don't idle. In priority order:
|
||||
- Run any pending local tests (`poetry run pytest`, e2e, etc.) and fix failures
|
||||
- Address any remaining comments
|
||||
- Only poll `gh pr checks {N}` as the last resort when there's truly nothing left to do
|
||||
10. **If CI fails** — fix, go back to step 6
|
||||
11. **Re-fetch comments again** after CI is green — address anything that appeared while CI was running
|
||||
12. **Done** only when: all comments reacted AND CI is green.
|
||||
|
||||
## Read the diff
|
||||
## CRITICAL: Do Not Stop
|
||||
|
||||
```bash
|
||||
gh pr diff {N}
|
||||
```
|
||||
**Loop is: address → format → commit → push → re-check comments → run local tests → wait CI → re-check comments → repeat.**
|
||||
|
||||
## Fetch existing review comments
|
||||
Never idle. If CI is running and you have nothing to address, run local tests. Waiting on CI is the last resort.
|
||||
|
||||
Before posting anything, fetch existing inline comments to avoid duplicates:
|
||||
## Rules
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
|
||||
```
|
||||
|
||||
## What to check
|
||||
|
||||
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
|
||||
|
||||
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
|
||||
|
||||
**Code quality:** apply rules from backend/frontend CLAUDE.md files.
|
||||
|
||||
**Architecture:** DRY, single responsibility, modular functions. `Security()` vs `Depends()` for FastAPI auth. `data:` for SSE events, `: comment` for heartbeats. `transaction=True` for Redis pipelines.
|
||||
|
||||
**Testing:** edge cases covered, colocated `*_test.py` (backend) / `__tests__/` (frontend), mocks target where symbol is **used** not defined, `AsyncMock` for async.
|
||||
|
||||
## Output format
|
||||
|
||||
Every comment **must** be prefixed with `🤖` and a criticality badge:
|
||||
|
||||
| Tier | Badge | Meaning |
|
||||
|---|---|---|
|
||||
| Blocker | `🔴 **Blocker**` | Must fix before merge |
|
||||
| Should Fix | `🟠 **Should Fix**` | Important improvement |
|
||||
| Nice to Have | `🟡 **Nice to Have**` | Minor suggestion |
|
||||
| Nit | `🔵 **Nit**` | Style / wording |
|
||||
|
||||
Example: `🤖 🔴 **Blocker**: Missing error handling for X — suggest wrapping in try/except.`
|
||||
|
||||
## Post inline comments
|
||||
|
||||
For each finding, post an inline comment on the PR (do not just write a local report):
|
||||
|
||||
```bash
|
||||
# Get the latest commit SHA for the PR
|
||||
COMMIT_SHA=$(gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.head.sha')
|
||||
|
||||
# Post an inline comment on a specific file/line
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments \
|
||||
-f body="🤖 🔴 **Blocker**: <description>" \
|
||||
-f commit_id="$COMMIT_SHA" \
|
||||
-f path="<file path>" \
|
||||
-F line=<line number>
|
||||
```
|
||||
- One todo per comment
|
||||
- For inline review comments: reply on existing threads. For PR conversation comments: post a new issue comment (API doesn't support threaded replies)
|
||||
- React to every comment: +1 addressed, -1 disagreed (with explanation)
|
||||
|
||||
45
.claude/skills/worktree-setup/SKILL.md
Normal file
45
.claude/skills/worktree-setup/SKILL.md
Normal file
@@ -0,0 +1,45 @@
|
||||
---
|
||||
name: worktree-setup
|
||||
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, generates Prisma client, and optionally starts the app (with port conflict resolution) or runs tests. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Worktree Setup
|
||||
|
||||
## Preferred: Use Branchlet
|
||||
|
||||
The repo has a `.branchlet.json` config — it handles env file copying, dependency installation, and Prisma generation automatically.
|
||||
|
||||
```bash
|
||||
npm install -g branchlet # install once
|
||||
branchlet create -n <name> -s <source-branch> -b <new-branch>
|
||||
branchlet list --json # list all worktrees
|
||||
```
|
||||
|
||||
## Manual Fallback
|
||||
|
||||
If branchlet isn't available:
|
||||
|
||||
1. `git worktree add ../<RepoName><N> <branch-name>`
|
||||
2. Copy `.env` files: `backend/.env`, `frontend/.env`, `autogpt_platform/.env`, `db/docker/.env`
|
||||
3. Install deps:
|
||||
- `cd autogpt_platform/backend && poetry install && poetry run prisma generate`
|
||||
- `cd autogpt_platform/frontend && pnpm install`
|
||||
|
||||
## Running the App
|
||||
|
||||
Free ports first — backend uses: 8001, 8002, 8003, 8005, 8006, 8007, 8008.
|
||||
|
||||
```bash
|
||||
for port in 8001 8002 8003 8005 8006 8007 8008; do
|
||||
lsof -ti :$port | xargs kill -9 2>/dev/null || true
|
||||
done
|
||||
cd <worktree>/autogpt_platform/backend && poetry run app
|
||||
```
|
||||
|
||||
## CoPilot Testing Gotcha
|
||||
|
||||
SDK mode spawns a Claude subprocess — **won't work inside Claude Code**. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
|
||||
@@ -1,85 +0,0 @@
|
||||
---
|
||||
name: worktree
|
||||
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, and generates Prisma client. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
|
||||
user-invocable: true
|
||||
args: "[name] — optional worktree name (e.g., 'AutoGPT7'). If omitted, uses next available AutoGPT<N>."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "3.0.0"
|
||||
---
|
||||
|
||||
# Worktree Setup
|
||||
|
||||
## Create the worktree
|
||||
|
||||
Derive paths from the git toplevel. If a name is provided as argument, use it. Otherwise, check `git worktree list` and pick the next `AutoGPT<N>`.
|
||||
|
||||
```bash
|
||||
ROOT=$(git rev-parse --show-toplevel)
|
||||
PARENT=$(dirname "$ROOT")
|
||||
|
||||
# From an existing branch
|
||||
git worktree add "$PARENT/<NAME>" <branch-name>
|
||||
|
||||
# From a new branch off dev
|
||||
git worktree add -b <new-branch> "$PARENT/<NAME>" dev
|
||||
```
|
||||
|
||||
## Copy environment files
|
||||
|
||||
Copy `.env` from the root worktree. Falls back to `.env.default` if `.env` doesn't exist.
|
||||
|
||||
```bash
|
||||
ROOT=$(git rev-parse --show-toplevel)
|
||||
TARGET="$(dirname "$ROOT")/<NAME>"
|
||||
|
||||
for envpath in autogpt_platform/backend autogpt_platform/frontend autogpt_platform; do
|
||||
if [ -f "$ROOT/$envpath/.env" ]; then
|
||||
cp "$ROOT/$envpath/.env" "$TARGET/$envpath/.env"
|
||||
elif [ -f "$ROOT/$envpath/.env.default" ]; then
|
||||
cp "$ROOT/$envpath/.env.default" "$TARGET/$envpath/.env"
|
||||
fi
|
||||
done
|
||||
```
|
||||
|
||||
## Install dependencies
|
||||
|
||||
```bash
|
||||
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||
cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install
|
||||
cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate
|
||||
cd "$TARGET/autogpt_platform/frontend" && pnpm install
|
||||
```
|
||||
|
||||
Replace `<NAME>` with the actual worktree name (e.g., `AutoGPT7`).
|
||||
|
||||
## Running the app (optional)
|
||||
|
||||
Backend uses ports: 8001, 8002, 8003, 8005, 8006, 8007, 8008. Free them first if needed:
|
||||
|
||||
```bash
|
||||
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||
for port in 8001 8002 8003 8005 8006 8007 8008; do
|
||||
lsof -ti :$port | xargs kill -9 2>/dev/null || true
|
||||
done
|
||||
cd "$TARGET/autogpt_platform/backend" && poetry run app
|
||||
```
|
||||
|
||||
## CoPilot testing
|
||||
|
||||
SDK mode spawns a Claude subprocess — won't work inside Claude Code. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
|
||||
|
||||
## Cleanup
|
||||
|
||||
```bash
|
||||
# Replace <NAME> with the actual worktree name (e.g., AutoGPT7)
|
||||
git worktree remove "$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||
```
|
||||
|
||||
## Alternative: Branchlet (optional)
|
||||
|
||||
If [branchlet](https://www.npmjs.com/package/branchlet) is installed:
|
||||
|
||||
```bash
|
||||
branchlet create -n <name> -s <source-branch> -b <new-branch>
|
||||
```
|
||||
116
.github/workflows/platform-backend-ci.yml
vendored
116
.github/workflows/platform-backend-ci.yml
vendored
@@ -5,14 +5,12 @@ on:
|
||||
branches: [master, dev, ci-test*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/backend/**"
|
||||
- "autogpt_platform/autogpt_libs/**"
|
||||
pull_request:
|
||||
branches: [master, dev, release-*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/backend/**"
|
||||
- "autogpt_platform/autogpt_libs/**"
|
||||
merge_group:
|
||||
@@ -27,91 +25,10 @@ defaults:
|
||||
working-directory: autogpt_platform/backend
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 10
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-py3.12-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Using Poetry version ${HEAD_POETRY_VERSION}"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run Linters
|
||||
run: poetry run lint --skip-pyright
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
type-check:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11", "3.12", "3.13"]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Using Poetry version ${HEAD_POETRY_VERSION}"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
- name: Run Pyright
|
||||
run: poetry run pyright --pythonversion ${{ matrix.python-version }}
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 15
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -179,9 +96,9 @@ jobs:
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
- name: Install Poetry (Unix)
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock
|
||||
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
@@ -239,22 +156,22 @@ jobs:
|
||||
echo "Waiting for ClamAV daemon to start..."
|
||||
max_attempts=60
|
||||
attempt=0
|
||||
|
||||
|
||||
until nc -z localhost 3310 || [ $attempt -eq $max_attempts ]; do
|
||||
echo "ClamAV is unavailable - sleeping (attempt $((attempt+1))/$max_attempts)"
|
||||
sleep 5
|
||||
attempt=$((attempt+1))
|
||||
done
|
||||
|
||||
|
||||
if [ $attempt -eq $max_attempts ]; then
|
||||
echo "ClamAV failed to start after $((max_attempts*5)) seconds"
|
||||
echo "Checking ClamAV service logs..."
|
||||
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "ClamAV is ready!"
|
||||
|
||||
|
||||
# Verify ClamAV is responsive
|
||||
echo "Testing ClamAV connection..."
|
||||
timeout 10 bash -c 'echo "PING" | nc localhost 3310' || {
|
||||
@@ -269,13 +186,18 @@ jobs:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- name: Run pytest
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
run: poetry run lint
|
||||
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
else
|
||||
poetry run pytest -s -vv
|
||||
fi
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
@@ -287,12 +209,6 @@ jobs:
|
||||
REDIS_PORT: "6379"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
@@ -306,3 +222,9 @@ jobs:
|
||||
# the backend service, docker composes, and examples
|
||||
RABBITMQ_DEFAULT_USER: "rabbitmq_user_default"
|
||||
RABBITMQ_DEFAULT_PASS: "k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7"
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
|
||||
169
.github/workflows/platform-frontend-ci.yml
vendored
169
.github/workflows/platform-frontend-ci.yml
vendored
@@ -120,6 +120,175 @@ jobs:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
exitOnceUploaded: true
|
||||
|
||||
e2e_test:
|
||||
name: end-to-end tests
|
||||
runs-on: big-boi
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Platform - Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Platform - Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver: docker-container
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v4
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: /tmp/e2e_test_data.sql
|
||||
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
|
||||
|
||||
- name: Set up Platform - Start Supabase DB + Auth
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for auth service to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
||||
|
||||
- name: Set up Platform - Run migrations
|
||||
run: |
|
||||
echo "Running migrations..."
|
||||
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
||||
echo "✅ Migrations completed"
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Load cached E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
||||
run: |
|
||||
echo "✅ Found cached E2E test data, restoring..."
|
||||
{
|
||||
echo "SET session_replication_role = 'replica';"
|
||||
cat /tmp/e2e_test_data.sql
|
||||
echo "SET session_replication_role = 'origin';"
|
||||
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
||||
# Refresh materialized views after restore
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
||||
|
||||
echo "✅ E2E test data restored from cache"
|
||||
|
||||
- name: Set up Platform - Start (all other services)
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Create E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
echo "Creating E2E test data..."
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Dump auth.users + platform schema for cache (two separate dumps)
|
||||
echo "Dumping database for cache..."
|
||||
{
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--table='auth.users' postgres
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--schema=platform \
|
||||
--exclude-table='platform._prisma_migrations' \
|
||||
--exclude-table='platform.apscheduler_jobs' \
|
||||
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
||||
postgres
|
||||
} > /tmp/e2e_test_data.sql
|
||||
|
||||
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
||||
|
||||
- name: Set up tests - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up tests - Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Upload Playwright test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.resolved.yml logs
|
||||
|
||||
integration_test:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
314
.github/workflows/platform-fullstack-ci.yml
vendored
314
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -1,18 +1,14 @@
|
||||
name: AutoGPT Platform - Full-stack CI
|
||||
name: AutoGPT Platform - Frontend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/**"
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/**"
|
||||
merge_group:
|
||||
|
||||
@@ -28,28 +24,42 @@ defaults:
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Node
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies to populate cache
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
check-api-types:
|
||||
name: check API types
|
||||
runs-on: ubuntu-latest
|
||||
types:
|
||||
runs-on: big-boi
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -57,256 +67,70 @@ jobs:
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
# ------------------------ Backend setup ------------------------
|
||||
|
||||
- name: Set up Backend - Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Set up Backend - Install Poetry
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Installing Poetry version ${POETRY_VERSION}"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$POETRY_VERSION python3 -
|
||||
|
||||
- name: Set up Backend - Set up dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Set up Backend - Install dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Set up Backend - Generate Prisma client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
- name: Set up Frontend - Export OpenAPI schema from Backend
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
|
||||
|
||||
# ------------------------ Frontend setup ------------------------
|
||||
|
||||
- name: Set up Frontend - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Frontend - Set up Node
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Set up Frontend - Install dependencies
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Set up Frontend - Format OpenAPI schema
|
||||
id: format-schema
|
||||
run: pnpm prettier --write ./src/app/api/openapi.json
|
||||
- name: Setup .env
|
||||
run: cp .env.default .env
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Generate API queries
|
||||
run: pnpm generate:api:force
|
||||
|
||||
- name: Check for API schema changes
|
||||
run: |
|
||||
if ! git diff --exit-code src/app/api/openapi.json; then
|
||||
echo "❌ API schema changes detected in src/app/api/openapi.json"
|
||||
echo ""
|
||||
echo "The openapi.json file has been modified after exporting the API schema."
|
||||
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
|
||||
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
|
||||
echo "The API schema is now out of sync with the Front-end queries."
|
||||
echo ""
|
||||
echo "To fix this:"
|
||||
echo "\nIn the backend directory:"
|
||||
echo "1. Run 'poetry run export-api-schema --output ../frontend/src/app/api/openapi.json'"
|
||||
echo "\nIn the frontend directory:"
|
||||
echo "2. Run 'pnpm prettier --write src/app/api/openapi.json'"
|
||||
echo "3. Run 'pnpm generate:api'"
|
||||
echo "4. Run 'pnpm types'"
|
||||
echo "5. Fix any TypeScript errors that may have been introduced"
|
||||
echo "6. Commit and push your changes"
|
||||
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
|
||||
echo "2. Run 'pnpm generate:api' locally"
|
||||
echo "3. Run 'pnpm types' locally"
|
||||
echo "4. Fix any TypeScript errors that may have been introduced"
|
||||
echo "5. Commit and push your changes"
|
||||
echo ""
|
||||
exit 1
|
||||
else
|
||||
echo "✅ No API schema changes detected"
|
||||
fi
|
||||
|
||||
- name: Set up Frontend - Generate API client
|
||||
id: generate-api-client
|
||||
run: pnpm orval --config ./orval.config.ts
|
||||
# Continue with type generation & check even if there are schema changes
|
||||
if: success() || (steps.format-schema.outcome == 'success')
|
||||
|
||||
- name: Check for TypeScript errors
|
||||
- name: Run Typescript checks
|
||||
run: pnpm types
|
||||
if: success() || (steps.generate-api-client.outcome == 'success')
|
||||
|
||||
e2e_test:
|
||||
name: end-to-end tests
|
||||
runs-on: big-boi
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Platform - Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Platform - Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver: docker-container
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v4
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: /tmp/e2e_test_data.sql
|
||||
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-fullstack-ci.yml') }}
|
||||
|
||||
- name: Set up Platform - Start Supabase DB + Auth
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for auth service to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
||||
|
||||
- name: Set up Platform - Run migrations
|
||||
run: |
|
||||
echo "Running migrations..."
|
||||
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
||||
echo "✅ Migrations completed"
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Load cached E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
||||
run: |
|
||||
echo "✅ Found cached E2E test data, restoring..."
|
||||
{
|
||||
echo "SET session_replication_role = 'replica';"
|
||||
cat /tmp/e2e_test_data.sql
|
||||
echo "SET session_replication_role = 'origin';"
|
||||
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
||||
# Refresh materialized views after restore
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
||||
|
||||
echo "✅ E2E test data restored from cache"
|
||||
|
||||
- name: Set up Platform - Start (all other services)
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Create E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
echo "Creating E2E test data..."
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Dump auth.users + platform schema for cache (two separate dumps)
|
||||
echo "Dumping database for cache..."
|
||||
{
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--table='auth.users' postgres
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--schema=platform \
|
||||
--exclude-table='platform._prisma_migrations' \
|
||||
--exclude-table='platform.apscheduler_jobs' \
|
||||
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
||||
postgres
|
||||
} > /tmp/e2e_test_data.sql
|
||||
|
||||
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
||||
|
||||
- name: Set up tests - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up tests - Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: autogpt_platform/frontend/playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Upload Playwright test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: autogpt_platform/frontend/test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.resolved.yml logs
|
||||
|
||||
@@ -56,36 +56,13 @@ AutoGPT Platform is a monorepo containing:
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
## Summary
|
||||
- use `backticks` freely here
|
||||
PREOF
|
||||
gh pr create --title "..." --body-file "$PR_BODY" --base dev
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, follow a test-first approach:
|
||||
|
||||
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
|
||||
2. **Implement the fix/feature** — write the minimal code to make the test pass.
|
||||
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
|
||||
|
||||
This ensures every change is covered by a test and that the test actually validates the intended behavior.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
|
||||
@@ -58,56 +58,10 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, write the test **before** the implementation:
|
||||
|
||||
```python
|
||||
# 1. Write a failing test marked xfail
|
||||
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
|
||||
# 2. Run it — confirm it fails (XFAIL)
|
||||
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
|
||||
|
||||
# 3. Implement the fix
|
||||
|
||||
# 4. Remove xfail, run again — confirm it passes
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
```
|
||||
|
||||
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
|
||||
|
||||
## Database Schema
|
||||
|
||||
@@ -203,16 +157,6 @@ yield "image_url", result_url
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Workspace & Media Files
|
||||
|
||||
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
|
||||
- Working on CoPilot file upload/download features
|
||||
- Building blocks that handle `MediaFileType` inputs/outputs
|
||||
- Modifying `WorkspaceManager` or `store_media_file()`
|
||||
- Debugging file persistence or virus scanning issues
|
||||
|
||||
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
@@ -50,7 +50,7 @@ RUN poetry install --no-ansi --no-root
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/scripts/gen_prisma_types_stub.py ./scripts/
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# =============================== DB MIGRATOR =============================== #
|
||||
@@ -82,7 +82,7 @@ RUN pip3 install prisma>=0.15.0 --break-system-packages
|
||||
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/scripts/gen_prisma_types_stub.py ./scripts/
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
COPY autogpt_platform/backend/migrations ./migrations
|
||||
|
||||
# ============================== BACKEND SERVER ============================== #
|
||||
@@ -121,37 +121,19 @@ RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install agent-browser (Copilot browser tool) + Chromium.
|
||||
# On amd64: install runtime libs + run `agent-browser install` to download
|
||||
# Chrome for Testing (pinned version, tested with Playwright).
|
||||
# On arm64: install system chromium package — Chrome for Testing has no ARM64
|
||||
# binary. AGENT_BROWSER_EXECUTABLE_PATH is set at runtime by the entrypoint
|
||||
# script (below) to redirect agent-browser to the system binary.
|
||||
ARG TARGETARCH
|
||||
RUN apt-get update \
|
||||
&& if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
apt-get install -y --no-install-recommends chromium fonts-liberation; \
|
||||
else \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||
fonts-liberation libfontconfig1; \
|
||||
fi \
|
||||
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
|
||||
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||
fonts-liberation libfontconfig1 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& npm install -g agent-browser \
|
||||
&& ([ "$TARGETARCH" = "arm64" ] || agent-browser install) \
|
||||
&& agent-browser install \
|
||||
&& rm -rf /tmp/* /root/.npm
|
||||
|
||||
# On arm64 the system chromium is at /usr/bin/chromium; set
|
||||
# AGENT_BROWSER_EXECUTABLE_PATH so agent-browser's daemon uses it instead of
|
||||
# Chrome for Testing (which has no ARM64 binary). On amd64 the variable is left
|
||||
# unset so agent-browser uses the Chrome for Testing binary it downloaded above.
|
||||
RUN printf '#!/bin/sh\n[ -x /usr/bin/chromium ] && export AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium\nexec "$@"\n' \
|
||||
> /usr/local/bin/entrypoint.sh \
|
||||
&& chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
# Copy only the .venv from builder (not the entire /app directory)
|
||||
@@ -173,5 +155,4 @@ RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
|
||||
ENV PORT=8000
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["rest"]
|
||||
|
||||
@@ -4,12 +4,14 @@ from difflib import SequenceMatcher
|
||||
from typing import Any, Sequence, get_args, get_origin
|
||||
|
||||
import prisma
|
||||
from prisma.enums import ContentType
|
||||
from prisma.models import mv_suggested_blocks
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks._base import (
|
||||
AnyBlockSchema,
|
||||
@@ -22,7 +24,6 @@ from backend.blocks.llm import LlmModel
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.text import split_camelcase
|
||||
|
||||
from .model import (
|
||||
BlockCategoryResponse,
|
||||
@@ -270,7 +271,7 @@ async def _build_cached_search_results(
|
||||
|
||||
# Use hybrid search when query is present, otherwise list all blocks
|
||||
if (include_blocks or include_integrations) and normalized_query:
|
||||
block_results, block_total, integration_total = await _text_search_blocks(
|
||||
block_results, block_total, integration_total = await _hybrid_search_blocks(
|
||||
query=search_query,
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
@@ -382,75 +383,117 @@ def _collect_block_results(
|
||||
return results, block_count, integration_count
|
||||
|
||||
|
||||
async def _text_search_blocks(
|
||||
async def _hybrid_search_blocks(
|
||||
*,
|
||||
query: str,
|
||||
include_blocks: bool,
|
||||
include_integrations: bool,
|
||||
) -> tuple[list[_ScoredItem], int, int]:
|
||||
"""
|
||||
Search blocks using in-memory text matching over the block registry.
|
||||
Search blocks using hybrid search with builder-specific filtering.
|
||||
|
||||
All blocks are already loaded in memory, so this is fast and reliable
|
||||
regardless of whether OpenAI embeddings are available.
|
||||
Uses unified_hybrid_search for semantic + lexical search, then applies
|
||||
post-filtering for block/integration types and scoring adjustments.
|
||||
|
||||
Scoring:
|
||||
- Base: text relevance via _score_primary_fields, plus BLOCK_SCORE_BOOST
|
||||
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
|
||||
to prioritize blocks over marketplace agents in combined results
|
||||
- +30 for exact name match, +15 for prefix name match
|
||||
- +20 if the block has an LlmModel field and the query matches an LLM model name
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
include_blocks: Whether to include regular blocks
|
||||
include_integrations: Whether to include integration blocks
|
||||
|
||||
Returns:
|
||||
Tuple of (scored_items, block_count, integration_count)
|
||||
"""
|
||||
results: list[_ScoredItem] = []
|
||||
block_count = 0
|
||||
integration_count = 0
|
||||
|
||||
if not include_blocks and not include_integrations:
|
||||
return results, 0, 0
|
||||
return results, block_count, integration_count
|
||||
|
||||
normalized_query = query.strip().lower()
|
||||
|
||||
all_results, _, _ = _collect_block_results(
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
# Fetch more results to account for post-filtering
|
||||
search_results, _ = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=150,
|
||||
min_score=0.10,
|
||||
)
|
||||
|
||||
# Load all blocks for getting BlockInfo
|
||||
all_blocks = load_all_blocks()
|
||||
|
||||
for item in all_results:
|
||||
block_info = item.item
|
||||
assert isinstance(block_info, BlockInfo)
|
||||
name = split_camelcase(block_info.name).lower()
|
||||
for result in search_results:
|
||||
block_id = result["content_id"]
|
||||
|
||||
# Build rich description including input field descriptions,
|
||||
# matching the searchable text that the embedding pipeline uses
|
||||
desc_parts = [block_info.description or ""]
|
||||
block_cls = all_blocks.get(block_info.id)
|
||||
if block_cls is not None:
|
||||
block: AnyBlockSchema = block_cls()
|
||||
desc_parts += [
|
||||
f"{f}: {info.description}"
|
||||
for f, info in block.input_schema.model_fields.items()
|
||||
if info.description
|
||||
]
|
||||
description = " ".join(desc_parts).lower()
|
||||
# Skip excluded blocks
|
||||
if block_id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
|
||||
score = _score_primary_fields(name, description, normalized_query)
|
||||
metadata = result.get("metadata", {})
|
||||
hybrid_score = result.get("relevance", 0.0)
|
||||
|
||||
# Get the actual block class
|
||||
if block_id not in all_blocks:
|
||||
continue
|
||||
|
||||
block_cls = all_blocks[block_id]
|
||||
block: AnyBlockSchema = block_cls()
|
||||
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
# Check block/integration filter using metadata
|
||||
is_integration = metadata.get("is_integration", False)
|
||||
|
||||
if is_integration and not include_integrations:
|
||||
continue
|
||||
if not is_integration and not include_blocks:
|
||||
continue
|
||||
|
||||
# Get block info
|
||||
block_info = block.get_info()
|
||||
|
||||
# Calculate final score: scale hybrid score and add builder-specific bonuses
|
||||
# Hybrid scores are 0-1, builder scores were 0-200+
|
||||
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
|
||||
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
|
||||
|
||||
# Add LLM model match bonus
|
||||
if block_cls is not None and _matches_llm_model(
|
||||
block_cls().input_schema, normalized_query
|
||||
):
|
||||
score += 20
|
||||
has_llm_field = metadata.get("has_llm_model_field", False)
|
||||
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
|
||||
final_score += 20
|
||||
|
||||
if score >= MIN_SCORE_FOR_FILTERED_RESULTS:
|
||||
results.append(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=item.filter_type,
|
||||
score=score + BLOCK_SCORE_BOOST,
|
||||
sort_key=name,
|
||||
)
|
||||
# Add exact/prefix match bonus for deterministic tie-breaking
|
||||
name = block_info.name.lower()
|
||||
if name == normalized_query:
|
||||
final_score += 30
|
||||
elif name.startswith(normalized_query):
|
||||
final_score += 15
|
||||
|
||||
# Track counts
|
||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||
if is_integration:
|
||||
integration_count += 1
|
||||
else:
|
||||
block_count += 1
|
||||
|
||||
results.append(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=filter_type,
|
||||
score=final_score,
|
||||
sort_key=name,
|
||||
)
|
||||
)
|
||||
|
||||
block_count = sum(1 for r in results if r.filter_type == "blocks")
|
||||
integration_count = sum(1 for r in results if r.filter_type == "integrations")
|
||||
return results, block_count, integration_count
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Annotated
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, HTTPException, Query, Response, Security
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@@ -27,12 +27,6 @@ from backend.copilot.model import (
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
check_rate_limit,
|
||||
get_usage_status,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
@@ -59,7 +53,6 @@ from backend.copilot.tools.models import (
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -125,8 +118,6 @@ class SessionDetailResponse(BaseModel):
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -136,7 +127,6 @@ class SessionSummaryResponse(BaseModel):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
title: str | None = None
|
||||
is_processing: bool
|
||||
|
||||
|
||||
class ListSessionsResponse(BaseModel):
|
||||
@@ -195,28 +185,6 @@ async def list_sessions(
|
||||
"""
|
||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||
|
||||
# Batch-check Redis for active stream status on each session
|
||||
processing_set: set[str] = set()
|
||||
if sessions:
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
pipe = redis.pipeline(transaction=False)
|
||||
for session in sessions:
|
||||
pipe.hget(
|
||||
f"{config.session_meta_prefix}{session.session_id}",
|
||||
"status",
|
||||
)
|
||||
statuses = await pipe.execute()
|
||||
processing_set = {
|
||||
session.session_id
|
||||
for session, st in zip(sessions, statuses)
|
||||
if st == "running"
|
||||
}
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch processing status from Redis; defaulting to empty"
|
||||
)
|
||||
|
||||
return ListSessionsResponse(
|
||||
sessions=[
|
||||
SessionSummaryResponse(
|
||||
@@ -224,7 +192,6 @@ async def list_sessions(
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
title=session.title,
|
||||
is_processing=session.session_id in processing_set,
|
||||
)
|
||||
for session in sessions
|
||||
],
|
||||
@@ -236,7 +203,7 @@ async def list_sessions(
|
||||
"/sessions",
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
user_id: Annotated[str, Depends(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
@@ -355,7 +322,7 @@ async def update_session_title_route(
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
@@ -396,10 +363,6 @@ async def get_session(
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
# Sum token usage from session
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
@@ -407,25 +370,6 @@ async def get_session(
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/usage",
|
||||
)
|
||||
async def get_copilot_usage(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get CoPilot usage status for the authenticated user.
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
"""
|
||||
return await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -435,7 +379,7 @@ async def get_copilot_usage(
|
||||
)
|
||||
async def cancel_session_task(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CancelSessionResponse:
|
||||
"""Cancel the active streaming task for a session.
|
||||
|
||||
@@ -480,7 +424,7 @@ async def cancel_session_task(
|
||||
async def stream_chat_post(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (POST with context support).
|
||||
@@ -497,7 +441,7 @@ async def stream_chat_post(
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Authenticated user ID.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
@@ -506,7 +450,9 @@ async def stream_chat_post(
|
||||
import time
|
||||
|
||||
stream_start_time = time.perf_counter()
|
||||
log_meta = {"component": "ChatStream", "session_id": session_id, "user_id": user_id}
|
||||
log_meta = {"component": "ChatStream", "session_id": session_id}
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
|
||||
@@ -524,18 +470,6 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Pre-turn rate limit check (token-based).
|
||||
# check_rate_limit short-circuits internally when both limits are 0.
|
||||
if user_id:
|
||||
try:
|
||||
await check_rate_limit(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
)
|
||||
except RateLimitExceeded as e:
|
||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||
|
||||
# Enrich message with file metadata if file_ids are provided.
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
@@ -770,7 +704,7 @@ async def stream_chat_post(
|
||||
)
|
||||
async def resume_session_stream(
|
||||
session_id: str,
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Resume an active stream for a session.
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, and rate limiting."""
|
||||
"""Tests for chat API routes: session title update and file attachment validation."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
@@ -250,153 +249,3 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
|
||||
|
||||
# ─── Rate limit → 429 ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFixture):
|
||||
"""When check_rate_limit raises RateLimitExceeded for daily limit the endpoint returns 429."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
# Ensure the rate-limit branch is entered by setting a non-zero limit.
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
assert "daily" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFixture):
|
||||
"""When check_rate_limit raises RateLimitExceeded for weekly limit the endpoint returns 429."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
resets_at = datetime.now(UTC) + timedelta(days=3)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded("weekly", resets_at),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
detail = response.json()["detail"].lower()
|
||||
assert "weekly" in detail
|
||||
assert "resets in" in detail
|
||||
|
||||
|
||||
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockFixture):
|
||||
"""The 429 response detail should include the human-readable reset time."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded(
|
||||
"daily", datetime.now(UTC) + timedelta(hours=2, minutes=30)
|
||||
),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
detail = response.json()["detail"]
|
||||
assert "2h" in detail
|
||||
assert "Resets in" in detail
|
||||
|
||||
|
||||
# ─── Usage endpoint ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_usage(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
daily_used: int = 500,
|
||||
weekly_used: int = 2000,
|
||||
) -> AsyncMock:
|
||||
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
resets_at = datetime.now(UTC) + timedelta(days=1)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
|
||||
)
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=status,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_returns_daily_and_weekly(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""GET /usage returns daily and weekly usage."""
|
||||
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily"]["used"] == 500
|
||||
assert data["weekly"]["used"] == 2000
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_uses_config_limits(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
|
||||
mock_get = _mock_usage(mocker)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_rejects_unauthenticated_request() -> None:
|
||||
"""GET /usage should return 401 when no valid JWT is provided."""
|
||||
unauthenticated_app = fastapi.FastAPI()
|
||||
unauthenticated_app.include_router(chat_routes.router)
|
||||
unauthenticated_client = fastapi.testclient.TestClient(unauthenticated_app)
|
||||
|
||||
response = unauthenticated_client.get("/usage")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
@@ -165,6 +165,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
owner_user_id: str
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -205,9 +206,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
default_factory=list,
|
||||
description="List of recent executions with status, score, and summary",
|
||||
)
|
||||
can_access_graph: bool = pydantic.Field(
|
||||
description="Indicates whether the same user owns the corresponding graph"
|
||||
)
|
||||
can_access_graph: bool
|
||||
is_latest_version: bool
|
||||
is_favorite: bool
|
||||
folder_id: str | None = None
|
||||
@@ -325,6 +324,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id=agent.id,
|
||||
graph_id=agent.agentGraphId,
|
||||
graph_version=agent.agentGraphVersion,
|
||||
owner_user_id=agent.userId,
|
||||
image_url=agent.imageUrl,
|
||||
creator_name=creator_name,
|
||||
creator_image_url=creator_image_url,
|
||||
|
||||
@@ -42,6 +42,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -66,6 +67,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-2",
|
||||
graph_id="test-agent-2",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
@@ -129,6 +131,7 @@ async def test_get_favorite_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Favorite Agent 1",
|
||||
description="Test Favorite Description 1",
|
||||
image_url=None,
|
||||
@@ -181,6 +184,7 @@ def test_add_agent_to_library_success(
|
||||
id="test-library-agent-id",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
|
||||
@@ -5,26 +5,16 @@ Pluggable system for different content sources (store agents, blocks, docs).
|
||||
Each handler knows how to fetch and process its content type for embedding.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, get_args, get_origin
|
||||
from typing import Any, get_args, get_origin
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.blocks import get_blocks
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.util.text import split_camelcase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -164,28 +154,6 @@ class StoreAgentHandler(ContentHandler):
|
||||
}
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _get_enabled_blocks() -> dict[str, AnyBlockSchema]:
|
||||
"""Return ``{block_id: block_instance}`` for all enabled, instantiable blocks.
|
||||
|
||||
Disabled blocks and blocks that fail to instantiate are silently skipped
|
||||
(with a warning log), so callers never need their own try/except loop.
|
||||
|
||||
Results are cached for the process lifetime via ``lru_cache`` because
|
||||
blocks are registered at import time and never change while running.
|
||||
"""
|
||||
enabled: dict[str, AnyBlockSchema] = {}
|
||||
for block_id, block_cls in get_blocks().items():
|
||||
try:
|
||||
instance = block_cls()
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping block {block_id}: init failed: {e}")
|
||||
continue
|
||||
if not instance.disabled:
|
||||
enabled[block_id] = instance
|
||||
return enabled
|
||||
|
||||
|
||||
class BlockHandler(ContentHandler):
|
||||
"""Handler for block definitions (Python classes)."""
|
||||
|
||||
@@ -195,14 +163,16 @@ class BlockHandler(ContentHandler):
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch blocks without embeddings."""
|
||||
# to_thread keeps the first (heavy) call off the event loop. On
|
||||
# subsequent calls the lru_cache makes this a dict lookup, so the
|
||||
# thread-pool overhead is negligible compared to the DB queries below.
|
||||
enabled = await asyncio.to_thread(_get_enabled_blocks)
|
||||
if not enabled:
|
||||
from backend.blocks import get_blocks
|
||||
|
||||
# Get all available blocks
|
||||
all_blocks = get_blocks()
|
||||
|
||||
# Check which ones have embeddings
|
||||
if not all_blocks:
|
||||
return []
|
||||
|
||||
block_ids = list(enabled.keys())
|
||||
block_ids = list(all_blocks.keys())
|
||||
|
||||
# Query for existing embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
@@ -217,42 +187,52 @@ class BlockHandler(ContentHandler):
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
missing_blocks = [
|
||||
(block_id, block_cls)
|
||||
for block_id, block_cls in all_blocks.items()
|
||||
if block_id not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem — disabled filtering already done by
|
||||
# _get_enabled_blocks so batch_size won't be exhausted by disabled blocks.
|
||||
missing = ((bid, b) for bid, b in enabled.items() if bid not in existing_ids)
|
||||
# Convert to ContentItem
|
||||
items = []
|
||||
for block_id, block in itertools.islice(missing, batch_size):
|
||||
for block_id, block_cls in missing_blocks[:batch_size]:
|
||||
try:
|
||||
block_instance = block_cls()
|
||||
|
||||
if block_instance.disabled:
|
||||
continue
|
||||
|
||||
# Build searchable text from block metadata
|
||||
if not block.name:
|
||||
logger.warning(
|
||||
f"Block {block_id} has no name — using block_id as fallback"
|
||||
)
|
||||
display_name = split_camelcase(block.name) if block.name else ""
|
||||
parts = []
|
||||
if display_name:
|
||||
parts.append(display_name)
|
||||
if block.description:
|
||||
parts.append(block.description)
|
||||
if block.categories:
|
||||
parts.append(" ".join(str(cat.value) for cat in block.categories))
|
||||
if block_instance.name:
|
||||
parts.append(block_instance.name)
|
||||
if block_instance.description:
|
||||
parts.append(block_instance.description)
|
||||
if block_instance.categories:
|
||||
parts.append(
|
||||
" ".join(str(cat.value) for cat in block_instance.categories)
|
||||
)
|
||||
|
||||
# Add input schema field descriptions
|
||||
block_input_fields = block_instance.input_schema.model_fields
|
||||
parts += [
|
||||
f"{field_name}: {field_info.description}"
|
||||
for field_name, field_info in block.input_schema.model_fields.items()
|
||||
for field_name, field_info in block_input_fields.items()
|
||||
if field_info.description
|
||||
]
|
||||
|
||||
searchable_text = " ".join(parts)
|
||||
|
||||
categories_list = (
|
||||
[cat.value for cat in block.categories] if block.categories else []
|
||||
[cat.value for cat in block_instance.categories]
|
||||
if block_instance.categories
|
||||
else []
|
||||
)
|
||||
|
||||
# Extract provider names from credentials fields
|
||||
credentials_info = block.input_schema.get_credentials_fields_info()
|
||||
credentials_info = (
|
||||
block_instance.input_schema.get_credentials_fields_info()
|
||||
)
|
||||
is_integration = len(credentials_info) > 0
|
||||
provider_names = [
|
||||
provider.value.lower()
|
||||
@@ -263,7 +243,7 @@ class BlockHandler(ContentHandler):
|
||||
# Check if block has LlmModel field in input schema
|
||||
has_llm_model_field = any(
|
||||
_contains_type(field.annotation, LlmModel)
|
||||
for field in block.input_schema.model_fields.values()
|
||||
for field in block_instance.input_schema.model_fields.values()
|
||||
)
|
||||
|
||||
items.append(
|
||||
@@ -272,13 +252,13 @@ class BlockHandler(ContentHandler):
|
||||
content_type=ContentType.BLOCK,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"name": display_name or block.name or block_id,
|
||||
"name": block_instance.name,
|
||||
"categories": categories_list,
|
||||
"providers": provider_names,
|
||||
"has_llm_model_field": has_llm_model_field,
|
||||
"is_integration": is_integration,
|
||||
},
|
||||
user_id=None,
|
||||
user_id=None, # Blocks are public
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -289,13 +269,22 @@ class BlockHandler(ContentHandler):
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about block embedding coverage."""
|
||||
enabled = await asyncio.to_thread(_get_enabled_blocks)
|
||||
total_blocks = len(enabled)
|
||||
from backend.blocks import get_blocks
|
||||
|
||||
all_blocks = get_blocks()
|
||||
|
||||
# Filter out disabled blocks - they're not indexed
|
||||
enabled_block_ids = [
|
||||
block_id
|
||||
for block_id, block_cls in all_blocks.items()
|
||||
if not block_cls().disabled
|
||||
]
|
||||
total_blocks = len(enabled_block_ids)
|
||||
|
||||
if total_blocks == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
block_ids = list(enabled.keys())
|
||||
block_ids = enabled_block_ids
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
|
||||
embedded_result = await query_raw_with_schema(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""
|
||||
Tests for content handlers (blocks, store agents, documentation).
|
||||
E2E tests for content handlers (blocks, store agents, documentation).
|
||||
|
||||
Tests the full flow: discovering content → generating embeddings → storing.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -13,103 +15,15 @@ from backend.api.features.store.content_handlers import (
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
_get_enabled_blocks,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_block_cache():
|
||||
"""Clear the lru_cache on _get_enabled_blocks before each test."""
|
||||
_get_enabled_blocks.cache_clear()
|
||||
yield
|
||||
_get_enabled_blocks.cache_clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper to build a mock block class that returns a pre-configured instance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_block_class(
|
||||
*,
|
||||
name: str = "Block",
|
||||
description: str = "",
|
||||
disabled: bool = False,
|
||||
categories: list[MagicMock] | None = None,
|
||||
fields: dict[str, str] | None = None,
|
||||
raise_on_init: Exception | None = None,
|
||||
) -> MagicMock:
|
||||
cls = MagicMock()
|
||||
if raise_on_init is not None:
|
||||
cls.side_effect = raise_on_init
|
||||
return cls
|
||||
inst = MagicMock()
|
||||
inst.name = name
|
||||
inst.disabled = disabled
|
||||
inst.description = description
|
||||
inst.categories = categories or []
|
||||
field_mocks = {
|
||||
fname: MagicMock(description=fdesc) for fname, fdesc in (fields or {}).items()
|
||||
}
|
||||
inst.input_schema.model_fields = field_mocks
|
||||
inst.input_schema.get_credentials_fields_info.return_value = {}
|
||||
cls.return_value = inst
|
||||
return cls
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_enabled_blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_enabled_blocks_filters_disabled():
|
||||
"""Disabled blocks are excluded."""
|
||||
blocks = {
|
||||
"enabled": _make_block_class(name="E", disabled=False),
|
||||
"disabled": _make_block_class(name="D", disabled=True),
|
||||
}
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
result = _get_enabled_blocks()
|
||||
assert list(result.keys()) == ["enabled"]
|
||||
|
||||
|
||||
def test_get_enabled_blocks_skips_broken():
|
||||
"""Blocks that raise on init are skipped without crashing."""
|
||||
blocks = {
|
||||
"good": _make_block_class(name="Good"),
|
||||
"bad": _make_block_class(raise_on_init=RuntimeError("boom")),
|
||||
}
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
result = _get_enabled_blocks()
|
||||
assert list(result.keys()) == ["good"]
|
||||
|
||||
|
||||
def test_get_enabled_blocks_cached():
|
||||
"""_get_enabled_blocks() calls get_blocks() only once across multiple calls."""
|
||||
blocks = {"b1": _make_block_class(name="B1")}
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
) as mock_get_blocks:
|
||||
result1 = _get_enabled_blocks()
|
||||
result2 = _get_enabled_blocks()
|
||||
assert result1 is result2
|
||||
mock_get_blocks.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StoreAgentHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_missing_items(mocker):
|
||||
"""Test StoreAgentHandler fetches approved agents without embeddings."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock database query
|
||||
mock_missing = [
|
||||
{
|
||||
"id": "agent-1",
|
||||
@@ -140,7 +54,9 @@ async def test_store_agent_handler_get_stats(mocker):
|
||||
"""Test StoreAgentHandler returns correct stats."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock approved count query
|
||||
mock_approved = [{"count": 50}]
|
||||
# Mock embedded count query
|
||||
mock_embedded = [{"count": 30}]
|
||||
|
||||
with patch(
|
||||
@@ -154,130 +70,74 @@ async def test_store_agent_handler_get_stats(mocker):
|
||||
assert stats["without_embeddings"] == 20
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BlockHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_missing_items():
|
||||
async def test_block_handler_get_missing_items(mocker):
|
||||
"""Test BlockHandler discovers blocks without embeddings."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"block-uuid-1": _make_block_class(
|
||||
name="CalculatorBlock",
|
||||
description="Performs calculations",
|
||||
categories=[MagicMock(value="MATH")],
|
||||
fields={"expression": "Math expression to evaluate"},
|
||||
),
|
||||
}
|
||||
# Mock get_blocks to return test blocks
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Calculator Block"
|
||||
mock_block_instance.description = "Performs calculations"
|
||||
mock_block_instance.categories = [MagicMock(value="MATH")]
|
||||
mock_block_instance.disabled = False
|
||||
mock_field = MagicMock()
|
||||
mock_field.description = "Math expression to evaluate"
|
||||
mock_block_instance.input_schema.model_fields = {"expression": mock_field}
|
||||
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-uuid-1": mock_block_class}
|
||||
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
mock_existing = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
"backend.blocks.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
return_value=mock_existing,
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "block-uuid-1"
|
||||
assert items[0].content_type == ContentType.BLOCK
|
||||
# CamelCase should be split in searchable text and metadata name
|
||||
assert "Calculator Block" in items[0].searchable_text
|
||||
assert "Performs calculations" in items[0].searchable_text
|
||||
assert "MATH" in items[0].searchable_text
|
||||
assert "expression: Math expression" in items[0].searchable_text
|
||||
assert items[0].metadata["name"] == "Calculator Block"
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_missing_items_splits_camelcase():
|
||||
"""CamelCase block names are split for better search indexing."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"ai-block": _make_block_class(name="AITextGeneratorBlock"),
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert "AI Text Generator Block" in items[0].searchable_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_missing_items_batch_size_zero():
|
||||
"""batch_size=0 returns an empty list; the DB is still queried to find missing IDs."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {"b1": _make_block_class(name="B1")}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
) as mock_query:
|
||||
items = await handler.get_missing_items(batch_size=0)
|
||||
assert items == []
|
||||
# DB query is still issued to learn which blocks lack embeddings;
|
||||
# the empty result comes from itertools.islice limiting to 0 items.
|
||||
mock_query.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_disabled_dont_exhaust_batch():
|
||||
"""Disabled blocks don't consume batch budget, so enabled blocks get indexed."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# 5 disabled + 3 enabled, batch_size=2
|
||||
blocks = {
|
||||
**{
|
||||
f"dis-{i}": _make_block_class(name=f"D{i}", disabled=True) for i in range(5)
|
||||
},
|
||||
**{f"en-{i}": _make_block_class(name=f"E{i}") for i in range(3)},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=2)
|
||||
|
||||
assert len(items) == 2
|
||||
assert all(item.content_id.startswith("en-") for item in items)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats():
|
||||
async def test_block_handler_get_stats(mocker):
|
||||
"""Test BlockHandler returns correct stats."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"block-1": _make_block_class(name="B1"),
|
||||
"block-2": _make_block_class(name="B2"),
|
||||
"block-3": _make_block_class(name="B3"),
|
||||
# Mock get_blocks - each block class returns an instance with disabled=False
|
||||
def make_mock_block_class():
|
||||
mock_class = MagicMock()
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.disabled = False
|
||||
mock_class.return_value = mock_instance
|
||||
return mock_class
|
||||
|
||||
mock_blocks = {
|
||||
"block-1": make_mock_block_class(),
|
||||
"block-2": make_mock_block_class(),
|
||||
"block-3": make_mock_block_class(),
|
||||
}
|
||||
|
||||
# Mock embedded count query (2 blocks have embeddings)
|
||||
mock_embedded = [{"count": 2}]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
"backend.blocks.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
@@ -290,123 +150,21 @@ async def test_block_handler_get_stats():
|
||||
assert stats["without_embeddings"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats_skips_broken():
|
||||
"""get_stats skips broken blocks instead of crashing."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"good": _make_block_class(name="Good"),
|
||||
"bad": _make_block_class(raise_on_init=RuntimeError("boom")),
|
||||
}
|
||||
|
||||
mock_embedded = [{"count": 1}]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 1 # only the good block
|
||||
assert stats["with_embeddings"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_none_name():
|
||||
"""When block.name is None the fallback display name logic is used."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"none-name-block": _make_block_class(
|
||||
name="placeholder", # will be overridden to None below
|
||||
description="A block with no name",
|
||||
),
|
||||
}
|
||||
# Override the name to None after construction so _make_block_class
|
||||
# doesn't interfere with the mock wiring.
|
||||
blocks["none-name-block"].return_value.name = None
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
# display_name should be "" because block.name is None
|
||||
# searchable_text should still contain the description
|
||||
assert "A block with no name" in items[0].searchable_text
|
||||
# metadata["name"] falls back to block_id when both display_name
|
||||
# and block.name are falsy, ensuring it is always a non-empty string.
|
||||
assert items[0].metadata["name"] == "none-name-block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_empty_attributes():
|
||||
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {"block-minimal": _make_block_class(name="Minimal Block")}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].searchable_text == "Minimal Block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_skips_failed_blocks():
|
||||
"""Test BlockHandler skips blocks that fail to instantiate."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"good-block": _make_block_class(name="Good Block", description="Works fine"),
|
||||
"bad-block": _make_block_class(raise_on_init=Exception("Instantiation failed")),
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "good-block"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DocumentationHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
"""Test DocumentationHandler discovers docs without embeddings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory with test files
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
|
||||
(docs_root / "guide.md").write_text("# Getting Started\n\nThis is a guide.")
|
||||
(docs_root / "api.mdx").write_text("# API Reference\n\nAPI documentation.")
|
||||
|
||||
# Mock _get_docs_root to return temp dir
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
@@ -415,6 +173,7 @@ async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
|
||||
assert len(items) == 2
|
||||
|
||||
# Check guide.md (content_id format: doc_path::section_index)
|
||||
guide_item = next(
|
||||
(item for item in items if item.content_id == "guide.md::0"), None
|
||||
)
|
||||
@@ -425,6 +184,7 @@ async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
assert guide_item.metadata["doc_title"] == "Getting Started"
|
||||
assert guide_item.user_id is None
|
||||
|
||||
# Check api.mdx (content_id format: doc_path::section_index)
|
||||
api_item = next(
|
||||
(item for item in items if item.content_id == "api.mdx::0"), None
|
||||
)
|
||||
@@ -437,12 +197,14 @@ async def test_documentation_handler_get_stats(tmp_path, mocker):
|
||||
"""Test DocumentationHandler returns correct stats."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
(docs_root / "doc1.md").write_text("# Doc 1")
|
||||
(docs_root / "doc2.md").write_text("# Doc 2")
|
||||
(docs_root / "doc3.mdx").write_text("# Doc 3")
|
||||
|
||||
# Mock embedded count query (1 doc has embedding)
|
||||
mock_embedded = [{"count": 1}]
|
||||
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
@@ -462,11 +224,13 @@ async def test_documentation_handler_title_extraction(tmp_path):
|
||||
"""Test DocumentationHandler extracts title from markdown heading."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test with heading
|
||||
doc_with_heading = tmp_path / "with_heading.md"
|
||||
doc_with_heading.write_text("# My Title\n\nContent here")
|
||||
title = handler._extract_doc_title(doc_with_heading)
|
||||
assert title == "My Title"
|
||||
|
||||
# Test without heading
|
||||
doc_without_heading = tmp_path / "no-heading.md"
|
||||
doc_without_heading.write_text("Just content, no heading")
|
||||
title = handler._extract_doc_title(doc_without_heading)
|
||||
@@ -478,6 +242,7 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
"""Test DocumentationHandler chunks markdown by headings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test document with multiple sections
|
||||
doc_with_sections = tmp_path / "sections.md"
|
||||
doc_with_sections.write_text(
|
||||
"# Document Title\n\n"
|
||||
@@ -489,6 +254,7 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
)
|
||||
sections = handler._chunk_markdown_by_headings(doc_with_sections)
|
||||
|
||||
# Should have 3 sections: intro (with doc title), section one, section two
|
||||
assert len(sections) == 3
|
||||
assert sections[0].title == "Document Title"
|
||||
assert sections[0].index == 0
|
||||
@@ -502,6 +268,7 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
assert sections[2].index == 2
|
||||
assert "Content for section two" in sections[2].content
|
||||
|
||||
# Test document without headings
|
||||
doc_no_sections = tmp_path / "no-sections.md"
|
||||
doc_no_sections.write_text("Just plain content without any headings.")
|
||||
sections = handler._chunk_markdown_by_headings(doc_no_sections)
|
||||
@@ -515,39 +282,21 @@ async def test_documentation_handler_section_content_ids():
|
||||
"""Test DocumentationHandler creates and parses section content IDs."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test making content ID
|
||||
content_id = handler._make_section_content_id("docs/guide.md", 2)
|
||||
assert content_id == "docs/guide.md::2"
|
||||
|
||||
# Test parsing content ID
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/guide.md::2")
|
||||
assert doc_path == "docs/guide.md"
|
||||
assert section_index == 2
|
||||
|
||||
# Test parsing legacy format (no section index)
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/old-format.md")
|
||||
assert doc_path == "docs/old-format.md"
|
||||
assert section_index == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_missing_docs_directory():
|
||||
"""Test DocumentationHandler handles missing docs directory gracefully."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
fake_path = Path("/nonexistent/docs")
|
||||
with patch.object(handler, "_get_docs_root", return_value=fake_path):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
assert items == []
|
||||
|
||||
stats = await handler.get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["with_embeddings"] == 0
|
||||
assert stats["without_embeddings"] == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handlers_registry():
|
||||
"""Test all content types are registered."""
|
||||
@@ -558,3 +307,88 @@ async def test_content_handlers_registry():
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_empty_attributes():
|
||||
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock block with empty values (all attributes exist but are falsy)
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Minimal Block"
|
||||
mock_block_instance.disabled = False
|
||||
mock_block_instance.description = ""
|
||||
mock_block_instance.categories = set()
|
||||
mock_block_instance.input_schema.model_fields = {}
|
||||
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-minimal": mock_block_class}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].searchable_text == "Minimal Block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_skips_failed_blocks():
|
||||
"""Test BlockHandler skips blocks that fail to instantiate."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock one good block and one bad block
|
||||
good_block = MagicMock()
|
||||
good_instance = MagicMock()
|
||||
good_instance.name = "Good Block"
|
||||
good_instance.description = "Works fine"
|
||||
good_instance.categories = []
|
||||
good_instance.disabled = False
|
||||
good_instance.input_schema.model_fields = {}
|
||||
good_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
good_block.return_value = good_instance
|
||||
|
||||
bad_block = MagicMock()
|
||||
bad_block.side_effect = Exception("Instantiation failed")
|
||||
|
||||
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
# Should only get the good block
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "good-block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_missing_docs_directory():
|
||||
"""Test DocumentationHandler handles missing docs directory gracefully."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Mock _get_docs_root to return non-existent path
|
||||
fake_path = Path("/nonexistent/docs")
|
||||
with patch.object(handler, "_get_docs_root", return_value=fake_path):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
assert items == []
|
||||
|
||||
stats = await handler.get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["with_embeddings"] == 0
|
||||
assert stats["without_embeddings"] == 0
|
||||
|
||||
@@ -9,7 +9,7 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
from backend.data.db import query_raw_with_schema, transaction
|
||||
from backend.data.db import transaction
|
||||
from backend.data.graph import (
|
||||
GraphModel,
|
||||
GraphModelWithoutNodes,
|
||||
@@ -104,8 +104,7 @@ async def get_store_agents(
|
||||
# search_used_hybrid remains False, will use fallback path below
|
||||
|
||||
# Convert hybrid search results (dict format) if hybrid succeeded
|
||||
# Fall through to direct DB search if hybrid returned nothing
|
||||
if search_used_hybrid and agents:
|
||||
if search_used_hybrid:
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
@@ -131,20 +130,52 @@ async def get_store_agents(
|
||||
)
|
||||
continue
|
||||
|
||||
if not search_used_hybrid or not agents:
|
||||
# Fallback path: direct DB query with optional tsvector search.
|
||||
# This mirrors the original pre-hybrid-search implementation.
|
||||
store_agents, total = await _fallback_store_agent_search(
|
||||
search_query=search_query,
|
||||
featured=featured,
|
||||
creators=creators,
|
||||
category=category,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
if not search_used_hybrid:
|
||||
# Fallback path - use basic search or no search
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
where_clause["creator_username"] = {"in": creators}
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
# Add basic text search if search_query provided but hybrid failed
|
||||
if search_query:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"sub_heading": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
if sorted_by == StoreAgentsSortOptions.RATING:
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.RUNS:
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.NAME:
|
||||
order_by.append({"agent_name": "asc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.UPDATED_AT:
|
||||
order_by.append({"updated_at": "desc"})
|
||||
|
||||
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in db_agents:
|
||||
try:
|
||||
store_agents.append(store_model.StoreAgent.from_db(agent))
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing StoreAgent from db: {e}")
|
||||
continue
|
||||
|
||||
logger.debug(f"Found {len(store_agents)} agents")
|
||||
return store_model.StoreAgentsResponse(
|
||||
agents=store_agents,
|
||||
@@ -164,126 +195,6 @@ async def get_store_agents(
|
||||
# await log_search_term(search_query=search_term)
|
||||
|
||||
|
||||
async def _fallback_store_agent_search(
|
||||
*,
|
||||
search_query: str | None,
|
||||
featured: bool,
|
||||
creators: list[str] | None,
|
||||
category: str | None,
|
||||
sorted_by: StoreAgentsSortOptions | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
) -> tuple[list[store_model.StoreAgent], int]:
|
||||
"""Direct DB search fallback when hybrid search is unavailable or empty.
|
||||
|
||||
Uses ad-hoc to_tsvector/plainto_tsquery with ts_rank_cd for text search,
|
||||
matching the quality of the original pre-hybrid-search implementation.
|
||||
Falls back to simple listing when no search query is provided.
|
||||
"""
|
||||
if not search_query:
|
||||
# No search query — use Prisma for simple filtered listing
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
where_clause["creator_username"] = {"in": creators}
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
order_by = []
|
||||
if sorted_by == StoreAgentsSortOptions.RATING:
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.RUNS:
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.NAME:
|
||||
order_by.append({"agent_name": "asc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.UPDATED_AT:
|
||||
order_by.append({"updated_at": "desc"})
|
||||
|
||||
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
return [store_model.StoreAgent.from_db(a) for a in db_agents], total
|
||||
|
||||
# Text search using ad-hoc tsvector on StoreAgent view fields
|
||||
params: list[Any] = [search_query]
|
||||
filters = ["sa.is_available = true"]
|
||||
param_idx = 2
|
||||
|
||||
if featured:
|
||||
filters.append("sa.featured = true")
|
||||
if creators:
|
||||
params.append(creators)
|
||||
filters.append(f"sa.creator_username = ANY(${param_idx})")
|
||||
param_idx += 1
|
||||
if category:
|
||||
params.append(category)
|
||||
filters.append(f"${param_idx} = ANY(sa.categories)")
|
||||
param_idx += 1
|
||||
|
||||
where_sql = " AND ".join(filters)
|
||||
|
||||
params.extend([page_size, (page - 1) * page_size])
|
||||
limit_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
offset_param = f"${param_idx}"
|
||||
|
||||
sql = f"""
|
||||
WITH ranked AS (
|
||||
SELECT sa.*,
|
||||
ts_rank_cd(
|
||||
to_tsvector('english',
|
||||
COALESCE(sa.agent_name, '') || ' ' ||
|
||||
COALESCE(sa.sub_heading, '') || ' ' ||
|
||||
COALESCE(sa.description, '')
|
||||
),
|
||||
plainto_tsquery('english', $1)
|
||||
) AS rank,
|
||||
COUNT(*) OVER () AS total_count
|
||||
FROM {{schema_prefix}}"StoreAgent" sa
|
||||
WHERE {where_sql}
|
||||
AND to_tsvector('english',
|
||||
COALESCE(sa.agent_name, '') || ' ' ||
|
||||
COALESCE(sa.sub_heading, '') || ' ' ||
|
||||
COALESCE(sa.description, '')
|
||||
) @@ plainto_tsquery('english', $1)
|
||||
)
|
||||
SELECT * FROM ranked
|
||||
ORDER BY rank DESC
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
results = await query_raw_with_schema(sql, *params)
|
||||
total = results[0]["total_count"] if results else 0
|
||||
|
||||
store_agents = []
|
||||
for row in results:
|
||||
try:
|
||||
store_agents.append(
|
||||
store_model.StoreAgent(
|
||||
slug=row["slug"],
|
||||
agent_name=row["agent_name"],
|
||||
agent_image=row["agent_image"][0] if row["agent_image"] else "",
|
||||
creator=row["creator_username"] or "Needs Profile",
|
||||
creator_avatar=row["creator_avatar"] or "",
|
||||
sub_heading=row["sub_heading"],
|
||||
description=row["description"],
|
||||
runs=row["runs"],
|
||||
rating=row["rating"],
|
||||
agent_graph_id=row.get("graph_id", ""),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing StoreAgent from fallback search: {e}")
|
||||
continue
|
||||
|
||||
return store_agents, total
|
||||
|
||||
|
||||
async def log_search_term(search_query: str):
|
||||
"""Log a search term to the database"""
|
||||
|
||||
@@ -1228,21 +1139,16 @@ async def review_store_submission(
|
||||
},
|
||||
)
|
||||
|
||||
# Generate embedding for approved listing (best-effort)
|
||||
try:
|
||||
await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=submission.name,
|
||||
description=submission.description,
|
||||
sub_heading=submission.subHeading,
|
||||
categories=submission.categories,
|
||||
tx=tx,
|
||||
)
|
||||
except Exception as emb_err:
|
||||
logger.warning(
|
||||
f"Could not generate embedding for listing "
|
||||
f"{store_listing_version_id}: {emb_err}"
|
||||
)
|
||||
# Generate embedding for approved listing (blocking - admin operation)
|
||||
# Inside transaction: if embedding fails, entire transaction rolls back
|
||||
await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=submission.name,
|
||||
description=submission.description,
|
||||
sub_heading=submission.subHeading,
|
||||
categories=submission.categories,
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": submission.storeListingId},
|
||||
|
||||
@@ -15,7 +15,6 @@ from prisma.enums import ContentType
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.api.features.store.content_handlers import CONTENT_HANDLERS
|
||||
from backend.blocks import get_blocks
|
||||
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
|
||||
from backend.util.clients import get_openai_client
|
||||
from backend.util.json import dumps
|
||||
@@ -663,6 +662,8 @@ async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
||||
)
|
||||
current_ids = {row["id"] for row in valid_agents}
|
||||
elif content_type == ContentType.BLOCK:
|
||||
from backend.blocks import get_blocks
|
||||
|
||||
current_ids = set(get_blocks().keys())
|
||||
elif content_type == ContentType.DOCUMENTATION:
|
||||
# Use DocumentationHandler to get section-based content IDs
|
||||
|
||||
@@ -31,10 +31,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Tokenize text for BM25."""
|
||||
"""Simple tokenizer for BM25 - lowercase and split on non-alphanumeric."""
|
||||
if not text:
|
||||
return []
|
||||
return re.findall(r"\b\w+\b", text.lower())
|
||||
# Lowercase and split on non-alphanumeric characters
|
||||
tokens = re.findall(r"\b\w+\b", text.lower())
|
||||
return tokens
|
||||
|
||||
|
||||
def bm25_rerank(
|
||||
|
||||
@@ -14,27 +14,9 @@ from backend.api.features.store.hybrid_search import (
|
||||
HybridSearchWeights,
|
||||
UnifiedSearchWeights,
|
||||
hybrid_search,
|
||||
tokenize,
|
||||
unified_hybrid_search,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tokenize (BM25)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text, expected",
|
||||
[
|
||||
("AITextGeneratorBlock", ["aitextgeneratorblock"]),
|
||||
("hello world", ["hello", "world"]),
|
||||
("", []),
|
||||
("HTTPRequest", ["httprequest"]),
|
||||
],
|
||||
)
|
||||
def test_tokenize(input_text: str, expected: list[str]):
|
||||
assert tokenize(input_text) == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import tempfile
|
||||
import urllib.parse
|
||||
|
||||
import autogpt_libs.auth
|
||||
@@ -258,18 +259,21 @@ async def get_graph_meta_by_store_listing_version_id(
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str,
|
||||
) -> fastapi.responses.Response:
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""Download agent graph file for a specific marketplace listing version"""
|
||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
return fastapi.responses.Response(
|
||||
content=backend.util.json.dumps(graph_data),
|
||||
media_type="application/json",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{file_name}"',
|
||||
},
|
||||
)
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Backward-compatibility shim — ``split_camelcase`` now lives in backend.util.text."""
|
||||
|
||||
from backend.util.text import split_camelcase # noqa: F401
|
||||
|
||||
__all__ = ["split_camelcase"]
|
||||
@@ -1,49 +0,0 @@
|
||||
"""Tests for split_camelcase (now in backend.util.text)."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.text import split_camelcase
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# split_camelcase
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text, expected",
|
||||
[
|
||||
("AITextGeneratorBlock", "AI Text Generator Block"),
|
||||
("HTTPRequestBlock", "HTTP Request Block"),
|
||||
("simpleWord", "simple Word"),
|
||||
("already spaced", "already spaced"),
|
||||
("XMLParser", "XML Parser"),
|
||||
("getHTTPResponse", "get HTTP Response"),
|
||||
("Block", "Block"),
|
||||
("", ""),
|
||||
("OAuth2Block", "OAuth2 Block"),
|
||||
("IOError", "IO Error"),
|
||||
("getHTTPSResponse", "get HTTPS Response"),
|
||||
# Known limitation: single-letter uppercase prefixes are NOT split.
|
||||
# "ABlock" stays "ABlock" because the algorithm requires the left
|
||||
# part of an uppercase run to retain at least 2 uppercase chars.
|
||||
("ABlock", "ABlock"),
|
||||
# Digit-to-uppercase transitions
|
||||
("Base64Encoder", "Base64 Encoder"),
|
||||
("UTF8Decoder", "UTF8 Decoder"),
|
||||
# Pure digits — no camelCase boundaries to split
|
||||
("123", "123"),
|
||||
# Known limitation: single-letter uppercase segments after digits
|
||||
# are not split from the following word. "3D" is only 1 uppercase
|
||||
# char so the uppercase-run rule cannot fire, producing "3 DRenderer"
|
||||
# rather than the ideal "3D Renderer".
|
||||
("3DRenderer", "3 DRenderer"),
|
||||
# Exception list — compound terms that should stay together
|
||||
("YouTubeBlock", "YouTube Block"),
|
||||
("OpenAIBlock", "OpenAI Block"),
|
||||
("AutoGPTAgent", "AutoGPT Agent"),
|
||||
("GitHubIntegration", "GitHub Integration"),
|
||||
("LinkedInBlock", "LinkedIn Block"),
|
||||
],
|
||||
)
|
||||
def test_split_camelcase(input_text: str, expected: str):
|
||||
assert split_camelcase(input_text) == expected
|
||||
@@ -94,8 +94,3 @@ class NotificationPayload(pydantic.BaseModel):
|
||||
|
||||
class OnboardingNotificationPayload(NotificationPayload):
|
||||
step: OnboardingStep | None
|
||||
|
||||
|
||||
class CopilotCompletionPayload(NotificationPayload):
|
||||
session_id: str
|
||||
status: Literal["completed", "failed"]
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
"""
|
||||
Shared configuration for all AgentMail blocks.
|
||||
"""
|
||||
|
||||
from agentmail import AsyncAgentMail
|
||||
|
||||
from backend.sdk import APIKeyCredentials, ProviderBuilder, SecretStr
|
||||
|
||||
agent_mail = (
|
||||
ProviderBuilder("agent_mail")
|
||||
.with_api_key("AGENTMAIL_API_KEY", "AgentMail API Key")
|
||||
.build()
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="agent_mail",
|
||||
title="Mock AgentMail API Key",
|
||||
api_key=SecretStr("mock-agentmail-api-key"),
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def _client(credentials: APIKeyCredentials) -> AsyncAgentMail:
|
||||
"""Create an AsyncAgentMail client from credentials."""
|
||||
return AsyncAgentMail(api_key=credentials.api_key.get_secret_value())
|
||||
@@ -1,211 +0,0 @@
|
||||
"""
|
||||
AgentMail Attachment blocks — download file attachments from messages and threads.
|
||||
|
||||
Attachments are files associated with messages (PDFs, CSVs, images, etc.).
|
||||
To send attachments, include them in the attachments parameter when using
|
||||
AgentMailSendMessageBlock or AgentMailReplyToMessageBlock.
|
||||
|
||||
To download, first get the attachment_id from a message's attachments array,
|
||||
then use these blocks to retrieve the file content as base64.
|
||||
"""
|
||||
|
||||
import base64
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailGetMessageAttachmentBlock(Block):
|
||||
"""
|
||||
Download a file attachment from a specific email message.
|
||||
|
||||
Retrieves the raw file content and returns it as base64-encoded data.
|
||||
First get the attachment_id from a message object's attachments array,
|
||||
then use this block to download the file.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the message belongs to"
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID containing the attachment"
|
||||
)
|
||||
attachment_id: str = SchemaField(
|
||||
description="Attachment ID to download (from the message's attachments array)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
content_base64: str = SchemaField(
|
||||
description="File content encoded as a base64 string. Decode with base64.b64decode() to get raw bytes."
|
||||
)
|
||||
attachment_id: str = SchemaField(
|
||||
description="The attachment ID that was downloaded"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a283ffc4-8087-4c3d-9135-8f26b86742ec",
|
||||
description="Download a file attachment from an email message. Returns base64-encoded file content.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"attachment_id": "test-attach",
|
||||
},
|
||||
test_output=[
|
||||
("content_base64", "dGVzdA=="),
|
||||
("attachment_id", "test-attach"),
|
||||
],
|
||||
test_mock={
|
||||
"get_attachment": lambda *a, **kw: b"test",
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_attachment(
|
||||
credentials: APIKeyCredentials,
|
||||
inbox_id: str,
|
||||
message_id: str,
|
||||
attachment_id: str,
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.get_attachment(
|
||||
inbox_id=inbox_id,
|
||||
message_id=message_id,
|
||||
attachment_id=attachment_id,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.get_attachment(
|
||||
credentials=credentials,
|
||||
inbox_id=input_data.inbox_id,
|
||||
message_id=input_data.message_id,
|
||||
attachment_id=input_data.attachment_id,
|
||||
)
|
||||
if isinstance(data, bytes):
|
||||
encoded = base64.b64encode(data).decode()
|
||||
elif isinstance(data, str):
|
||||
encoded = base64.b64encode(data.encode("utf-8")).decode()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unexpected attachment data type: {type(data).__name__}"
|
||||
)
|
||||
|
||||
yield "content_base64", encoded
|
||||
yield "attachment_id", input_data.attachment_id
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetThreadAttachmentBlock(Block):
|
||||
"""
|
||||
Download a file attachment from a conversation thread.
|
||||
|
||||
Same as GetMessageAttachment but looks up by thread ID instead of
|
||||
message ID. Useful when you know the thread but not the specific
|
||||
message containing the attachment.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the thread belongs to"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID containing the attachment")
|
||||
attachment_id: str = SchemaField(
|
||||
description="Attachment ID to download (from a message's attachments array within the thread)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
content_base64: str = SchemaField(
|
||||
description="File content encoded as a base64 string. Decode with base64.b64decode() to get raw bytes."
|
||||
)
|
||||
attachment_id: str = SchemaField(
|
||||
description="The attachment ID that was downloaded"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="06b6a4c4-9d71-4992-9e9c-cf3b352763b5",
|
||||
description="Download a file attachment from a conversation thread. Returns base64-encoded file content.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"thread_id": "test-thread",
|
||||
"attachment_id": "test-attach",
|
||||
},
|
||||
test_output=[
|
||||
("content_base64", "dGVzdA=="),
|
||||
("attachment_id", "test-attach"),
|
||||
],
|
||||
test_mock={
|
||||
"get_attachment": lambda *a, **kw: b"test",
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_attachment(
|
||||
credentials: APIKeyCredentials,
|
||||
inbox_id: str,
|
||||
thread_id: str,
|
||||
attachment_id: str,
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.threads.get_attachment(
|
||||
inbox_id=inbox_id,
|
||||
thread_id=thread_id,
|
||||
attachment_id=attachment_id,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.get_attachment(
|
||||
credentials=credentials,
|
||||
inbox_id=input_data.inbox_id,
|
||||
thread_id=input_data.thread_id,
|
||||
attachment_id=input_data.attachment_id,
|
||||
)
|
||||
if isinstance(data, bytes):
|
||||
encoded = base64.b64encode(data).decode()
|
||||
elif isinstance(data, str):
|
||||
encoded = base64.b64encode(data.encode("utf-8")).decode()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unexpected attachment data type: {type(data).__name__}"
|
||||
)
|
||||
|
||||
yield "content_base64", encoded
|
||||
yield "attachment_id", input_data.attachment_id
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,678 +0,0 @@
|
||||
"""
|
||||
AgentMail Draft blocks — create, get, list, update, send, and delete drafts.
|
||||
|
||||
A Draft is an unsent message that can be reviewed, edited, and sent later.
|
||||
Drafts enable human-in-the-loop review, scheduled sending (via send_at),
|
||||
and complex multi-step email composition workflows.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailCreateDraftBlock(Block):
|
||||
"""
|
||||
Create a draft email in an AgentMail inbox for review or scheduled sending.
|
||||
|
||||
Drafts let agents prepare emails without sending immediately. Use send_at
|
||||
to schedule automatic sending at a future time (ISO 8601 format).
|
||||
Scheduled drafts are auto-labeled 'scheduled' and can be cancelled by
|
||||
deleting the draft.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to create the draft in"
|
||||
)
|
||||
to: list[str] = SchemaField(
|
||||
description="Recipient email addresses (e.g. ['user@example.com'])"
|
||||
)
|
||||
subject: str = SchemaField(description="Email subject line", default="")
|
||||
text: str = SchemaField(description="Plain text body of the draft", default="")
|
||||
html: str = SchemaField(
|
||||
description="Rich HTML body of the draft", default="", advanced=True
|
||||
)
|
||||
cc: list[str] = SchemaField(
|
||||
description="CC recipient email addresses",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
bcc: list[str] = SchemaField(
|
||||
description="BCC recipient email addresses",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
in_reply_to: str = SchemaField(
|
||||
description="Message ID this draft replies to, for threading follow-up drafts",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
send_at: str = SchemaField(
|
||||
description="Schedule automatic sending at this ISO 8601 datetime (e.g. '2025-01-15T09:00:00Z'). Leave empty for manual send.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
draft_id: str = SchemaField(
|
||||
description="Unique identifier of the created draft"
|
||||
)
|
||||
send_status: str = SchemaField(
|
||||
description="'scheduled' if send_at was set, empty otherwise. Values: scheduled, sending, failed.",
|
||||
default="",
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete draft object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="25ac9086-69fd-48b8-b910-9dbe04b8f3bd",
|
||||
description="Create a draft email for review or scheduled sending. Use send_at for automatic future delivery.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"to": ["user@example.com"],
|
||||
},
|
||||
test_output=[
|
||||
("draft_id", "mock-draft-id"),
|
||||
("send_status", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_draft": lambda *a, **kw: type(
|
||||
"Draft",
|
||||
(),
|
||||
{
|
||||
"draft_id": "mock-draft-id",
|
||||
"send_status": "",
|
||||
"model_dump": lambda self: {"draft_id": "mock-draft-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_draft(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.create(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"to": input_data.to}
|
||||
if input_data.subject:
|
||||
params["subject"] = input_data.subject
|
||||
if input_data.text:
|
||||
params["text"] = input_data.text
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
if input_data.cc:
|
||||
params["cc"] = input_data.cc
|
||||
if input_data.bcc:
|
||||
params["bcc"] = input_data.bcc
|
||||
if input_data.in_reply_to:
|
||||
params["in_reply_to"] = input_data.in_reply_to
|
||||
if input_data.send_at:
|
||||
params["send_at"] = input_data.send_at
|
||||
|
||||
draft = await self.create_draft(credentials, input_data.inbox_id, **params)
|
||||
result = draft.model_dump()
|
||||
|
||||
yield "draft_id", draft.draft_id
|
||||
yield "send_status", draft.send_status or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetDraftBlock(Block):
|
||||
"""
|
||||
Retrieve a specific draft from an AgentMail inbox.
|
||||
|
||||
Returns the draft contents including recipients, subject, body, and
|
||||
scheduled send status. Use this to review a draft before approving it.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(description="Draft ID to retrieve")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
draft_id: str = SchemaField(description="Unique identifier of the draft")
|
||||
subject: str = SchemaField(description="Draft subject line", default="")
|
||||
send_status: str = SchemaField(
|
||||
description="Scheduled send status: 'scheduled', 'sending', 'failed', or empty",
|
||||
default="",
|
||||
)
|
||||
send_at: str = SchemaField(
|
||||
description="Scheduled send time (ISO 8601) if set", default=""
|
||||
)
|
||||
result: dict = SchemaField(description="Complete draft object with all fields")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8e57780d-dc25-43d4-a0f4-1f02877b09fb",
|
||||
description="Retrieve a draft email to review its contents, recipients, and scheduled send status.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[
|
||||
("draft_id", "test-draft"),
|
||||
("subject", ""),
|
||||
("send_status", ""),
|
||||
("send_at", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_draft": lambda *a, **kw: type(
|
||||
"Draft",
|
||||
(),
|
||||
{
|
||||
"draft_id": "test-draft",
|
||||
"subject": "",
|
||||
"send_status": "",
|
||||
"send_at": "",
|
||||
"model_dump": lambda self: {"draft_id": "test-draft"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_draft(credentials: APIKeyCredentials, inbox_id: str, draft_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.get(inbox_id=inbox_id, draft_id=draft_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
draft = await self.get_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id
|
||||
)
|
||||
result = draft.model_dump()
|
||||
|
||||
yield "draft_id", draft.draft_id
|
||||
yield "subject", draft.subject or ""
|
||||
yield "send_status", draft.send_status or ""
|
||||
yield "send_at", draft.send_at or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListDraftsBlock(Block):
|
||||
"""
|
||||
List all drafts in an AgentMail inbox with optional label filtering.
|
||||
|
||||
Use labels=['scheduled'] to find all drafts queued for future sending.
|
||||
Useful for building approval dashboards or monitoring pending outreach.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to list drafts from"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of drafts to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Filter drafts by labels (e.g. ['scheduled'] for pending sends)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
drafts: list[dict] = SchemaField(
|
||||
description="List of draft objects with subject, recipients, send_status, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of drafts returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e84883b7-7c39-4c5c-88e8-0a72b078ea63",
|
||||
description="List drafts in an AgentMail inbox. Filter by labels=['scheduled'] to find pending sends.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("drafts", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_drafts": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"drafts": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_drafts(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.list(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_drafts(
|
||||
credentials, input_data.inbox_id, **params
|
||||
)
|
||||
drafts = [d.model_dump() for d in response.drafts]
|
||||
|
||||
yield "drafts", drafts
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailUpdateDraftBlock(Block):
|
||||
"""
|
||||
Update an existing draft's content, recipients, or scheduled send time.
|
||||
|
||||
Use this to reschedule a draft (change send_at), modify recipients,
|
||||
or edit the subject/body before sending. To cancel a scheduled send,
|
||||
delete the draft instead.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(description="Draft ID to update")
|
||||
to: Optional[list[str]] = SchemaField(
|
||||
description="Updated recipient email addresses (replaces existing list). Omit to keep current value.",
|
||||
default=None,
|
||||
)
|
||||
subject: Optional[str] = SchemaField(
|
||||
description="Updated subject line. Omit to keep current value.",
|
||||
default=None,
|
||||
)
|
||||
text: Optional[str] = SchemaField(
|
||||
description="Updated plain text body. Omit to keep current value.",
|
||||
default=None,
|
||||
)
|
||||
html: Optional[str] = SchemaField(
|
||||
description="Updated HTML body. Omit to keep current value.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
send_at: Optional[str] = SchemaField(
|
||||
description="Reschedule: new ISO 8601 send time (e.g. '2025-01-20T14:00:00Z'). Omit to keep current value.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
draft_id: str = SchemaField(description="The updated draft ID")
|
||||
send_status: str = SchemaField(description="Updated send status", default="")
|
||||
result: dict = SchemaField(description="Complete updated draft object")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="351f6e51-695a-421a-9032-46a587b10336",
|
||||
description="Update a draft's content, recipients, or scheduled send time. Use to reschedule or edit before sending.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[
|
||||
("draft_id", "test-draft"),
|
||||
("send_status", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"update_draft": lambda *a, **kw: type(
|
||||
"Draft",
|
||||
(),
|
||||
{
|
||||
"draft_id": "test-draft",
|
||||
"send_status": "",
|
||||
"model_dump": lambda self: {"draft_id": "test-draft"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_draft(
|
||||
credentials: APIKeyCredentials, inbox_id: str, draft_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.update(
|
||||
inbox_id=inbox_id, draft_id=draft_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.to is not None:
|
||||
params["to"] = input_data.to
|
||||
if input_data.subject is not None:
|
||||
params["subject"] = input_data.subject
|
||||
if input_data.text is not None:
|
||||
params["text"] = input_data.text
|
||||
if input_data.html is not None:
|
||||
params["html"] = input_data.html
|
||||
if input_data.send_at is not None:
|
||||
params["send_at"] = input_data.send_at
|
||||
|
||||
draft = await self.update_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id, **params
|
||||
)
|
||||
result = draft.model_dump()
|
||||
|
||||
yield "draft_id", draft.draft_id
|
||||
yield "send_status", draft.send_status or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailSendDraftBlock(Block):
|
||||
"""
|
||||
Send a draft immediately, converting it into a delivered message.
|
||||
|
||||
The draft is deleted after successful sending and becomes a regular
|
||||
message with a message_id. Use this for human-in-the-loop approval
|
||||
workflows: agent creates draft, human reviews, then this block sends it.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(description="Draft ID to send now")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID of the now-sent email (draft is deleted)"
|
||||
)
|
||||
thread_id: str = SchemaField(
|
||||
description="Thread ID the sent message belongs to"
|
||||
)
|
||||
result: dict = SchemaField(description="Complete sent message object")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="37c39e83-475d-4b3d-843a-d923d001b85a",
|
||||
description="Send a draft immediately, converting it into a delivered message. The draft is deleted after sending.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-msg-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"send_draft": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-msg-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {"message_id": "mock-msg-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def send_draft(credentials: APIKeyCredentials, inbox_id: str, draft_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.send(inbox_id=inbox_id, draft_id=draft_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
msg = await self.send_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id
|
||||
)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "thread_id", msg.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteDraftBlock(Block):
|
||||
"""
|
||||
Delete a draft from an AgentMail inbox. Also cancels any scheduled send.
|
||||
|
||||
If the draft was scheduled with send_at, deleting it cancels the
|
||||
scheduled delivery. This is the way to cancel a scheduled email.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(
|
||||
description="Draft ID to delete (also cancels scheduled sends)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the draft was successfully deleted/cancelled"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9023eb99-3e2f-4def-808b-d9c584b3d9e7",
|
||||
description="Delete a draft or cancel a scheduled email. Removes the draft permanently.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_draft": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_draft(
|
||||
credentials: APIKeyCredentials, inbox_id: str, draft_id: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
await client.inboxes.drafts.delete(inbox_id=inbox_id, draft_id=draft_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id
|
||||
)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListOrgDraftsBlock(Block):
|
||||
"""
|
||||
List all drafts across every inbox in your organization.
|
||||
|
||||
Returns drafts from all inboxes in one query. Perfect for building
|
||||
a central approval dashboard where a human supervisor can review
|
||||
and approve any draft created by any agent.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of drafts to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
drafts: list[dict] = SchemaField(
|
||||
description="List of draft objects from all inboxes in the organization"
|
||||
)
|
||||
count: int = SchemaField(description="Number of drafts returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ed7558ae-3a07-45f5-af55-a25fe88c9971",
|
||||
description="List all drafts across every inbox in your organization. Use for central approval dashboards.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("drafts", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_org_drafts": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"drafts": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_org_drafts(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.drafts.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_org_drafts(credentials, **params)
|
||||
drafts = [d.model_dump() for d in response.drafts]
|
||||
|
||||
yield "drafts", drafts
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,414 +0,0 @@
|
||||
"""
|
||||
AgentMail Inbox blocks — create, get, list, update, and delete inboxes.
|
||||
|
||||
An Inbox is a fully programmable email account for AI agents. Each inbox gets
|
||||
a unique email address and can send, receive, and manage emails via the
|
||||
AgentMail API. You can create thousands of inboxes on demand.
|
||||
"""
|
||||
|
||||
from agentmail.inboxes.types import CreateInboxRequest
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailCreateInboxBlock(Block):
|
||||
"""
|
||||
Create a new email inbox for an AI agent via AgentMail.
|
||||
|
||||
Each inbox gets a unique email address (e.g. username@agentmail.to).
|
||||
If username and domain are not provided, AgentMail auto-generates them.
|
||||
Use custom domains by specifying the domain field.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
username: str = SchemaField(
|
||||
description="Local part of the email address (e.g. 'support' for support@domain.com). Leave empty to auto-generate.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Email domain (e.g. 'mydomain.com'). Defaults to agentmail.to if empty.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
display_name: str = SchemaField(
|
||||
description="Friendly name shown in the 'From' field of sent emails (e.g. 'Support Agent')",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(
|
||||
description="Unique identifier for the created inbox (also the email address)"
|
||||
)
|
||||
email_address: str = SchemaField(
|
||||
description="Full email address of the inbox (e.g. support@agentmail.to)"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7a8ac219-c6ec-4eec-a828-81af283ce04c",
|
||||
description="Create a new email inbox for an AI agent via AgentMail. Each inbox gets a unique address and can send/receive emails.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("inbox_id", "mock-inbox-id"),
|
||||
("email_address", "mock-inbox-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "mock-inbox-id",
|
||||
"model_dump": lambda self: {"inbox_id": "mock-inbox-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_inbox(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.create(request=CreateInboxRequest(**params))
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.username:
|
||||
params["username"] = input_data.username
|
||||
if input_data.domain:
|
||||
params["domain"] = input_data.domain
|
||||
if input_data.display_name:
|
||||
params["display_name"] = input_data.display_name
|
||||
|
||||
inbox = await self.create_inbox(credentials, **params)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "email_address", inbox.inbox_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetInboxBlock(Block):
|
||||
"""
|
||||
Retrieve details of an existing AgentMail inbox by its ID or email address.
|
||||
|
||||
Returns the inbox metadata including email address, display name, and
|
||||
configuration. Use this to check if an inbox exists or get its properties.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to look up (e.g. 'support@agentmail.to')"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(description="Unique identifier of the inbox")
|
||||
email_address: str = SchemaField(description="Full email address of the inbox")
|
||||
display_name: str = SchemaField(
|
||||
description="Friendly name shown in the 'From' field", default=""
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b858f62b-6c12-4736-aaf2-dbc5a9281320",
|
||||
description="Retrieve details of an existing AgentMail inbox including its email address, display name, and configuration.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("inbox_id", "test-inbox"),
|
||||
("email_address", "test-inbox"),
|
||||
("display_name", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "test-inbox",
|
||||
"display_name": "",
|
||||
"model_dump": lambda self: {"inbox_id": "test-inbox"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_inbox(credentials: APIKeyCredentials, inbox_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.get(inbox_id=inbox_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
inbox = await self.get_inbox(credentials, input_data.inbox_id)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "email_address", inbox.inbox_id
|
||||
yield "display_name", inbox.display_name or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListInboxesBlock(Block):
|
||||
"""
|
||||
List all email inboxes in your AgentMail organization.
|
||||
|
||||
Returns a paginated list of all inboxes with their metadata.
|
||||
Use page_token for pagination when you have many inboxes.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of inboxes to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page of results",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inboxes: list[dict] = SchemaField(
|
||||
description="List of inbox objects, each containing inbox_id, email_address, display_name, etc."
|
||||
)
|
||||
count: int = SchemaField(
|
||||
description="Total number of inboxes in your organization"
|
||||
)
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token to pass as page_token to get the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="cfd84a06-2121-4cef-8d14-8badf52d22f0",
|
||||
description="List all email inboxes in your AgentMail organization with pagination support.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("inboxes", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_inboxes": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"inboxes": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_inboxes(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_inboxes(credentials, **params)
|
||||
inboxes = [i.model_dump() for i in response.inboxes]
|
||||
|
||||
yield "inboxes", inboxes
|
||||
yield "count", (c if (c := response.count) is not None else len(inboxes))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailUpdateInboxBlock(Block):
|
||||
"""
|
||||
Update the display name of an existing AgentMail inbox.
|
||||
|
||||
Changes the friendly name shown in the 'From' field when emails are sent
|
||||
from this inbox. The email address itself cannot be changed.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to update (e.g. 'support@agentmail.to')"
|
||||
)
|
||||
display_name: str = SchemaField(
|
||||
description="New display name for the inbox (e.g. 'Customer Support Bot')"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(description="The updated inbox ID")
|
||||
result: dict = SchemaField(
|
||||
description="Complete updated inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="59b49f59-a6d1-4203-94c0-3908adac50b6",
|
||||
description="Update the display name of an AgentMail inbox. Changes the 'From' name shown when emails are sent.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"display_name": "Updated",
|
||||
},
|
||||
test_output=[
|
||||
("inbox_id", "test-inbox"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"update_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "test-inbox",
|
||||
"model_dump": lambda self: {"inbox_id": "test-inbox"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_inbox(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.update(inbox_id=inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
inbox = await self.update_inbox(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
display_name=input_data.display_name,
|
||||
)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteInboxBlock(Block):
|
||||
"""
|
||||
Permanently delete an AgentMail inbox and all its data.
|
||||
|
||||
This removes the inbox, all its messages, threads, and drafts.
|
||||
This action cannot be undone. The email address will no longer
|
||||
receive or send emails.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to permanently delete"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the inbox was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ade970ae-8428-4a7b-9278-b52054dbf535",
|
||||
description="Permanently delete an AgentMail inbox and all its messages, threads, and drafts. This action cannot be undone.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_inbox": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_inbox(credentials: APIKeyCredentials, inbox_id: str):
|
||||
client = _client(credentials)
|
||||
await client.inboxes.delete(inbox_id=inbox_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_inbox(credentials, input_data.inbox_id)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,384 +0,0 @@
|
||||
"""
|
||||
AgentMail List blocks — manage allow/block lists for email filtering.
|
||||
|
||||
Lists let you control which email addresses and domains your agents can
|
||||
send to or receive from. There are four list types based on two dimensions:
|
||||
direction (send/receive) and type (allow/block).
|
||||
|
||||
- receive + allow: Only accept emails from these addresses/domains
|
||||
- receive + block: Reject emails from these addresses/domains
|
||||
- send + allow: Only send emails to these addresses/domains
|
||||
- send + block: Prevent sending emails to these addresses/domains
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class ListDirection(str, Enum):
|
||||
SEND = "send"
|
||||
RECEIVE = "receive"
|
||||
|
||||
|
||||
class ListType(str, Enum):
|
||||
ALLOW = "allow"
|
||||
BLOCK = "block"
|
||||
|
||||
|
||||
class AgentMailListEntriesBlock(Block):
|
||||
"""
|
||||
List all entries in an AgentMail allow/block list.
|
||||
|
||||
Retrieves email addresses and domains that are currently allowed
|
||||
or blocked for sending or receiving. Use direction and list_type
|
||||
to select which of the four lists to query.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' to filter outgoing emails, 'receive' to filter incoming emails"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' for whitelist (only permit these), 'block' for blacklist (reject these)"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of entries to return per page",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
entries: list[dict] = SchemaField(
|
||||
description="List of entries, each with an email address or domain"
|
||||
)
|
||||
count: int = SchemaField(description="Number of entries returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="01489100-35da-45aa-8a01-9540ba0e9a21",
|
||||
description="List all entries in an AgentMail allow/block list. Choose send/receive direction and allow/block type.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
},
|
||||
test_output=[
|
||||
("entries", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_entries": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"entries": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_entries(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.lists.list(direction, list_type, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_entries(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
**params,
|
||||
)
|
||||
entries = [e.model_dump() for e in response.entries]
|
||||
|
||||
yield "entries", entries
|
||||
yield "count", (c if (c := response.count) is not None else len(entries))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailCreateListEntryBlock(Block):
|
||||
"""
|
||||
Add an email address or domain to an AgentMail allow/block list.
|
||||
|
||||
Entries can be full email addresses (e.g. 'partner@example.com') or
|
||||
entire domains (e.g. 'example.com'). For block lists, you can optionally
|
||||
provide a reason (e.g. 'spam', 'competitor').
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' for outgoing email rules, 'receive' for incoming email rules"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' to whitelist, 'block' to blacklist"
|
||||
)
|
||||
entry: str = SchemaField(
|
||||
description="Email address (user@example.com) or domain (example.com) to add"
|
||||
)
|
||||
reason: str = SchemaField(
|
||||
description="Reason for blocking (only used with block lists, e.g. 'spam', 'competitor')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
entry: str = SchemaField(
|
||||
description="The email address or domain that was added"
|
||||
)
|
||||
result: dict = SchemaField(description="Complete entry object")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b6650a0a-b113-40cf-8243-ff20f684f9b8",
|
||||
description="Add an email address or domain to an allow/block list. Block spam senders or whitelist trusted domains.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
"entry": "spam@example.com",
|
||||
},
|
||||
test_output=[
|
||||
("entry", "spam@example.com"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_entry": lambda *a, **kw: type(
|
||||
"Entry",
|
||||
(),
|
||||
{
|
||||
"model_dump": lambda self: {"entry": "spam@example.com"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_entry(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.lists.create(direction, list_type, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"entry": input_data.entry}
|
||||
if input_data.reason and input_data.list_type == ListType.BLOCK:
|
||||
params["reason"] = input_data.reason
|
||||
|
||||
result = await self.create_entry(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
**params,
|
||||
)
|
||||
result_dict = result.model_dump()
|
||||
|
||||
yield "entry", input_data.entry
|
||||
yield "result", result_dict
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetListEntryBlock(Block):
|
||||
"""
|
||||
Check if an email address or domain exists in an AgentMail allow/block list.
|
||||
|
||||
Returns the entry details if found. Use this to verify whether a specific
|
||||
address or domain is currently allowed or blocked.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' for outgoing rules, 'receive' for incoming rules"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' for whitelist, 'block' for blacklist"
|
||||
)
|
||||
entry: str = SchemaField(description="Email address or domain to look up")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
entry: str = SchemaField(
|
||||
description="The email address or domain that was found"
|
||||
)
|
||||
result: dict = SchemaField(description="Complete entry object with metadata")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fb117058-ab27-40d1-9231-eb1dd526fc7a",
|
||||
description="Check if an email address or domain is in an allow/block list. Verify filtering rules.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
"entry": "spam@example.com",
|
||||
},
|
||||
test_output=[
|
||||
("entry", "spam@example.com"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_entry": lambda *a, **kw: type(
|
||||
"Entry",
|
||||
(),
|
||||
{
|
||||
"model_dump": lambda self: {"entry": "spam@example.com"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_entry(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, entry: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.lists.get(direction, list_type, entry=entry)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.get_entry(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
input_data.entry,
|
||||
)
|
||||
result_dict = result.model_dump()
|
||||
|
||||
yield "entry", input_data.entry
|
||||
yield "result", result_dict
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteListEntryBlock(Block):
|
||||
"""
|
||||
Remove an email address or domain from an AgentMail allow/block list.
|
||||
|
||||
After removal, the address/domain will no longer be filtered by this list.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' for outgoing rules, 'receive' for incoming rules"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' for whitelist, 'block' for blacklist"
|
||||
)
|
||||
entry: str = SchemaField(
|
||||
description="Email address or domain to remove from the list"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the entry was successfully removed"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2b8d57f1-1c9e-470f-a70b-5991c80fad5f",
|
||||
description="Remove an email address or domain from an allow/block list to stop filtering it.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
"entry": "spam@example.com",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_entry": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_entry(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, entry: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
await client.lists.delete(direction, list_type, entry=entry)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_entry(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
input_data.entry,
|
||||
)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,695 +0,0 @@
|
||||
"""
|
||||
AgentMail Message blocks — send, list, get, reply, forward, and update messages.
|
||||
|
||||
A Message is an individual email within a Thread. Agents can send new messages
|
||||
(which create threads), reply to existing messages, forward them, and manage
|
||||
labels for state tracking (e.g. read/unread, campaign tags).
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailSendMessageBlock(Block):
|
||||
"""
|
||||
Send a new email from an AgentMail inbox, automatically creating a new thread.
|
||||
|
||||
Supports plain text and HTML bodies, CC/BCC recipients, and labels for
|
||||
organizing messages (e.g. campaign tracking, state management).
|
||||
Max 50 combined recipients across to, cc, and bcc.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to send from (e.g. 'agent@agentmail.to')"
|
||||
)
|
||||
to: list[str] = SchemaField(
|
||||
description="Recipient email addresses (e.g. ['user@example.com'])"
|
||||
)
|
||||
subject: str = SchemaField(description="Email subject line")
|
||||
text: str = SchemaField(
|
||||
description="Plain text body of the email. Always provide this as a fallback for email clients that don't render HTML."
|
||||
)
|
||||
html: str = SchemaField(
|
||||
description="Rich HTML body of the email. Embed CSS in a <style> tag for best compatibility across email clients.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
cc: list[str] = SchemaField(
|
||||
description="CC recipient email addresses for human-in-the-loop oversight",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
bcc: list[str] = SchemaField(
|
||||
description="BCC recipient email addresses (hidden from other recipients)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Labels to tag the message for filtering and state management (e.g. ['outreach', 'q4-campaign'])",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Unique identifier of the sent message"
|
||||
)
|
||||
thread_id: str = SchemaField(
|
||||
description="Thread ID grouping this message and any future replies"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete sent message object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b67469b2-7748-4d81-a223-4ebd332cca89",
|
||||
description="Send a new email from an AgentMail inbox. Creates a new conversation thread. Supports HTML, CC/BCC, and labels.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"to": ["user@example.com"],
|
||||
"subject": "Test",
|
||||
"text": "Hello",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-msg-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"send_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-msg-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {
|
||||
"message_id": "mock-msg-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def send_message(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.send(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
total = len(input_data.to) + len(input_data.cc) + len(input_data.bcc)
|
||||
if total > 50:
|
||||
raise ValueError(
|
||||
f"Max 50 combined recipients across to, cc, and bcc (got {total})"
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"to": input_data.to,
|
||||
"subject": input_data.subject,
|
||||
"text": input_data.text,
|
||||
}
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
if input_data.cc:
|
||||
params["cc"] = input_data.cc
|
||||
if input_data.bcc:
|
||||
params["bcc"] = input_data.bcc
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
msg = await self.send_message(credentials, input_data.inbox_id, **params)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "thread_id", msg.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListMessagesBlock(Block):
|
||||
"""
|
||||
List all messages in an AgentMail inbox with optional label filtering.
|
||||
|
||||
Returns a paginated list of messages. Use labels to filter (e.g.
|
||||
labels=['unread'] to only get unprocessed messages). Useful for
|
||||
polling workflows or building inbox views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to list messages from"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of messages to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return messages with ALL of these labels (e.g. ['unread'] or ['q4-campaign', 'follow-up'])",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
messages: list[dict] = SchemaField(
|
||||
description="List of message objects with subject, sender, text, html, labels, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of messages returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="721234df-c7a2-4927-b205-744badbd5844",
|
||||
description="List messages in an AgentMail inbox. Filter by labels to find unread, campaign-tagged, or categorized messages.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("messages", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_messages": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"messages": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_messages(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.list(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_messages(
|
||||
credentials, input_data.inbox_id, **params
|
||||
)
|
||||
messages = [m.model_dump() for m in response.messages]
|
||||
|
||||
yield "messages", messages
|
||||
yield "count", (c if (c := response.count) is not None else len(messages))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetMessageBlock(Block):
|
||||
"""
|
||||
Retrieve a specific email message by ID from an AgentMail inbox.
|
||||
|
||||
Returns the full message including subject, body (text and HTML),
|
||||
sender, recipients, and attachments. Use extracted_text to get
|
||||
only the new reply content without quoted history.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the message belongs to"
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID to retrieve (e.g. '<abc123@agentmail.to>')"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(description="Unique identifier of the message")
|
||||
thread_id: str = SchemaField(description="Thread this message belongs to")
|
||||
subject: str = SchemaField(description="Email subject line")
|
||||
text: str = SchemaField(
|
||||
description="Full plain text body (may include quoted reply history)"
|
||||
)
|
||||
extracted_text: str = SchemaField(
|
||||
description="Just the new reply content with quoted history stripped. Best for AI processing.",
|
||||
default="",
|
||||
)
|
||||
html: str = SchemaField(description="HTML body of the email", default="")
|
||||
result: dict = SchemaField(
|
||||
description="Complete message object with all fields including sender, recipients, attachments, labels"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2788bdfa-1527-4603-a5e4-a455c05c032f",
|
||||
description="Retrieve a specific email message by ID. Includes extracted_text for clean reply content without quoted history.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "test-msg"),
|
||||
("thread_id", "t1"),
|
||||
("subject", "Hi"),
|
||||
("text", "Hello"),
|
||||
("extracted_text", "Hello"),
|
||||
("html", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "test-msg",
|
||||
"thread_id": "t1",
|
||||
"subject": "Hi",
|
||||
"text": "Hello",
|
||||
"extracted_text": "Hello",
|
||||
"html": "",
|
||||
"model_dump": lambda self: {"message_id": "test-msg"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_message(
|
||||
credentials: APIKeyCredentials,
|
||||
inbox_id: str,
|
||||
message_id: str,
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.get(
|
||||
inbox_id=inbox_id, message_id=message_id
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
msg = await self.get_message(
|
||||
credentials, input_data.inbox_id, input_data.message_id
|
||||
)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "thread_id", msg.thread_id or ""
|
||||
yield "subject", msg.subject or ""
|
||||
yield "text", msg.text or ""
|
||||
yield "extracted_text", msg.extracted_text or ""
|
||||
yield "html", msg.html or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailReplyToMessageBlock(Block):
|
||||
"""
|
||||
Reply to an existing email message, keeping the reply in the same thread.
|
||||
|
||||
The reply is automatically added to the same conversation thread as the
|
||||
original message. Use this for multi-turn agent conversations.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to send the reply from"
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID to reply to (e.g. '<abc123@agentmail.to>')"
|
||||
)
|
||||
text: str = SchemaField(description="Plain text body of the reply")
|
||||
html: str = SchemaField(
|
||||
description="Rich HTML body of the reply",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Unique identifier of the reply message"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID the reply was added to")
|
||||
result: dict = SchemaField(
|
||||
description="Complete reply message object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b9fe53fa-5026-4547-9570-b54ccb487229",
|
||||
description="Reply to an existing email in the same conversation thread. Use for multi-turn agent conversations.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"text": "Reply",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-reply-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"reply_to_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-reply-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {"message_id": "mock-reply-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def reply_to_message(
|
||||
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.reply(
|
||||
inbox_id=inbox_id, message_id=message_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"text": input_data.text}
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
|
||||
reply = await self.reply_to_message(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
input_data.message_id,
|
||||
**params,
|
||||
)
|
||||
result = reply.model_dump()
|
||||
|
||||
yield "message_id", reply.message_id
|
||||
yield "thread_id", reply.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailForwardMessageBlock(Block):
|
||||
"""
|
||||
Forward an existing email message to one or more recipients.
|
||||
|
||||
Sends the original message content to different email addresses.
|
||||
Optionally prepend additional text or override the subject line.
|
||||
Max 50 combined recipients across to, cc, and bcc.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to forward from"
|
||||
)
|
||||
message_id: str = SchemaField(description="Message ID to forward")
|
||||
to: list[str] = SchemaField(
|
||||
description="Recipient email addresses to forward the message to (e.g. ['user@example.com'])"
|
||||
)
|
||||
cc: list[str] = SchemaField(
|
||||
description="CC recipient email addresses",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
bcc: list[str] = SchemaField(
|
||||
description="BCC recipient email addresses (hidden from other recipients)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
subject: str = SchemaField(
|
||||
description="Override the subject line (defaults to 'Fwd: <original subject>')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text: str = SchemaField(
|
||||
description="Additional plain text to prepend before the forwarded content",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
html: str = SchemaField(
|
||||
description="Additional HTML to prepend before the forwarded content",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Unique identifier of the forwarded message"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID of the forward")
|
||||
result: dict = SchemaField(
|
||||
description="Complete forwarded message object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b70c7e33-5d66-4f8e-897f-ac73a7bfce82",
|
||||
description="Forward an email message to one or more recipients. Supports CC/BCC and optional extra text or subject override.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"to": ["user@example.com"],
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-fwd-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"forward_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-fwd-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {"message_id": "mock-fwd-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def forward_message(
|
||||
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.forward(
|
||||
inbox_id=inbox_id, message_id=message_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
total = len(input_data.to) + len(input_data.cc) + len(input_data.bcc)
|
||||
if total > 50:
|
||||
raise ValueError(
|
||||
f"Max 50 combined recipients across to, cc, and bcc (got {total})"
|
||||
)
|
||||
|
||||
params: dict = {"to": input_data.to}
|
||||
if input_data.cc:
|
||||
params["cc"] = input_data.cc
|
||||
if input_data.bcc:
|
||||
params["bcc"] = input_data.bcc
|
||||
if input_data.subject:
|
||||
params["subject"] = input_data.subject
|
||||
if input_data.text:
|
||||
params["text"] = input_data.text
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
|
||||
fwd = await self.forward_message(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
input_data.message_id,
|
||||
**params,
|
||||
)
|
||||
result = fwd.model_dump()
|
||||
|
||||
yield "message_id", fwd.message_id
|
||||
yield "thread_id", fwd.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailUpdateMessageBlock(Block):
|
||||
"""
|
||||
Add or remove labels on an email message for state management.
|
||||
|
||||
Labels are string tags used to track message state (read/unread),
|
||||
categorize messages (billing, support), or tag campaigns (q4-outreach).
|
||||
Common pattern: add 'read' and remove 'unread' after processing a message.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the message belongs to"
|
||||
)
|
||||
message_id: str = SchemaField(description="Message ID to update labels on")
|
||||
add_labels: list[str] = SchemaField(
|
||||
description="Labels to add (e.g. ['read', 'processed', 'high-priority'])",
|
||||
default_factory=list,
|
||||
)
|
||||
remove_labels: list[str] = SchemaField(
|
||||
description="Labels to remove (e.g. ['unread', 'pending'])",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(description="The updated message ID")
|
||||
result: dict = SchemaField(
|
||||
description="Complete updated message object with current labels"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="694ff816-4c89-4a5e-a552-8c31be187735",
|
||||
description="Add or remove labels on an email message. Use for read/unread tracking, campaign tagging, or state management.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"add_labels": ["read"],
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "test-msg"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"update_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "test-msg",
|
||||
"model_dump": lambda self: {"message_id": "test-msg"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_message(
|
||||
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.update(
|
||||
inbox_id=inbox_id, message_id=message_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
if not input_data.add_labels and not input_data.remove_labels:
|
||||
raise ValueError(
|
||||
"Must specify at least one label operation: add_labels or remove_labels"
|
||||
)
|
||||
|
||||
params: dict = {}
|
||||
if input_data.add_labels:
|
||||
params["add_labels"] = input_data.add_labels
|
||||
if input_data.remove_labels:
|
||||
params["remove_labels"] = input_data.remove_labels
|
||||
|
||||
msg = await self.update_message(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
input_data.message_id,
|
||||
**params,
|
||||
)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,651 +0,0 @@
|
||||
"""
|
||||
AgentMail Pod blocks — create, get, list, delete pods and list pod-scoped resources.
|
||||
|
||||
Pods provide multi-tenant isolation between your customers. Each pod acts as
|
||||
an isolated workspace containing its own inboxes, domains, threads, and drafts.
|
||||
Use pods when building SaaS platforms, agency tools, or AI agent fleets that
|
||||
serve multiple customers.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailCreatePodBlock(Block):
|
||||
"""
|
||||
Create a new pod for multi-tenant customer isolation.
|
||||
|
||||
Each pod acts as an isolated workspace for one customer or tenant.
|
||||
Use client_id to map pods to your internal tenant IDs for idempotent
|
||||
creation (safe to retry without creating duplicates).
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
client_id: str = SchemaField(
|
||||
description="Your internal tenant/customer ID for idempotent mapping. Lets you access the pod by your own ID instead of AgentMail's pod_id.",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
pod_id: str = SchemaField(description="Unique identifier of the created pod")
|
||||
result: dict = SchemaField(description="Complete pod object with all metadata")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a2db9784-2d17-4f8f-9d6b-0214e6f22101",
|
||||
description="Create a new pod for multi-tenant customer isolation. Use client_id to map to your internal tenant IDs.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("pod_id", "mock-pod-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_pod": lambda *a, **kw: type(
|
||||
"Pod",
|
||||
(),
|
||||
{
|
||||
"pod_id": "mock-pod-id",
|
||||
"model_dump": lambda self: {"pod_id": "mock-pod-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_pod(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.create(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.client_id:
|
||||
params["client_id"] = input_data.client_id
|
||||
|
||||
pod = await self.create_pod(credentials, **params)
|
||||
result = pod.model_dump()
|
||||
|
||||
yield "pod_id", pod.pod_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetPodBlock(Block):
|
||||
"""
|
||||
Retrieve details of an existing pod by its ID.
|
||||
|
||||
Returns the pod metadata including its client_id mapping and
|
||||
creation timestamp.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to retrieve")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
pod_id: str = SchemaField(description="Unique identifier of the pod")
|
||||
result: dict = SchemaField(description="Complete pod object with all metadata")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="553361bc-bb1b-4322-9ad4-0c226200217e",
|
||||
description="Retrieve details of an existing pod including its client_id mapping and metadata.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("pod_id", "test-pod"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_pod": lambda *a, **kw: type(
|
||||
"Pod",
|
||||
(),
|
||||
{
|
||||
"pod_id": "test-pod",
|
||||
"model_dump": lambda self: {"pod_id": "test-pod"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_pod(credentials: APIKeyCredentials, pod_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.pods.get(pod_id=pod_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
pod = await self.get_pod(credentials, pod_id=input_data.pod_id)
|
||||
result = pod.model_dump()
|
||||
|
||||
yield "pod_id", pod.pod_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodsBlock(Block):
|
||||
"""
|
||||
List all pods in your AgentMail organization.
|
||||
|
||||
Returns a paginated list of all tenant pods with their metadata.
|
||||
Use this to see all customer workspaces at a glance.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of pods to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
pods: list[dict] = SchemaField(
|
||||
description="List of pod objects with pod_id, client_id, creation time, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of pods returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9d3725ee-2968-431a-a816-857ab41e1420",
|
||||
description="List all tenant pods in your organization. See all customer workspaces at a glance.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("pods", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pods": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"pods": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pods(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_pods(credentials, **params)
|
||||
pods = [p.model_dump() for p in response.pods]
|
||||
|
||||
yield "pods", pods
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeletePodBlock(Block):
|
||||
"""
|
||||
Permanently delete a pod. All inboxes and domains must be removed first.
|
||||
|
||||
You cannot delete a pod that still contains inboxes or domains.
|
||||
Delete all child resources first, then delete the pod.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(
|
||||
description="Pod ID to permanently delete (must have no inboxes or domains)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the pod was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f371f8cd-682d-4f5f-905c-529c74a8fb35",
|
||||
description="Permanently delete a pod. All inboxes and domains must be removed first.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_pod": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_pod(credentials: APIKeyCredentials, pod_id: str):
|
||||
client = _client(credentials)
|
||||
await client.pods.delete(pod_id=pod_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_pod(credentials, pod_id=input_data.pod_id)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodInboxesBlock(Block):
|
||||
"""
|
||||
List all inboxes within a specific pod (customer workspace).
|
||||
|
||||
Returns only the inboxes belonging to this pod, providing
|
||||
tenant-scoped visibility.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to list inboxes from")
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of inboxes to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inboxes: list[dict] = SchemaField(
|
||||
description="List of inbox objects within this pod"
|
||||
)
|
||||
count: int = SchemaField(description="Number of inboxes returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a8c17ce0-b7c1-4bc3-ae39-680e1952e5d0",
|
||||
description="List all inboxes within a pod. View email accounts scoped to a specific customer.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("inboxes", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pod_inboxes": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"inboxes": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pod_inboxes(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.inboxes.list(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_pod_inboxes(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
inboxes = [i.model_dump() for i in response.inboxes]
|
||||
|
||||
yield "inboxes", inboxes
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodThreadsBlock(Block):
|
||||
"""
|
||||
List all conversation threads across all inboxes within a pod.
|
||||
|
||||
Returns threads from every inbox in the pod. Use for building
|
||||
per-customer dashboards showing all email activity, or for
|
||||
supervisor agents monitoring a customer's conversations.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to list threads from")
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of threads to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return threads matching ALL of these labels",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
threads: list[dict] = SchemaField(
|
||||
description="List of thread objects from all inboxes in this pod"
|
||||
)
|
||||
count: int = SchemaField(description="Number of threads returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="80214f08-8b85-4533-a6b8-f8123bfcb410",
|
||||
description="List all conversation threads across all inboxes within a pod. View all email activity for a customer.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("threads", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pod_threads": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"threads": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pod_threads(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.threads.list(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_pod_threads(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
threads = [t.model_dump() for t in response.threads]
|
||||
|
||||
yield "threads", threads
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodDraftsBlock(Block):
|
||||
"""
|
||||
List all drafts across all inboxes within a pod.
|
||||
|
||||
Returns pending drafts from every inbox in the pod. Use for
|
||||
per-customer approval dashboards or monitoring scheduled sends.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to list drafts from")
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of drafts to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
drafts: list[dict] = SchemaField(
|
||||
description="List of draft objects from all inboxes in this pod"
|
||||
)
|
||||
count: int = SchemaField(description="Number of drafts returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="12fd7a3e-51ad-4b20-97c1-0391f207f517",
|
||||
description="List all drafts across all inboxes within a pod. View pending emails for a customer.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("drafts", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pod_drafts": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"drafts": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pod_drafts(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.drafts.list(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_pod_drafts(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
drafts = [d.model_dump() for d in response.drafts]
|
||||
|
||||
yield "drafts", drafts
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailCreatePodInboxBlock(Block):
|
||||
"""
|
||||
Create a new email inbox within a specific pod (customer workspace).
|
||||
|
||||
The inbox is automatically scoped to the pod and inherits its
|
||||
isolation guarantees. If username/domain are not provided,
|
||||
AgentMail auto-generates a unique address.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to create the inbox in")
|
||||
username: str = SchemaField(
|
||||
description="Local part of the email address (e.g. 'support'). Leave empty to auto-generate.",
|
||||
default="",
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Email domain (e.g. 'mydomain.com'). Defaults to agentmail.to if empty.",
|
||||
default="",
|
||||
)
|
||||
display_name: str = SchemaField(
|
||||
description="Friendly name shown in the 'From' field (e.g. 'Customer Support')",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(
|
||||
description="Unique identifier of the created inbox"
|
||||
)
|
||||
email_address: str = SchemaField(description="Full email address of the inbox")
|
||||
result: dict = SchemaField(
|
||||
description="Complete inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c6862373-1ac6-402e-89e6-7db1fea882af",
|
||||
description="Create a new email inbox within a pod. The inbox is scoped to the customer workspace.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("inbox_id", "mock-inbox-id"),
|
||||
("email_address", "mock-inbox-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_pod_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "mock-inbox-id",
|
||||
"model_dump": lambda self: {"inbox_id": "mock-inbox-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_pod_inbox(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.inboxes.create(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.username:
|
||||
params["username"] = input_data.username
|
||||
if input_data.domain:
|
||||
params["domain"] = input_data.domain
|
||||
if input_data.display_name:
|
||||
params["display_name"] = input_data.display_name
|
||||
|
||||
inbox = await self.create_pod_inbox(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "email_address", inbox.inbox_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,438 +0,0 @@
|
||||
"""
|
||||
AgentMail Thread blocks — list, get, and delete conversation threads.
|
||||
|
||||
A Thread groups related messages into a single conversation. Threads are
|
||||
created automatically when a new message is sent and grow as replies are added.
|
||||
Threads can be queried per-inbox or across the entire organization.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailListInboxThreadsBlock(Block):
|
||||
"""
|
||||
List all conversation threads within a specific AgentMail inbox.
|
||||
|
||||
Returns a paginated list of threads with optional label filtering.
|
||||
Use labels to find threads by campaign, status, or custom tags.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to list threads from"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of threads to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return threads matching ALL of these labels (e.g. ['q4-campaign', 'follow-up'])",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
threads: list[dict] = SchemaField(
|
||||
description="List of thread objects with thread_id, subject, message count, labels, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of threads returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="63dd9e2d-ef81-405c-b034-c031f0437334",
|
||||
description="List all conversation threads in an AgentMail inbox. Filter by labels for campaign tracking or status management.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("threads", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_threads": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"threads": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_threads(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.threads.list(inbox_id=inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_threads(
|
||||
credentials, input_data.inbox_id, **params
|
||||
)
|
||||
threads = [t.model_dump() for t in response.threads]
|
||||
|
||||
yield "threads", threads
|
||||
yield "count", (c if (c := response.count) is not None else len(threads))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetInboxThreadBlock(Block):
|
||||
"""
|
||||
Retrieve a single conversation thread from an AgentMail inbox.
|
||||
|
||||
Returns the thread with all its messages in chronological order.
|
||||
Use this to get the full conversation history for context when
|
||||
composing replies.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the thread belongs to"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID to retrieve")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
thread_id: str = SchemaField(description="Unique identifier of the thread")
|
||||
messages: list[dict] = SchemaField(
|
||||
description="All messages in the thread, in chronological order"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete thread object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="42866290-1479-4153-83e7-550b703e9da2",
|
||||
description="Retrieve a conversation thread with all its messages. Use for getting full conversation context before replying.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"thread_id": "test-thread",
|
||||
},
|
||||
test_output=[
|
||||
("thread_id", "test-thread"),
|
||||
("messages", []),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_thread": lambda *a, **kw: type(
|
||||
"Thread",
|
||||
(),
|
||||
{
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
"model_dump": lambda self: {
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_thread(credentials: APIKeyCredentials, inbox_id: str, thread_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.threads.get(inbox_id=inbox_id, thread_id=thread_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
thread = await self.get_thread(
|
||||
credentials, input_data.inbox_id, input_data.thread_id
|
||||
)
|
||||
messages = [m.model_dump() for m in thread.messages]
|
||||
result = thread.model_dump()
|
||||
result["messages"] = messages
|
||||
|
||||
yield "thread_id", thread.thread_id
|
||||
yield "messages", messages
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteInboxThreadBlock(Block):
|
||||
"""
|
||||
Permanently delete a conversation thread and all its messages from an inbox.
|
||||
|
||||
This removes the thread and every message within it. This action
|
||||
cannot be undone.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the thread belongs to"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID to permanently delete")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the thread was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="18cd5f6f-4ff6-45da-8300-25a50ea7fb75",
|
||||
description="Permanently delete a conversation thread and all its messages. This action cannot be undone.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"thread_id": "test-thread",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_thread": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_thread(
|
||||
credentials: APIKeyCredentials, inbox_id: str, thread_id: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
await client.inboxes.threads.delete(inbox_id=inbox_id, thread_id=thread_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_thread(
|
||||
credentials, input_data.inbox_id, input_data.thread_id
|
||||
)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListOrgThreadsBlock(Block):
|
||||
"""
|
||||
List conversation threads across ALL inboxes in your organization.
|
||||
|
||||
Unlike per-inbox listing, this returns threads from every inbox.
|
||||
Ideal for building supervisor agents that monitor all conversations,
|
||||
analytics dashboards, or cross-agent routing workflows.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of threads to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return threads matching ALL of these labels",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
threads: list[dict] = SchemaField(
|
||||
description="List of thread objects from all inboxes in the organization"
|
||||
)
|
||||
count: int = SchemaField(description="Number of threads returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d7a0657b-58ab-48b2-898b-7bd94f44a708",
|
||||
description="List threads across ALL inboxes in your organization. Use for supervisor agents, dashboards, or cross-agent monitoring.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("threads", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_org_threads": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"threads": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_org_threads(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.threads.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_org_threads(credentials, **params)
|
||||
threads = [t.model_dump() for t in response.threads]
|
||||
|
||||
yield "threads", threads
|
||||
yield "count", (c if (c := response.count) is not None else len(threads))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetOrgThreadBlock(Block):
|
||||
"""
|
||||
Retrieve a single conversation thread by ID from anywhere in the organization.
|
||||
|
||||
Works without needing to know which inbox the thread belongs to.
|
||||
Returns the thread with all its messages in chronological order.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
thread_id: str = SchemaField(
|
||||
description="Thread ID to retrieve (works across all inboxes)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
thread_id: str = SchemaField(description="Unique identifier of the thread")
|
||||
messages: list[dict] = SchemaField(
|
||||
description="All messages in the thread, in chronological order"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete thread object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="39aaae31-3eb1-44c6-9e37-5a44a4529649",
|
||||
description="Retrieve a conversation thread by ID from anywhere in the organization, without needing the inbox ID.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"thread_id": "test-thread",
|
||||
},
|
||||
test_output=[
|
||||
("thread_id", "test-thread"),
|
||||
("messages", []),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_org_thread": lambda *a, **kw: type(
|
||||
"Thread",
|
||||
(),
|
||||
{
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
"model_dump": lambda self: {
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_org_thread(credentials: APIKeyCredentials, thread_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.threads.get(thread_id=thread_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
thread = await self.get_org_thread(credentials, input_data.thread_id)
|
||||
messages = [m.model_dump() for m in thread.messages]
|
||||
result = thread.model_dump()
|
||||
result["messages"] = messages
|
||||
|
||||
yield "thread_id", thread.thread_id
|
||||
yield "messages", messages
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -27,7 +27,6 @@ from backend.util.file import MediaFileType, store_media_file
|
||||
class GeminiImageModel(str, Enum):
|
||||
NANO_BANANA = "google/nano-banana"
|
||||
NANO_BANANA_PRO = "google/nano-banana-pro"
|
||||
NANO_BANANA_2 = "google/nano-banana-2"
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
@@ -78,7 +77,7 @@ class AIImageCustomizerBlock(Block):
|
||||
)
|
||||
model: GeminiImageModel = SchemaField(
|
||||
description="The AI model to use for image generation and editing",
|
||||
default=GeminiImageModel.NANO_BANANA_2,
|
||||
default=GeminiImageModel.NANO_BANANA,
|
||||
title="Model",
|
||||
)
|
||||
images: list[MediaFileType] = SchemaField(
|
||||
@@ -104,7 +103,7 @@ class AIImageCustomizerBlock(Block):
|
||||
super().__init__(
|
||||
id="d76bbe4c-930e-4894-8469-b66775511f71",
|
||||
description=(
|
||||
"Generate and edit custom images using Google's Nano-Banana models from Gemini. "
|
||||
"Generate and edit custom images using Google's Nano-Banana model from Gemini 2.5. "
|
||||
"Provide a prompt and optional reference images to create or modify images."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
@@ -112,7 +111,7 @@ class AIImageCustomizerBlock(Block):
|
||||
output_schema=AIImageCustomizerBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Make the scene more vibrant and colorful",
|
||||
"model": GeminiImageModel.NANO_BANANA_2,
|
||||
"model": GeminiImageModel.NANO_BANANA,
|
||||
"images": [],
|
||||
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
|
||||
"output_format": OutputFormat.JPG,
|
||||
|
||||
@@ -115,7 +115,6 @@ class ImageGenModel(str, Enum):
|
||||
RECRAFT = "Recraft v3"
|
||||
SD3_5 = "Stable Diffusion 3.5 Medium"
|
||||
NANO_BANANA_PRO = "Nano Banana Pro"
|
||||
NANO_BANANA_2 = "Nano Banana 2"
|
||||
|
||||
|
||||
class AIImageGeneratorBlock(Block):
|
||||
@@ -132,7 +131,7 @@ class AIImageGeneratorBlock(Block):
|
||||
)
|
||||
model: ImageGenModel = SchemaField(
|
||||
description="The AI model to use for image generation",
|
||||
default=ImageGenModel.NANO_BANANA_2,
|
||||
default=ImageGenModel.SD3_5,
|
||||
title="Model",
|
||||
)
|
||||
size: ImageSize = SchemaField(
|
||||
@@ -166,7 +165,7 @@ class AIImageGeneratorBlock(Block):
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"prompt": "An octopus using a laptop in a snowy forest with 'AutoGPT' clearly visible on the screen",
|
||||
"model": ImageGenModel.NANO_BANANA_2,
|
||||
"model": ImageGenModel.RECRAFT,
|
||||
"size": ImageSize.SQUARE,
|
||||
"style": ImageStyle.REALISTIC,
|
||||
},
|
||||
@@ -180,9 +179,7 @@ class AIImageGeneratorBlock(Block):
|
||||
],
|
||||
test_mock={
|
||||
# Return a data URI directly so store_media_file doesn't need to download
|
||||
"_run_client": lambda *args, **kwargs: (
|
||||
"data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
|
||||
)
|
||||
"_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
|
||||
},
|
||||
)
|
||||
|
||||
@@ -283,24 +280,17 @@ class AIImageGeneratorBlock(Block):
|
||||
)
|
||||
return output
|
||||
|
||||
elif input_data.model in (
|
||||
ImageGenModel.NANO_BANANA_PRO,
|
||||
ImageGenModel.NANO_BANANA_2,
|
||||
):
|
||||
# Use Nano Banana models (Google Gemini image variants)
|
||||
model_map = {
|
||||
ImageGenModel.NANO_BANANA_PRO: "google/nano-banana-pro",
|
||||
ImageGenModel.NANO_BANANA_2: "google/nano-banana-2",
|
||||
}
|
||||
elif input_data.model == ImageGenModel.NANO_BANANA_PRO:
|
||||
# Use Nano Banana Pro (Google Gemini 3 Pro Image)
|
||||
input_params = {
|
||||
"prompt": modified_prompt,
|
||||
"aspect_ratio": SIZE_TO_NANO_BANANA_RATIO[input_data.size],
|
||||
"resolution": "2K",
|
||||
"resolution": "2K", # Default to 2K for good quality/cost balance
|
||||
"output_format": "jpg",
|
||||
"safety_filter_level": "block_only_high",
|
||||
"safety_filter_level": "block_only_high", # Most permissive
|
||||
}
|
||||
output = await self._run_client(
|
||||
credentials, model_map[input_data.model], input_params
|
||||
credentials, "google/nano-banana-pro", input_params
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@@ -1,376 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from typing_extensions import TypedDict # Needed for Python <3.12 compatibility
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Block ID shared between autopilot.py and copilot prompting.py.
|
||||
AUTOPILOT_BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
|
||||
|
||||
class ToolCallEntry(TypedDict):
|
||||
"""A single tool invocation record from an autopilot execution."""
|
||||
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
input: Any
|
||||
output: Any | None
|
||||
success: bool | None
|
||||
|
||||
|
||||
class TokenUsage(TypedDict):
|
||||
"""Aggregated token counts from the autopilot stream."""
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class AutoPilotBlock(Block):
|
||||
"""Execute tasks using AutoGPT AutoPilot with full access to platform tools.
|
||||
|
||||
The autopilot can manage agents, access workspace files, fetch web content,
|
||||
run blocks, and more. This block enables sub-agent patterns (autopilot calling
|
||||
autopilot) and scheduled autopilot execution via the agent executor.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
"""Input schema for the AutoPilot block."""
|
||||
|
||||
prompt: str = SchemaField(
|
||||
description=(
|
||||
"The task or instruction for the autopilot to execute. "
|
||||
"The autopilot has access to platform tools like agent management, "
|
||||
"workspace files, web fetch, block execution, and more."
|
||||
),
|
||||
placeholder="Find my agents and list them",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
system_context: str = SchemaField(
|
||||
description=(
|
||||
"Optional additional context prepended to the prompt. "
|
||||
"Use this to constrain autopilot behavior, provide domain "
|
||||
"context, or set output format requirements."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
session_id: str = SchemaField(
|
||||
description=(
|
||||
"Session ID to continue an existing autopilot conversation. "
|
||||
"Leave empty to start a new session. "
|
||||
"Use the session_id output from a previous run to continue."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
max_recursion_depth: int = SchemaField(
|
||||
description=(
|
||||
"Maximum nesting depth when the autopilot calls this block "
|
||||
"recursively (sub-agent pattern). Prevents infinite loops."
|
||||
),
|
||||
default=3,
|
||||
ge=1,
|
||||
le=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# timeout_seconds removed: the SDK manages its own heartbeat-based
|
||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||
# SDK's internal stream (see service.py CRITICAL comment).
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
"""Output schema for the AutoPilot block."""
|
||||
|
||||
response: str = SchemaField(
|
||||
description="The final text response from the autopilot."
|
||||
)
|
||||
tool_calls: list[ToolCallEntry] = SchemaField(
|
||||
description=(
|
||||
"List of tools called during execution. Each entry has "
|
||||
"tool_call_id, tool_name, input, output, and success fields."
|
||||
),
|
||||
)
|
||||
conversation_history: str = SchemaField(
|
||||
description=(
|
||||
"Current turn messages (user prompt + assistant reply) as JSON. "
|
||||
"It can be used for logging or analysis."
|
||||
),
|
||||
)
|
||||
session_id: str = SchemaField(
|
||||
description=(
|
||||
"Session ID for this conversation. "
|
||||
"Pass this back to continue the conversation in a future run."
|
||||
),
|
||||
)
|
||||
token_usage: TokenUsage = SchemaField(
|
||||
description=(
|
||||
"Token usage statistics: prompt_tokens, "
|
||||
"completion_tokens, total_tokens."
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id=AUTOPILOT_BLOCK_ID,
|
||||
description=(
|
||||
"Execute tasks using AutoGPT AutoPilot with full access to "
|
||||
"platform tools (agent management, workspace files, web fetch, "
|
||||
"block execution, and more). Enables sub-agent patterns and "
|
||||
"scheduled autopilot execution."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.AGENT},
|
||||
input_schema=AutoPilotBlock.Input,
|
||||
output_schema=AutoPilotBlock.Output,
|
||||
test_input={
|
||||
"prompt": "List my agents",
|
||||
"system_context": "",
|
||||
"session_id": "",
|
||||
"max_recursion_depth": 3,
|
||||
},
|
||||
test_output=[
|
||||
("response", "You have 2 agents: Agent A and Agent B."),
|
||||
("tool_calls", []),
|
||||
(
|
||||
"conversation_history",
|
||||
'[{"role": "user", "content": "List my agents"}]',
|
||||
),
|
||||
("session_id", "test-session-id"),
|
||||
(
|
||||
"token_usage",
|
||||
{
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"create_session": lambda *args, **kwargs: "test-session-id",
|
||||
"execute_copilot": lambda *args, **kwargs: (
|
||||
"You have 2 agents: Agent A and Agent B.",
|
||||
[],
|
||||
'[{"role": "user", "content": "List my agents"}]',
|
||||
"test-session-id",
|
||||
{
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def create_session(self, user_id: str) -> str:
|
||||
"""Create a new chat session and return its ID (mockable for tests)."""
|
||||
from backend.copilot.model import create_chat_session
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
return session.session_id
|
||||
|
||||
async def execute_copilot(
|
||||
self,
|
||||
prompt: str,
|
||||
system_context: str,
|
||||
session_id: str,
|
||||
max_recursion_depth: int,
|
||||
user_id: str,
|
||||
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
|
||||
"""Invoke the copilot and collect all stream results.
|
||||
|
||||
Delegates to :func:`collect_copilot_response` — the shared helper that
|
||||
consumes ``stream_chat_completion_sdk`` without wrapping it in an
|
||||
``asyncio.timeout`` (the SDK manages its own heartbeat-based timeouts).
|
||||
|
||||
Args:
|
||||
prompt: The user task/instruction.
|
||||
system_context: Optional context prepended to the prompt.
|
||||
session_id: Chat session to use.
|
||||
max_recursion_depth: Maximum allowed recursion nesting.
|
||||
user_id: Authenticated user ID.
|
||||
|
||||
Returns:
|
||||
A tuple of (response_text, tool_calls, history_json, session_id, usage).
|
||||
"""
|
||||
from backend.copilot.sdk.collect import collect_copilot_response
|
||||
|
||||
tokens = _check_recursion(max_recursion_depth)
|
||||
try:
|
||||
effective_prompt = prompt
|
||||
if system_context:
|
||||
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
|
||||
|
||||
result = await collect_copilot_response(
|
||||
session_id=session_id,
|
||||
message=effective_prompt,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Build a lightweight conversation summary from streamed data.
|
||||
turn_messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": effective_prompt},
|
||||
]
|
||||
if result.tool_calls:
|
||||
turn_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": result.response_text,
|
||||
"tool_calls": result.tool_calls,
|
||||
}
|
||||
)
|
||||
else:
|
||||
turn_messages.append(
|
||||
{"role": "assistant", "content": result.response_text}
|
||||
)
|
||||
history_json = json.dumps(turn_messages, default=str)
|
||||
|
||||
tool_calls: list[ToolCallEntry] = [
|
||||
{
|
||||
"tool_call_id": tc["tool_call_id"],
|
||||
"tool_name": tc["tool_name"],
|
||||
"input": tc["input"],
|
||||
"output": tc["output"],
|
||||
"success": tc["success"],
|
||||
}
|
||||
for tc in result.tool_calls
|
||||
]
|
||||
|
||||
usage: TokenUsage = {
|
||||
"prompt_tokens": result.prompt_tokens,
|
||||
"completion_tokens": result.completion_tokens,
|
||||
"total_tokens": result.total_tokens,
|
||||
}
|
||||
|
||||
return (
|
||||
result.response_text,
|
||||
tool_calls,
|
||||
history_json,
|
||||
session_id,
|
||||
usage,
|
||||
)
|
||||
finally:
|
||||
_reset_recursion(tokens)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Validate inputs, invoke the autopilot, and yield structured outputs.
|
||||
|
||||
Yields session_id even on failure so callers can inspect/resume the session.
|
||||
"""
|
||||
if not input_data.prompt.strip():
|
||||
yield "error", "Prompt cannot be empty."
|
||||
return
|
||||
|
||||
if not execution_context.user_id:
|
||||
yield "error", "Cannot run autopilot without an authenticated user."
|
||||
return
|
||||
|
||||
if input_data.max_recursion_depth < 1:
|
||||
yield "error", "max_recursion_depth must be at least 1."
|
||||
return
|
||||
|
||||
# Create session eagerly so the user always gets the session_id,
|
||||
# even if the downstream stream fails (avoids orphaned sessions).
|
||||
sid = input_data.session_id
|
||||
if not sid:
|
||||
sid = await self.create_session(execution_context.user_id)
|
||||
|
||||
# NOTE: No asyncio.timeout() here — the SDK manages its own
|
||||
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout
|
||||
# would cancel the task mid-flight, corrupting the SDK's internal
|
||||
# anyio memory stream (see service.py CRITICAL comment).
|
||||
try:
|
||||
response, tool_calls, history, _, usage = await self.execute_copilot(
|
||||
prompt=input_data.prompt,
|
||||
system_context=input_data.system_context,
|
||||
session_id=sid,
|
||||
max_recursion_depth=input_data.max_recursion_depth,
|
||||
user_id=execution_context.user_id,
|
||||
)
|
||||
|
||||
yield "response", response
|
||||
yield "tool_calls", tool_calls
|
||||
yield "conversation_history", history
|
||||
yield "session_id", sid
|
||||
yield "token_usage", usage
|
||||
except asyncio.CancelledError:
|
||||
yield "session_id", sid
|
||||
yield "error", "AutoPilot execution was cancelled."
|
||||
raise
|
||||
except Exception as exc:
|
||||
yield "session_id", sid
|
||||
yield "error", str(exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers – placed after the block class for top-down readability.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Task-scoped recursion depth counter & chain-wide limit.
|
||||
# contextvars are scoped to the current asyncio task, so concurrent
|
||||
# graph executions each get independent counters.
|
||||
_autopilot_recursion_depth: contextvars.ContextVar[int] = contextvars.ContextVar(
|
||||
"_autopilot_recursion_depth", default=0
|
||||
)
|
||||
_autopilot_recursion_limit: contextvars.ContextVar[int | None] = contextvars.ContextVar(
|
||||
"_autopilot_recursion_limit", default=None
|
||||
)
|
||||
|
||||
|
||||
def _check_recursion(
|
||||
max_depth: int,
|
||||
) -> tuple[contextvars.Token[int], contextvars.Token[int | None]]:
|
||||
"""Check and increment recursion depth.
|
||||
|
||||
Returns ContextVar tokens that must be passed to ``_reset_recursion``
|
||||
when the caller exits to restore the previous depth.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the current depth already meets or exceeds the limit.
|
||||
"""
|
||||
current = _autopilot_recursion_depth.get()
|
||||
inherited = _autopilot_recursion_limit.get()
|
||||
limit = max_depth if inherited is None else min(inherited, max_depth)
|
||||
if current >= limit:
|
||||
raise RuntimeError(
|
||||
f"AutoPilot recursion depth limit reached ({limit}). "
|
||||
"The autopilot has called itself too many times."
|
||||
)
|
||||
return (
|
||||
_autopilot_recursion_depth.set(current + 1),
|
||||
_autopilot_recursion_limit.set(limit),
|
||||
)
|
||||
|
||||
|
||||
def _reset_recursion(
|
||||
tokens: tuple[contextvars.Token[int], contextvars.Token[int | None]],
|
||||
) -> None:
|
||||
"""Restore recursion depth and limit to their previous values."""
|
||||
_autopilot_recursion_depth.reset(tokens[0])
|
||||
_autopilot_recursion_limit.reset(tokens[1])
|
||||
@@ -472,7 +472,7 @@ class AddToListBlock(Block):
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
entries_added = input_data.entries.copy()
|
||||
if input_data.entry is not None:
|
||||
if input_data.entry:
|
||||
entries_added.append(input_data.entry)
|
||||
|
||||
updated_list = input_data.list.copy()
|
||||
|
||||
@@ -21,7 +21,6 @@ from backend.data.model import (
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import resolve_and_check_blocked
|
||||
|
||||
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -97,11 +96,8 @@ class SendEmailBlock(Block):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Email sent successfully")],
|
||||
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
ALLOWED_SMTP_PORTS = {25, 465, 587, 2525}
|
||||
|
||||
@staticmethod
|
||||
def send_email(
|
||||
config: SMTPConfig,
|
||||
@@ -132,17 +128,6 @@ class SendEmailBlock(Block):
|
||||
self, input_data: Input, *, credentials: SMTPCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
# --- SSRF Protection ---
|
||||
smtp_port = input_data.config.smtp_port
|
||||
if smtp_port not in self.ALLOWED_SMTP_PORTS:
|
||||
yield "error", (
|
||||
f"SMTP port {smtp_port} is not allowed. "
|
||||
f"Allowed ports: {sorted(self.ALLOWED_SMTP_PORTS)}"
|
||||
)
|
||||
return
|
||||
|
||||
await resolve_and_check_blocked(input_data.config.smtp_server)
|
||||
|
||||
status = self.send_email(
|
||||
config=input_data.config,
|
||||
to_email=input_data.to_email,
|
||||
@@ -194,19 +179,7 @@ class SendEmailBlock(Block):
|
||||
"was rejected by the server. "
|
||||
"Please verify your account is authorized to send emails."
|
||||
)
|
||||
except smtplib.SMTPConnectError:
|
||||
yield "error", (
|
||||
f"Cannot connect to SMTP server '{input_data.config.smtp_server}' "
|
||||
f"on port {input_data.config.smtp_port}."
|
||||
)
|
||||
except smtplib.SMTPServerDisconnected:
|
||||
yield "error", (
|
||||
f"SMTP server '{input_data.config.smtp_server}' "
|
||||
"disconnected unexpectedly."
|
||||
)
|
||||
except smtplib.SMTPDataError as e:
|
||||
yield "error", f"Email data rejected by server: {str(e)}"
|
||||
except ValueError as e:
|
||||
yield "error", str(e)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@@ -34,29 +34,17 @@ TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
|
||||
class ImageEditorModel(str, Enum):
|
||||
FLUX_KONTEXT_PRO = "Flux Kontext Pro"
|
||||
FLUX_KONTEXT_MAX = "Flux Kontext Max"
|
||||
NANO_BANANA_PRO = "Nano Banana Pro"
|
||||
NANO_BANANA_2 = "Nano Banana 2"
|
||||
class FluxKontextModelName(str, Enum):
|
||||
PRO = "Flux Kontext Pro"
|
||||
MAX = "Flux Kontext Max"
|
||||
|
||||
@property
|
||||
def api_name(self) -> str:
|
||||
_map = {
|
||||
"FLUX_KONTEXT_PRO": "black-forest-labs/flux-kontext-pro",
|
||||
"FLUX_KONTEXT_MAX": "black-forest-labs/flux-kontext-max",
|
||||
"NANO_BANANA_PRO": "google/nano-banana-pro",
|
||||
"NANO_BANANA_2": "google/nano-banana-2",
|
||||
}
|
||||
return _map[self.name]
|
||||
|
||||
|
||||
# Keep old name as alias for backwards compatibility
|
||||
FluxKontextModelName = ImageEditorModel
|
||||
return f"black-forest-labs/flux-kontext-{self.name.lower()}"
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
@@ -81,7 +69,7 @@ class AIImageEditorBlock(Block):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Replicate API key with permissions for Flux Kontext and Nano Banana models",
|
||||
description="Replicate API key with permissions for Flux Kontext models",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="Text instruction describing the desired edit",
|
||||
@@ -99,14 +87,14 @@ class AIImageEditorBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
seed: Optional[int] = SchemaField(
|
||||
description="Random seed. Set for reproducible generation (Flux Kontext only; ignored by Nano Banana models)",
|
||||
description="Random seed. Set for reproducible generation",
|
||||
default=None,
|
||||
title="Seed",
|
||||
advanced=True,
|
||||
)
|
||||
model: ImageEditorModel = SchemaField(
|
||||
model: FluxKontextModelName = SchemaField(
|
||||
description="Model variant to use",
|
||||
default=ImageEditorModel.NANO_BANANA_2,
|
||||
default=FluxKontextModelName.PRO,
|
||||
title="Model",
|
||||
)
|
||||
|
||||
@@ -119,7 +107,7 @@ class AIImageEditorBlock(Block):
|
||||
super().__init__(
|
||||
id="3fd9c73d-4370-4925-a1ff-1b86b99fabfa",
|
||||
description=(
|
||||
"Edit images using Flux Kontext or Google Nano Banana models. Provide a prompt "
|
||||
"Edit images using BlackForest Labs' Flux Kontext models. Provide a prompt "
|
||||
"and optional reference image to generate a modified image."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
@@ -130,7 +118,7 @@ class AIImageEditorBlock(Block):
|
||||
"input_image": "data:image/png;base64,MQ==",
|
||||
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
|
||||
"seed": None,
|
||||
"model": ImageEditorModel.NANO_BANANA_2,
|
||||
"model": FluxKontextModelName.PRO,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
@@ -139,9 +127,7 @@ class AIImageEditorBlock(Block):
|
||||
],
|
||||
test_mock={
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"run_model": lambda *args, **kwargs: (
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
),
|
||||
"run_model": lambda *args, **kwargs: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
@@ -156,7 +142,7 @@ class AIImageEditorBlock(Block):
|
||||
) -> BlockOutput:
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
model=input_data.model,
|
||||
model_name=input_data.model.api_name,
|
||||
prompt=input_data.prompt,
|
||||
input_image_b64=(
|
||||
await store_media_file(
|
||||
@@ -183,7 +169,7 @@ class AIImageEditorBlock(Block):
|
||||
async def run_model(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
model: ImageEditorModel,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
input_image_b64: Optional[str],
|
||||
aspect_ratio: str,
|
||||
@@ -192,29 +178,12 @@ class AIImageEditorBlock(Block):
|
||||
graph_exec_id: str,
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
model_name = model.api_name
|
||||
|
||||
is_nano_banana = model in (
|
||||
ImageEditorModel.NANO_BANANA_PRO,
|
||||
ImageEditorModel.NANO_BANANA_2,
|
||||
)
|
||||
if is_nano_banana:
|
||||
input_params: dict = {
|
||||
"prompt": prompt,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"output_format": "jpg",
|
||||
"safety_filter_level": "block_only_high",
|
||||
}
|
||||
# NB API expects "image_input" as a list, unlike Flux's single "input_image"
|
||||
if input_image_b64:
|
||||
input_params["image_input"] = [input_image_b64]
|
||||
else:
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"input_image": input_image_b64,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"input_image": input_image_b64,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
|
||||
try:
|
||||
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
def github_repo_path(repo_url: str) -> str:
|
||||
"""Extract 'owner/repo' from a GitHub repository URL."""
|
||||
return repo_url.replace("https://github.com/", "")
|
||||
@@ -1,408 +0,0 @@
|
||||
import asyncio
|
||||
from enum import StrEnum
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import parse_data_uri, resolve_media_content
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListCommitsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to list commits from",
|
||||
default="main",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of commits to return (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class CommitItem(TypedDict):
|
||||
sha: str
|
||||
message: str
|
||||
author: str
|
||||
date: str
|
||||
url: str
|
||||
|
||||
commit: CommitItem = SchemaField(
|
||||
title="Commit", description="A commit with its details"
|
||||
)
|
||||
commits: list[CommitItem] = SchemaField(
|
||||
description="List of commits with their details"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing commits failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8b13f579-d8b6-4dc2-a140-f770428805de",
|
||||
description="This block lists commits on a branch in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListCommitsBlock.Input,
|
||||
output_schema=GithubListCommitsBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "main",
|
||||
"per_page": 30,
|
||||
"page": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"commits",
|
||||
[
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"commit",
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_commits": lambda *args, **kwargs: [
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_commits(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
per_page: int,
|
||||
page: int,
|
||||
) -> list[Output.CommitItem]:
|
||||
api = get_api(credentials)
|
||||
commits_url = repo_url + "/commits"
|
||||
params = {"sha": branch, "per_page": str(per_page), "page": str(page)}
|
||||
response = await api.get(commits_url, params=params)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
return [
|
||||
GithubListCommitsBlock.Output.CommitItem(
|
||||
sha=c["sha"],
|
||||
message=c["commit"]["message"],
|
||||
author=(c["commit"].get("author") or {}).get("name", "Unknown"),
|
||||
date=(c["commit"].get("author") or {}).get("date", ""),
|
||||
url=f"https://github.com/{repo_path}/commit/{c['sha']}",
|
||||
)
|
||||
for c in data
|
||||
]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
commits = await self.list_commits(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.per_page,
|
||||
input_data.page,
|
||||
)
|
||||
yield "commits", commits
|
||||
for commit in commits:
|
||||
yield "commit", commit
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class FileOperation(StrEnum):
|
||||
"""File operations for GithubMultiFileCommitBlock.
|
||||
|
||||
UPSERT creates or overwrites a file (the Git Trees API does not distinguish
|
||||
between creation and update — the blob is placed at the given path regardless
|
||||
of whether a file already exists there).
|
||||
|
||||
DELETE removes a file from the tree.
|
||||
"""
|
||||
|
||||
UPSERT = "upsert"
|
||||
DELETE = "delete"
|
||||
|
||||
|
||||
class FileOperationInput(TypedDict):
|
||||
path: str
|
||||
# MediaFileType is a str NewType — no runtime breakage for existing callers.
|
||||
content: MediaFileType
|
||||
operation: FileOperation
|
||||
|
||||
|
||||
class GithubMultiFileCommitBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to commit to",
|
||||
placeholder="feature-branch",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Commit message",
|
||||
placeholder="Add new feature",
|
||||
)
|
||||
files: list[FileOperationInput] = SchemaField(
|
||||
description=(
|
||||
"List of file operations. Each item has: "
|
||||
"'path' (file path), 'content' (file content, ignored for delete), "
|
||||
"'operation' (upsert/delete)"
|
||||
),
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
sha: str = SchemaField(description="SHA of the new commit")
|
||||
url: str = SchemaField(description="URL of the new commit")
|
||||
error: str = SchemaField(description="Error message if the commit failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="389eee51-a95e-4230-9bed-92167a327802",
|
||||
description=(
|
||||
"This block creates a single commit with multiple file "
|
||||
"upsert/delete operations using the Git Trees API."
|
||||
),
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMultiFileCommitBlock.Input,
|
||||
output_schema=GithubMultiFileCommitBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "feature",
|
||||
"commit_message": "Add files",
|
||||
"files": [
|
||||
{
|
||||
"path": "src/new.py",
|
||||
"content": "print('hello')",
|
||||
"operation": "upsert",
|
||||
},
|
||||
{
|
||||
"path": "src/old.py",
|
||||
"content": "",
|
||||
"operation": "delete",
|
||||
},
|
||||
],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("sha", "newcommitsha"),
|
||||
("url", "https://github.com/owner/repo/commit/newcommitsha"),
|
||||
],
|
||||
test_mock={
|
||||
"multi_file_commit": lambda *args, **kwargs: (
|
||||
"newcommitsha",
|
||||
"https://github.com/owner/repo/commit/newcommitsha",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def multi_file_commit(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
files: list[FileOperationInput],
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
safe_branch = quote(branch, safe="")
|
||||
|
||||
# 1. Get the latest commit SHA for the branch
|
||||
ref_url = repo_url + f"/git/refs/heads/{safe_branch}"
|
||||
response = await api.get(ref_url)
|
||||
ref_data = response.json()
|
||||
latest_commit_sha = ref_data["object"]["sha"]
|
||||
|
||||
# 2. Get the tree SHA of the latest commit
|
||||
commit_url = repo_url + f"/git/commits/{latest_commit_sha}"
|
||||
response = await api.get(commit_url)
|
||||
commit_data = response.json()
|
||||
base_tree_sha = commit_data["tree"]["sha"]
|
||||
|
||||
# 3. Build tree entries for each file operation (blobs created concurrently)
|
||||
async def _create_blob(content: str, encoding: str = "utf-8") -> str:
|
||||
blob_url = repo_url + "/git/blobs"
|
||||
blob_response = await api.post(
|
||||
blob_url,
|
||||
json={"content": content, "encoding": encoding},
|
||||
)
|
||||
return blob_response.json()["sha"]
|
||||
|
||||
tree_entries: list[dict] = []
|
||||
upsert_files = []
|
||||
for file_op in files:
|
||||
path = file_op["path"]
|
||||
operation = FileOperation(file_op.get("operation", "upsert"))
|
||||
|
||||
if operation == FileOperation.DELETE:
|
||||
tree_entries.append(
|
||||
{
|
||||
"path": path,
|
||||
"mode": "100644",
|
||||
"type": "blob",
|
||||
"sha": None, # null SHA = delete
|
||||
}
|
||||
)
|
||||
else:
|
||||
upsert_files.append((path, file_op.get("content", "")))
|
||||
|
||||
# Create all blobs concurrently. Data URIs (from store_media_file)
|
||||
# are sent as base64 blobs to preserve binary content.
|
||||
if upsert_files:
|
||||
|
||||
async def _make_blob(content: str) -> str:
|
||||
parsed = parse_data_uri(content)
|
||||
if parsed is not None:
|
||||
_, b64_payload = parsed
|
||||
return await _create_blob(b64_payload, encoding="base64")
|
||||
return await _create_blob(content)
|
||||
|
||||
blob_shas = await asyncio.gather(
|
||||
*[_make_blob(content) for _, content in upsert_files]
|
||||
)
|
||||
for (path, _), blob_sha in zip(upsert_files, blob_shas):
|
||||
tree_entries.append(
|
||||
{
|
||||
"path": path,
|
||||
"mode": "100644",
|
||||
"type": "blob",
|
||||
"sha": blob_sha,
|
||||
}
|
||||
)
|
||||
|
||||
# 4. Create a new tree
|
||||
tree_url = repo_url + "/git/trees"
|
||||
tree_response = await api.post(
|
||||
tree_url,
|
||||
json={"base_tree": base_tree_sha, "tree": tree_entries},
|
||||
)
|
||||
new_tree_sha = tree_response.json()["sha"]
|
||||
|
||||
# 5. Create a new commit
|
||||
new_commit_url = repo_url + "/git/commits"
|
||||
commit_response = await api.post(
|
||||
new_commit_url,
|
||||
json={
|
||||
"message": commit_message,
|
||||
"tree": new_tree_sha,
|
||||
"parents": [latest_commit_sha],
|
||||
},
|
||||
)
|
||||
new_commit_sha = commit_response.json()["sha"]
|
||||
|
||||
# 6. Update the branch reference
|
||||
try:
|
||||
await api.patch(
|
||||
ref_url,
|
||||
json={"sha": new_commit_sha},
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Commit {new_commit_sha} was created but failed to update "
|
||||
f"ref heads/{branch}: {e}. "
|
||||
f"You can recover by manually updating the branch to {new_commit_sha}."
|
||||
) from e
|
||||
|
||||
repo_path = github_repo_path(repo_url)
|
||||
commit_web_url = f"https://github.com/{repo_path}/commit/{new_commit_sha}"
|
||||
return new_commit_sha, commit_web_url
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
# Resolve media references (workspace://, data:, URLs) to data
|
||||
# URIs so _make_blob can send binary content correctly.
|
||||
resolved_files: list[FileOperationInput] = []
|
||||
for file_op in input_data.files:
|
||||
content = file_op.get("content", "")
|
||||
operation = FileOperation(file_op.get("operation", "upsert"))
|
||||
if operation != FileOperation.DELETE:
|
||||
content = await resolve_media_content(
|
||||
MediaFileType(content),
|
||||
execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
resolved_files.append(
|
||||
FileOperationInput(
|
||||
path=file_op["path"],
|
||||
content=MediaFileType(content),
|
||||
operation=operation,
|
||||
)
|
||||
)
|
||||
|
||||
sha, url = await self.multi_file_commit(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
resolved_files,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "url", url
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@@ -21,8 +20,6 @@ from ._auth import (
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
MergeMethod = Literal["merge", "squash", "rebase"]
|
||||
|
||||
|
||||
class GithubListPullRequestsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
@@ -561,109 +558,12 @@ class GithubListPRReviewersBlock(Block):
|
||||
yield "reviewer", reviewer
|
||||
|
||||
|
||||
class GithubMergePullRequestBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
placeholder="https://github.com/owner/repo/pull/1",
|
||||
)
|
||||
merge_method: MergeMethod = SchemaField(
|
||||
description="Merge method to use: merge, squash, or rebase",
|
||||
default="merge",
|
||||
)
|
||||
commit_title: str = SchemaField(
|
||||
description="Title for the merge commit (optional, used for merge and squash)",
|
||||
default="",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the merge commit (optional, used for merge and squash)",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
sha: str = SchemaField(description="SHA of the merge commit")
|
||||
merged: bool = SchemaField(description="Whether the PR was merged")
|
||||
message: str = SchemaField(description="Merge status message")
|
||||
error: str = SchemaField(description="Error message if the merge failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="77456c22-33d8-4fd4-9eef-50b46a35bb48",
|
||||
description="This block merges a pull request using merge, squash, or rebase.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMergePullRequestBlock.Input,
|
||||
output_schema=GithubMergePullRequestBlock.Output,
|
||||
test_input={
|
||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||
"merge_method": "squash",
|
||||
"commit_title": "",
|
||||
"commit_message": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("sha", "abc123"),
|
||||
("merged", True),
|
||||
("message", "Pull Request successfully merged"),
|
||||
],
|
||||
test_mock={
|
||||
"merge_pr": lambda *args, **kwargs: (
|
||||
"abc123",
|
||||
True,
|
||||
"Pull Request successfully merged",
|
||||
)
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def merge_pr(
|
||||
credentials: GithubCredentials,
|
||||
pr_url: str,
|
||||
merge_method: MergeMethod,
|
||||
commit_title: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, bool, str]:
|
||||
api = get_api(credentials)
|
||||
merge_url = prepare_pr_api_url(pr_url=pr_url, path="merge")
|
||||
data: dict[str, str] = {"merge_method": merge_method}
|
||||
if commit_title:
|
||||
data["commit_title"] = commit_title
|
||||
if commit_message:
|
||||
data["commit_message"] = commit_message
|
||||
response = await api.put(merge_url, json=data)
|
||||
result = response.json()
|
||||
return result["sha"], result["merged"], result["message"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
sha, merged, message = await self.merge_pr(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
input_data.merge_method,
|
||||
input_data.commit_title,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "merged", merged
|
||||
yield "message", message
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
def prepare_pr_api_url(pr_url: str, path: str) -> str:
|
||||
# Pattern to capture the base repository URL and the pull request number
|
||||
pattern = r"^(?:(https?)://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
||||
pattern = r"^(?:https?://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
||||
match = re.match(pattern, pr_url)
|
||||
if not match:
|
||||
return pr_url
|
||||
|
||||
scheme, base_url, pr_number = match.groups()
|
||||
return f"{scheme or 'https'}://{base_url}/pulls/{pr_number}/{path}"
|
||||
base_url, pr_number = match.groups()
|
||||
return f"{base_url}/pulls/{pr_number}/{path}"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import base64
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -17,7 +19,6 @@ from ._auth import (
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListTagsBlock(Block):
|
||||
@@ -88,7 +89,7 @@ class GithubListTagsBlock(Block):
|
||||
tags_url = repo_url + "/tags"
|
||||
response = await api.get(tags_url)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
tags: list[GithubListTagsBlock.Output.TagItem] = [
|
||||
{
|
||||
"name": tag["name"],
|
||||
@@ -114,6 +115,101 @@ class GithubListTagsBlock(Block):
|
||||
yield "tag", tag
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class BranchItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
branch: BranchItem = SchemaField(
|
||||
title="Branch",
|
||||
description="Branches with their name and file tree browser URL",
|
||||
)
|
||||
branches: list[BranchItem] = SchemaField(
|
||||
description="List of branches with their name and file tree browser URL"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
||||
description="This block lists all branches for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListBranchesBlock.Input,
|
||||
output_schema=GithubListBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"branches",
|
||||
[
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"branch",
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_branches": lambda *args, **kwargs: [
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_branches(
|
||||
credentials: GithubCredentials, repo_url: str
|
||||
) -> list[Output.BranchItem]:
|
||||
api = get_api(credentials)
|
||||
branches_url = repo_url + "/branches"
|
||||
response = await api.get(branches_url)
|
||||
data = response.json()
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
||||
{
|
||||
"name": branch["name"],
|
||||
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
|
||||
}
|
||||
for branch in data
|
||||
]
|
||||
return branches
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
branches = await self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "branches", branches
|
||||
for branch in branches:
|
||||
yield "branch", branch
|
||||
|
||||
|
||||
class GithubListDiscussionsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
@@ -187,7 +283,7 @@ class GithubListDiscussionsBlock(Block):
|
||||
) -> list[Output.DiscussionItem]:
|
||||
api = get_api(credentials)
|
||||
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
|
||||
repo_path = github_repo_path(repo_url)
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
owner, repo = repo_path.split("/")
|
||||
query = """
|
||||
query($owner: String!, $repo: String!, $num: Int!) {
|
||||
@@ -320,6 +416,564 @@ class GithubListReleasesBlock(Block):
|
||||
yield "release", release
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file in the repository",
|
||||
placeholder="path/to/file",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to read from",
|
||||
placeholder="branch_name",
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
text_content: str = SchemaField(
|
||||
description="Content of the file (decoded as UTF-8 text)"
|
||||
)
|
||||
raw_content: str = SchemaField(
|
||||
description="Raw base64-encoded content of the file"
|
||||
)
|
||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
||||
description="This block reads the content of a specified file from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFileBlock.Input,
|
||||
output_schema=GithubReadFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "path/to/file",
|
||||
"branch": "master",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("raw_content", "RmlsZSBjb250ZW50"),
|
||||
("text_content", "File content"),
|
||||
("size", 13),
|
||||
],
|
||||
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_file(
|
||||
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
||||
) -> tuple[str, int]:
|
||||
api = get_api(credentials)
|
||||
content_url = repo_url + f"/contents/{file_path}?ref={branch}"
|
||||
response = await api.get(content_url)
|
||||
data = response.json()
|
||||
|
||||
if isinstance(data, list):
|
||||
# Multiple entries of different types exist at this path
|
||||
if not (file := next((f for f in data if f["type"] == "file"), None)):
|
||||
raise TypeError("Not a file")
|
||||
data = file
|
||||
|
||||
if data["type"] != "file":
|
||||
raise TypeError("Not a file")
|
||||
|
||||
return data["content"], data["size"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
content, size = await self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", content
|
||||
yield "text_content", base64.b64decode(content).decode("utf-8")
|
||||
yield "size", size
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
folder_path: str = SchemaField(
|
||||
description="Path to the folder in the repository",
|
||||
placeholder="path/to/folder",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to read from (defaults to master)",
|
||||
placeholder="branch_name",
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class DirEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
|
||||
class FileEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
size: int
|
||||
|
||||
file: FileEntry = SchemaField(description="Files in the folder")
|
||||
dir: DirEntry = SchemaField(description="Directories in the folder")
|
||||
error: str = SchemaField(
|
||||
description="Error message if reading the folder failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
||||
description="This block reads the content of a specified folder from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFolderBlock.Input,
|
||||
output_schema=GithubReadFolderBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"folder_path": "path/to/folder",
|
||||
"branch": "master",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
},
|
||||
),
|
||||
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
||||
],
|
||||
test_mock={
|
||||
"read_folder": lambda *args, **kwargs: (
|
||||
[
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
}
|
||||
],
|
||||
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_folder(
|
||||
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
||||
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{folder_path}?ref={branch}"
|
||||
response = await api.get(contents_url)
|
||||
data = response.json()
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise TypeError("Not a folder")
|
||||
|
||||
files: list[GithubReadFolderBlock.Output.FileEntry] = [
|
||||
GithubReadFolderBlock.Output.FileEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
size=entry["size"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "file"
|
||||
]
|
||||
|
||||
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
|
||||
GithubReadFolderBlock.Output.DirEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "dir"
|
||||
]
|
||||
|
||||
return files, dirs
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
files, dirs = await self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
for file in files:
|
||||
yield "file", file
|
||||
for dir in dirs:
|
||||
yield "dir", dir
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
new_branch: str = SchemaField(
|
||||
description="Name of the new branch",
|
||||
placeholder="new_branch_name",
|
||||
)
|
||||
source_branch: str = SchemaField(
|
||||
description="Name of the source branch",
|
||||
placeholder="source_branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch creation operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
||||
description="This block creates a new branch from a specified source branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMakeBranchBlock.Input,
|
||||
output_schema=GithubMakeBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"new_branch": "new_branch_name",
|
||||
"source_branch": "source_branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch created successfully")],
|
||||
test_mock={
|
||||
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_branch(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
new_branch: str,
|
||||
source_branch: str,
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{source_branch}"
|
||||
response = await api.get(ref_url)
|
||||
data = response.json()
|
||||
sha = data["object"]["sha"]
|
||||
|
||||
# Create the new branch
|
||||
new_ref_url = repo_url + "/git/refs"
|
||||
data = {
|
||||
"ref": f"refs/heads/{new_branch}",
|
||||
"sha": sha,
|
||||
}
|
||||
response = await api.post(new_ref_url, json=data)
|
||||
return "Branch created successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
status = await self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Name of the branch to delete",
|
||||
placeholder="branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch deletion failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
||||
description="This block deletes a specified branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubDeleteBranchBlock.Input,
|
||||
output_schema=GithubDeleteBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch deleted successfully")],
|
||||
test_mock={
|
||||
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_branch(
|
||||
credentials: GithubCredentials, repo_url: str, branch: str
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{branch}"
|
||||
await api.delete(ref_url)
|
||||
return "Branch deleted successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
status = await self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubCreateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path where the file should be created",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="Content to write to the file",
|
||||
placeholder="File content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch where the file should be created",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Create new file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the created file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the file creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
||||
description="This block creates a new file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateFileBlock.Input,
|
||||
output_schema=GithubCreateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Test content",
|
||||
"branch": "main",
|
||||
"commit_message": "Create test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "abc123"),
|
||||
],
|
||||
test_mock={
|
||||
"create_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"abc123",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{file_path}"
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.create_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubUpdateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file to update",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="New content for the file",
|
||||
placeholder="Updated content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch containing the file",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Update file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the updated file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
||||
description="This block updates an existing file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUpdateFileBlock.Input,
|
||||
output_schema=GithubUpdateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Updated content",
|
||||
"branch": "main",
|
||||
"commit_message": "Update test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "def456"),
|
||||
],
|
||||
test_mock={
|
||||
"update_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"def456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{file_path}"
|
||||
params = {"ref": branch}
|
||||
response = await api.get(contents_url, params=params)
|
||||
data = response.json()
|
||||
|
||||
# Convert new content to base64
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"sha": data["sha"],
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.update_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCreateRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
@@ -449,7 +1103,7 @@ class GithubListStargazersBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e96d01ec-b55e-4a99-8ce8-c8776dce850b", # Generated unique UUID
|
||||
id="a4b9c2d1-e5f6-4g7h-8i9j-0k1l2m3n4o5p", # Generated unique UUID
|
||||
description="This block lists all users who have starred a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListStargazersBlock.Input,
|
||||
@@ -518,230 +1172,3 @@ class GithubListStargazersBlock(Block):
|
||||
yield "stargazers", stargazers
|
||||
for stargazer in stargazers:
|
||||
yield "stargazer", stargazer
|
||||
|
||||
|
||||
class GithubGetRepositoryInfoBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
name: str = SchemaField(description="Repository name")
|
||||
full_name: str = SchemaField(description="Full repository name (owner/repo)")
|
||||
description: str = SchemaField(description="Repository description")
|
||||
default_branch: str = SchemaField(description="Default branch name (e.g. main)")
|
||||
private: bool = SchemaField(description="Whether the repository is private")
|
||||
html_url: str = SchemaField(description="Web URL of the repository")
|
||||
clone_url: str = SchemaField(description="Git clone URL")
|
||||
stars: int = SchemaField(description="Number of stars")
|
||||
forks: int = SchemaField(description="Number of forks")
|
||||
open_issues: int = SchemaField(description="Number of open issues")
|
||||
error: str = SchemaField(
|
||||
description="Error message if fetching repo info failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="59d4f241-968a-4040-95da-348ac5c5ce27",
|
||||
description="This block retrieves metadata about a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetRepositoryInfoBlock.Input,
|
||||
output_schema=GithubGetRepositoryInfoBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("name", "repo"),
|
||||
("full_name", "owner/repo"),
|
||||
("description", "A test repo"),
|
||||
("default_branch", "main"),
|
||||
("private", False),
|
||||
("html_url", "https://github.com/owner/repo"),
|
||||
("clone_url", "https://github.com/owner/repo.git"),
|
||||
("stars", 42),
|
||||
("forks", 5),
|
||||
("open_issues", 3),
|
||||
],
|
||||
test_mock={
|
||||
"get_repo_info": lambda *args, **kwargs: {
|
||||
"name": "repo",
|
||||
"full_name": "owner/repo",
|
||||
"description": "A test repo",
|
||||
"default_branch": "main",
|
||||
"private": False,
|
||||
"html_url": "https://github.com/owner/repo",
|
||||
"clone_url": "https://github.com/owner/repo.git",
|
||||
"stargazers_count": 42,
|
||||
"forks_count": 5,
|
||||
"open_issues_count": 3,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_repo_info(credentials: GithubCredentials, repo_url: str) -> dict:
|
||||
api = get_api(credentials)
|
||||
response = await api.get(repo_url)
|
||||
return response.json()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.get_repo_info(credentials, input_data.repo_url)
|
||||
yield "name", data["name"]
|
||||
yield "full_name", data["full_name"]
|
||||
yield "description", data.get("description", "") or ""
|
||||
yield "default_branch", data["default_branch"]
|
||||
yield "private", data["private"]
|
||||
yield "html_url", data["html_url"]
|
||||
yield "clone_url", data["clone_url"]
|
||||
yield "stars", data["stargazers_count"]
|
||||
yield "forks", data["forks_count"]
|
||||
yield "open_issues", data["open_issues_count"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubForkRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository to fork",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
organization: str = SchemaField(
|
||||
description="Organization to fork into (leave empty to fork to your account)",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the forked repository")
|
||||
clone_url: str = SchemaField(description="Git clone URL of the fork")
|
||||
full_name: str = SchemaField(description="Full name of the fork (owner/repo)")
|
||||
error: str = SchemaField(description="Error message if the fork failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a439f2f4-835f-4dae-ba7b-0205ffa70be6",
|
||||
description="This block forks a GitHub repository to your account or an organization.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubForkRepositoryBlock.Input,
|
||||
output_schema=GithubForkRepositoryBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"organization": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/myuser/repo"),
|
||||
("clone_url", "https://github.com/myuser/repo.git"),
|
||||
("full_name", "myuser/repo"),
|
||||
],
|
||||
test_mock={
|
||||
"fork_repo": lambda *args, **kwargs: (
|
||||
"https://github.com/myuser/repo",
|
||||
"https://github.com/myuser/repo.git",
|
||||
"myuser/repo",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def fork_repo(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
organization: str,
|
||||
) -> tuple[str, str, str]:
|
||||
api = get_api(credentials)
|
||||
forks_url = repo_url + "/forks"
|
||||
data: dict[str, str] = {}
|
||||
if organization:
|
||||
data["organization"] = organization
|
||||
response = await api.post(forks_url, json=data)
|
||||
result = response.json()
|
||||
return result["html_url"], result["clone_url"], result["full_name"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, clone_url, full_name = await self.fork_repo(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.organization,
|
||||
)
|
||||
yield "url", url
|
||||
yield "clone_url", clone_url
|
||||
yield "full_name", full_name
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubStarRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository to star",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the star operation")
|
||||
error: str = SchemaField(description="Error message if starring failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bd700764-53e3-44dd-a969-d1854088458f",
|
||||
description="This block stars a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubStarRepositoryBlock.Input,
|
||||
output_schema=GithubStarRepositoryBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Repository starred successfully")],
|
||||
test_mock={
|
||||
"star_repo": lambda *args, **kwargs: "Repository starred successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def star_repo(credentials: GithubCredentials, repo_url: str) -> str:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
repo_path = github_repo_path(repo_url)
|
||||
owner, repo = repo_path.split("/")
|
||||
await api.put(
|
||||
f"https://api.github.com/user/starred/{owner}/{repo}",
|
||||
headers={"Content-Length": "0"},
|
||||
)
|
||||
return "Repository starred successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.star_repo(credentials, input_data.repo_url)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -1,452 +0,0 @@
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of branches to return per page (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class BranchItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
branch: BranchItem = SchemaField(
|
||||
title="Branch",
|
||||
description="Branches with their name and file tree browser URL",
|
||||
)
|
||||
branches: list[BranchItem] = SchemaField(
|
||||
description="List of branches with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing branches failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
||||
description="This block lists all branches for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListBranchesBlock.Input,
|
||||
output_schema=GithubListBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"per_page": 30,
|
||||
"page": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"branches",
|
||||
[
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"branch",
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_branches": lambda *args, **kwargs: [
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_branches(
|
||||
credentials: GithubCredentials, repo_url: str, per_page: int, page: int
|
||||
) -> list[Output.BranchItem]:
|
||||
api = get_api(credentials)
|
||||
branches_url = repo_url + "/branches"
|
||||
response = await api.get(
|
||||
branches_url, params={"per_page": str(per_page), "page": str(page)}
|
||||
)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
||||
{
|
||||
"name": branch["name"],
|
||||
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
|
||||
}
|
||||
for branch in data
|
||||
]
|
||||
return branches
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
branches = await self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.per_page,
|
||||
input_data.page,
|
||||
)
|
||||
yield "branches", branches
|
||||
for branch in branches:
|
||||
yield "branch", branch
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
new_branch: str = SchemaField(
|
||||
description="Name of the new branch",
|
||||
placeholder="new_branch_name",
|
||||
)
|
||||
source_branch: str = SchemaField(
|
||||
description="Name of the source branch",
|
||||
placeholder="source_branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch creation operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
||||
description="This block creates a new branch from a specified source branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMakeBranchBlock.Input,
|
||||
output_schema=GithubMakeBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"new_branch": "new_branch_name",
|
||||
"source_branch": "source_branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch created successfully")],
|
||||
test_mock={
|
||||
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_branch(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
new_branch: str,
|
||||
source_branch: str,
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{quote(source_branch, safe='')}"
|
||||
response = await api.get(ref_url)
|
||||
data = response.json()
|
||||
sha = data["object"]["sha"]
|
||||
|
||||
# Create the new branch
|
||||
new_ref_url = repo_url + "/git/refs"
|
||||
data = {
|
||||
"ref": f"refs/heads/{new_branch}",
|
||||
"sha": sha,
|
||||
}
|
||||
response = await api.post(new_ref_url, json=data)
|
||||
return "Branch created successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Name of the branch to delete",
|
||||
placeholder="branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch deletion failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
||||
description="This block deletes a specified branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubDeleteBranchBlock.Input,
|
||||
output_schema=GithubDeleteBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch deleted successfully")],
|
||||
test_mock={
|
||||
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_branch(
|
||||
credentials: GithubCredentials, repo_url: str, branch: str
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{quote(branch, safe='')}"
|
||||
await api.delete(ref_url)
|
||||
return "Branch deleted successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCompareBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
base: str = SchemaField(
|
||||
description="Base branch or commit SHA",
|
||||
placeholder="main",
|
||||
)
|
||||
head: str = SchemaField(
|
||||
description="Head branch or commit SHA to compare against base",
|
||||
placeholder="feature-branch",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class FileChange(TypedDict):
|
||||
filename: str
|
||||
status: str
|
||||
additions: int
|
||||
deletions: int
|
||||
patch: str
|
||||
|
||||
status: str = SchemaField(
|
||||
description="Comparison status: ahead, behind, diverged, or identical"
|
||||
)
|
||||
ahead_by: int = SchemaField(
|
||||
description="Number of commits head is ahead of base"
|
||||
)
|
||||
behind_by: int = SchemaField(
|
||||
description="Number of commits head is behind base"
|
||||
)
|
||||
total_commits: int = SchemaField(
|
||||
description="Total number of commits in the comparison"
|
||||
)
|
||||
diff: str = SchemaField(description="Unified diff of all file changes")
|
||||
file: FileChange = SchemaField(
|
||||
title="Changed File", description="A changed file with its diff"
|
||||
)
|
||||
files: list[FileChange] = SchemaField(
|
||||
description="List of changed files with their diffs"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if comparison failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2e4faa8c-6086-4546-ba77-172d1d560186",
|
||||
description="This block compares two branches or commits in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCompareBranchesBlock.Input,
|
||||
output_schema=GithubCompareBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"base": "main",
|
||||
"head": "feature",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("status", "ahead"),
|
||||
("ahead_by", 2),
|
||||
("behind_by", 0),
|
||||
("total_commits", 2),
|
||||
("diff", "+++ b/file.py\n+new line"),
|
||||
(
|
||||
"files",
|
||||
[
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"compare_branches": lambda *args, **kwargs: {
|
||||
"status": "ahead",
|
||||
"ahead_by": 2,
|
||||
"behind_by": 0,
|
||||
"total_commits": 2,
|
||||
"files": [
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def compare_branches(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
base: str,
|
||||
head: str,
|
||||
) -> dict:
|
||||
api = get_api(credentials)
|
||||
safe_base = quote(base, safe="")
|
||||
safe_head = quote(head, safe="")
|
||||
compare_url = repo_url + f"/compare/{safe_base}...{safe_head}"
|
||||
response = await api.get(compare_url)
|
||||
return response.json()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.compare_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.base,
|
||||
input_data.head,
|
||||
)
|
||||
yield "status", data["status"]
|
||||
yield "ahead_by", data["ahead_by"]
|
||||
yield "behind_by", data["behind_by"]
|
||||
yield "total_commits", data["total_commits"]
|
||||
|
||||
files: list[GithubCompareBranchesBlock.Output.FileChange] = [
|
||||
GithubCompareBranchesBlock.Output.FileChange(
|
||||
filename=f["filename"],
|
||||
status=f["status"],
|
||||
additions=f["additions"],
|
||||
deletions=f["deletions"],
|
||||
patch=f.get("patch", ""),
|
||||
)
|
||||
for f in data.get("files", [])
|
||||
]
|
||||
|
||||
# Build unified diff
|
||||
diff_parts = []
|
||||
for f in data.get("files", []):
|
||||
patch = f.get("patch", "")
|
||||
if patch:
|
||||
diff_parts.append(f"+++ b/{f['filename']}\n{patch}")
|
||||
yield "diff", "\n".join(diff_parts)
|
||||
|
||||
yield "files", files
|
||||
for file in files:
|
||||
yield "file", file
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,720 +0,0 @@
|
||||
import base64
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file in the repository",
|
||||
placeholder="path/to/file",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to read from",
|
||||
placeholder="branch_name",
|
||||
default="main",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
text_content: str = SchemaField(
|
||||
description="Content of the file (decoded as UTF-8 text)"
|
||||
)
|
||||
raw_content: str = SchemaField(
|
||||
description="Raw base64-encoded content of the file"
|
||||
)
|
||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||
error: str = SchemaField(description="Error message if reading the file failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
||||
description="This block reads the content of a specified file from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFileBlock.Input,
|
||||
output_schema=GithubReadFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "path/to/file",
|
||||
"branch": "main",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("raw_content", "RmlsZSBjb250ZW50"),
|
||||
("text_content", "File content"),
|
||||
("size", 13),
|
||||
],
|
||||
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_file(
|
||||
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
||||
) -> tuple[str, int]:
|
||||
api = get_api(credentials)
|
||||
content_url = (
|
||||
repo_url
|
||||
+ f"/contents/{quote(file_path, safe='')}?ref={quote(branch, safe='')}"
|
||||
)
|
||||
response = await api.get(content_url)
|
||||
data = response.json()
|
||||
|
||||
if isinstance(data, list):
|
||||
# Multiple entries of different types exist at this path
|
||||
if not (file := next((f for f in data if f["type"] == "file"), None)):
|
||||
raise TypeError("Not a file")
|
||||
data = file
|
||||
|
||||
if data["type"] != "file":
|
||||
raise TypeError("Not a file")
|
||||
|
||||
return data["content"], data["size"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
content, size = await self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", content
|
||||
yield "text_content", base64.b64decode(content).decode("utf-8")
|
||||
yield "size", size
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
folder_path: str = SchemaField(
|
||||
description="Path to the folder in the repository",
|
||||
placeholder="path/to/folder",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to read from (defaults to main)",
|
||||
placeholder="branch_name",
|
||||
default="main",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class DirEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
|
||||
class FileEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
size: int
|
||||
|
||||
file: FileEntry = SchemaField(description="Files in the folder")
|
||||
dir: DirEntry = SchemaField(description="Directories in the folder")
|
||||
error: str = SchemaField(
|
||||
description="Error message if reading the folder failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
||||
description="This block reads the content of a specified folder from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFolderBlock.Input,
|
||||
output_schema=GithubReadFolderBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"folder_path": "path/to/folder",
|
||||
"branch": "main",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
},
|
||||
),
|
||||
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
||||
],
|
||||
test_mock={
|
||||
"read_folder": lambda *args, **kwargs: (
|
||||
[
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
}
|
||||
],
|
||||
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_folder(
|
||||
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
||||
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
||||
api = get_api(credentials)
|
||||
contents_url = (
|
||||
repo_url
|
||||
+ f"/contents/{quote(folder_path, safe='/')}?ref={quote(branch, safe='')}"
|
||||
)
|
||||
response = await api.get(contents_url)
|
||||
data = response.json()
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise TypeError("Not a folder")
|
||||
|
||||
files: list[GithubReadFolderBlock.Output.FileEntry] = [
|
||||
GithubReadFolderBlock.Output.FileEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
size=entry["size"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "file"
|
||||
]
|
||||
|
||||
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
|
||||
GithubReadFolderBlock.Output.DirEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "dir"
|
||||
]
|
||||
|
||||
return files, dirs
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
files, dirs = await self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
for file in files:
|
||||
yield "file", file
|
||||
for dir in dirs:
|
||||
yield "dir", dir
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCreateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path where the file should be created",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="Content to write to the file",
|
||||
placeholder="File content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch where the file should be created",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Create new file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the created file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the file creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
||||
description="This block creates a new file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateFileBlock.Input,
|
||||
output_schema=GithubCreateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Test content",
|
||||
"branch": "main",
|
||||
"commit_message": "Create test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "abc123"),
|
||||
],
|
||||
test_mock={
|
||||
"create_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"abc123",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.create_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubUpdateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file to update",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="New content for the file",
|
||||
placeholder="Updated content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch containing the file",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Update file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the updated file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
||||
description="This block updates an existing file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUpdateFileBlock.Input,
|
||||
output_schema=GithubUpdateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Updated content",
|
||||
"branch": "main",
|
||||
"commit_message": "Update test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "def456"),
|
||||
],
|
||||
test_mock={
|
||||
"update_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"def456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
|
||||
params = {"ref": branch}
|
||||
response = await api.get(contents_url, params=params)
|
||||
data = response.json()
|
||||
|
||||
# Convert new content to base64
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"sha": data["sha"],
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.update_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubSearchCodeBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
query: str = SchemaField(
|
||||
description="Search query (GitHub code search syntax)",
|
||||
placeholder="className language:python",
|
||||
)
|
||||
repo: str = SchemaField(
|
||||
description="Restrict search to a repository (owner/repo format, optional)",
|
||||
default="",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of results to return (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class SearchResult(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
repository: str
|
||||
url: str
|
||||
score: float
|
||||
|
||||
result: SearchResult = SchemaField(
|
||||
title="Result", description="A code search result"
|
||||
)
|
||||
results: list[SearchResult] = SchemaField(
|
||||
description="List of code search results"
|
||||
)
|
||||
total_count: int = SchemaField(description="Total number of matching results")
|
||||
error: str = SchemaField(description="Error message if search failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="47f94891-a2b1-4f1c-b5f2-573c043f721e",
|
||||
description="This block searches for code in GitHub repositories.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubSearchCodeBlock.Input,
|
||||
output_schema=GithubSearchCodeBlock.Output,
|
||||
test_input={
|
||||
"query": "addClass",
|
||||
"repo": "owner/repo",
|
||||
"per_page": 30,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("total_count", 1),
|
||||
(
|
||||
"results",
|
||||
[
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"result",
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"search_code": lambda *args, **kwargs: (
|
||||
1,
|
||||
[
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
}
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def search_code(
|
||||
credentials: GithubCredentials,
|
||||
query: str,
|
||||
repo: str,
|
||||
per_page: int,
|
||||
) -> tuple[int, list[Output.SearchResult]]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
full_query = f"{query} repo:{repo}" if repo else query
|
||||
params = {"q": full_query, "per_page": str(per_page)}
|
||||
response = await api.get("https://api.github.com/search/code", params=params)
|
||||
data = response.json()
|
||||
results: list[GithubSearchCodeBlock.Output.SearchResult] = [
|
||||
GithubSearchCodeBlock.Output.SearchResult(
|
||||
name=item["name"],
|
||||
path=item["path"],
|
||||
repository=item["repository"]["full_name"],
|
||||
url=item["html_url"],
|
||||
score=item["score"],
|
||||
)
|
||||
for item in data["items"]
|
||||
]
|
||||
return data["total_count"], results
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
total_count, results = await self.search_code(
|
||||
credentials,
|
||||
input_data.query,
|
||||
input_data.repo,
|
||||
input_data.per_page,
|
||||
)
|
||||
yield "total_count", total_count
|
||||
yield "results", results
|
||||
for result in results:
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubGetRepositoryTreeBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to get the tree from",
|
||||
default="main",
|
||||
)
|
||||
recursive: bool = SchemaField(
|
||||
description="Whether to recursively list the entire tree",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class TreeEntry(TypedDict):
|
||||
path: str
|
||||
type: str
|
||||
size: int
|
||||
sha: str
|
||||
|
||||
entry: TreeEntry = SchemaField(
|
||||
title="Tree Entry", description="A file or directory in the tree"
|
||||
)
|
||||
entries: list[TreeEntry] = SchemaField(
|
||||
description="List of all files and directories in the tree"
|
||||
)
|
||||
truncated: bool = SchemaField(
|
||||
description="Whether the tree was truncated due to size"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if getting tree failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="89c5c0ec-172e-4001-a32c-bdfe4d0c9e81",
|
||||
description="This block lists the entire file tree of a GitHub repository recursively.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetRepositoryTreeBlock.Input,
|
||||
output_schema=GithubGetRepositoryTreeBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "main",
|
||||
"recursive": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("truncated", False),
|
||||
(
|
||||
"entries",
|
||||
[
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"entry",
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_tree": lambda *args, **kwargs: (
|
||||
False,
|
||||
[
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
}
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_tree(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
recursive: bool,
|
||||
) -> tuple[bool, list[Output.TreeEntry]]:
|
||||
api = get_api(credentials)
|
||||
tree_url = repo_url + f"/git/trees/{quote(branch, safe='')}"
|
||||
params = {"recursive": "1"} if recursive else {}
|
||||
response = await api.get(tree_url, params=params)
|
||||
data = response.json()
|
||||
entries: list[GithubGetRepositoryTreeBlock.Output.TreeEntry] = [
|
||||
GithubGetRepositoryTreeBlock.Output.TreeEntry(
|
||||
path=item["path"],
|
||||
type=item["type"],
|
||||
size=item.get("size", 0),
|
||||
sha=item["sha"],
|
||||
)
|
||||
for item in data["tree"]
|
||||
]
|
||||
return data.get("truncated", False), entries
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
truncated, entries = await self.get_tree(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.recursive,
|
||||
)
|
||||
yield "truncated", truncated
|
||||
yield "entries", entries
|
||||
for entry in entries:
|
||||
yield "entry", entry
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,125 +0,0 @@
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.github._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from backend.blocks.github.commits import FileOperation, GithubMultiFileCommitBlock
|
||||
from backend.blocks.github.pull_requests import (
|
||||
GithubMergePullRequestBlock,
|
||||
prepare_pr_api_url,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
# ── prepare_pr_api_url tests ──
|
||||
|
||||
|
||||
class TestPreparePrApiUrl:
|
||||
def test_https_scheme_preserved(self):
|
||||
result = prepare_pr_api_url("https://github.com/owner/repo/pull/42", "merge")
|
||||
assert result == "https://github.com/owner/repo/pulls/42/merge"
|
||||
|
||||
def test_http_scheme_preserved(self):
|
||||
result = prepare_pr_api_url("http://github.com/owner/repo/pull/1", "files")
|
||||
assert result == "http://github.com/owner/repo/pulls/1/files"
|
||||
|
||||
def test_no_scheme_defaults_to_https(self):
|
||||
result = prepare_pr_api_url("github.com/owner/repo/pull/5", "merge")
|
||||
assert result == "https://github.com/owner/repo/pulls/5/merge"
|
||||
|
||||
def test_reviewers_path(self):
|
||||
result = prepare_pr_api_url(
|
||||
"https://github.com/owner/repo/pull/99", "requested_reviewers"
|
||||
)
|
||||
assert result == "https://github.com/owner/repo/pulls/99/requested_reviewers"
|
||||
|
||||
def test_invalid_url_returned_as_is(self):
|
||||
url = "https://example.com/not-a-pr"
|
||||
assert prepare_pr_api_url(url, "merge") == url
|
||||
|
||||
def test_empty_string(self):
|
||||
assert prepare_pr_api_url("", "merge") == ""
|
||||
|
||||
|
||||
# ── Error-path block tests ──
|
||||
# When a block's run() yields ("error", msg), _execute() converts it to a
|
||||
# BlockExecutionError. We call block.execute() directly (not execute_block_test,
|
||||
# which returns early on empty test_output).
|
||||
|
||||
|
||||
def _mock_block(block, mocks: dict):
|
||||
"""Apply mocks to a block's static methods, wrapping sync mocks as async."""
|
||||
for name, mock_fn in mocks.items():
|
||||
original = getattr(block, name)
|
||||
if inspect.iscoroutinefunction(original):
|
||||
|
||||
async def async_mock(*args, _fn=mock_fn, **kwargs):
|
||||
return _fn(*args, **kwargs)
|
||||
|
||||
setattr(block, name, async_mock)
|
||||
else:
|
||||
setattr(block, name, mock_fn)
|
||||
|
||||
|
||||
def _raise(exc: Exception):
|
||||
"""Helper that returns a callable which raises the given exception."""
|
||||
|
||||
def _raiser(*args, **kwargs):
|
||||
raise exc
|
||||
|
||||
return _raiser
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_pr_error_path():
|
||||
block = GithubMergePullRequestBlock()
|
||||
_mock_block(block, {"merge_pr": _raise(RuntimeError("PR not mergeable"))})
|
||||
input_data = {
|
||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||
"merge_method": "squash",
|
||||
"commit_title": "",
|
||||
"commit_message": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="PR not mergeable"):
|
||||
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_file_commit_error_path():
|
||||
block = GithubMultiFileCommitBlock()
|
||||
_mock_block(block, {"multi_file_commit": _raise(RuntimeError("ref update failed"))})
|
||||
input_data = {
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "feature",
|
||||
"commit_message": "test",
|
||||
"files": [{"path": "a.py", "content": "x", "operation": "upsert"}],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="ref update failed"):
|
||||
async for _ in block.execute(
|
||||
input_data,
|
||||
credentials=TEST_CREDENTIALS,
|
||||
execution_context=ExecutionContext(),
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
# ── FileOperation enum tests ──
|
||||
|
||||
|
||||
class TestFileOperation:
|
||||
def test_upsert_value(self):
|
||||
assert FileOperation.UPSERT == "upsert"
|
||||
|
||||
def test_delete_value(self):
|
||||
assert FileOperation.DELETE == "delete"
|
||||
|
||||
def test_invalid_value_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
FileOperation("create")
|
||||
|
||||
def test_invalid_value_raises_typo(self):
|
||||
with pytest.raises(ValueError):
|
||||
FileOperation("upser")
|
||||
@@ -241,8 +241,8 @@ class GmailBase(Block, ABC):
|
||||
h.ignore_links = False
|
||||
h.ignore_images = True
|
||||
return h.handle(html_content)
|
||||
except Exception:
|
||||
# Keep extraction resilient if html2text is unavailable or fails.
|
||||
except ImportError:
|
||||
# Fallback: return raw HTML if html2text is not available
|
||||
return html_content
|
||||
|
||||
# Handle content stored as attachment
|
||||
|
||||
@@ -211,7 +211,7 @@ class AgentOutputBlock(Block):
|
||||
if input_data.format:
|
||||
try:
|
||||
formatter = TextFormatter(autoescape=input_data.escape_html)
|
||||
yield "output", await formatter.format_string(
|
||||
yield "output", formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -33,13 +33,6 @@ from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.openai_responses import (
|
||||
convert_tools_to_responses_format,
|
||||
extract_responses_content,
|
||||
extract_responses_reasoning,
|
||||
extract_responses_tool_calls,
|
||||
extract_responses_usage,
|
||||
)
|
||||
from backend.util.prompt import compress_context, estimate_token_count
|
||||
from backend.util.request import validate_url_host
|
||||
from backend.util.settings import Settings
|
||||
@@ -118,6 +111,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
GPT4O_MINI = "gpt-4o-mini"
|
||||
GPT4O = "gpt-4o"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
@@ -146,31 +140,19 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenRouter models
|
||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
||||
GEMINI_2_5_PRO_PREVIEW = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro"
|
||||
GEMINI_3_1_PRO_PREVIEW = "google/gemini-3.1-pro-preview"
|
||||
GEMINI_3_FLASH_PREVIEW = "google/gemini-3-flash-preview"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
|
||||
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
||||
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
||||
GEMINI_3_1_FLASH_LITE_PREVIEW = "google/gemini-3.1-flash-lite-preview"
|
||||
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
|
||||
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
|
||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||
MISTRAL_LARGE_3 = "mistralai/mistral-large-2512"
|
||||
MISTRAL_MEDIUM_3_1 = "mistralai/mistral-medium-3.1"
|
||||
MISTRAL_SMALL_3_2 = "mistralai/mistral-small-3.2-24b-instruct"
|
||||
CODESTRAL = "mistralai/codestral-2508"
|
||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
|
||||
COHERE_COMMAND_A_03_2025 = "cohere/command-a-03-2025"
|
||||
COHERE_COMMAND_A_TRANSLATE_08_2025 = "cohere/command-a-translate-08-2025"
|
||||
COHERE_COMMAND_A_REASONING_08_2025 = "cohere/command-a-reasoning-08-2025"
|
||||
COHERE_COMMAND_A_VISION_07_2025 = "cohere/command-a-vision-07-2025"
|
||||
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
|
||||
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
|
||||
PERPLEXITY_SONAR = "perplexity/sonar"
|
||||
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
||||
PERPLEXITY_SONAR_REASONING_PRO = "perplexity/sonar-reasoning-pro"
|
||||
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||
@@ -178,11 +160,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||
MICROSOFT_PHI_4 = "microsoft/phi-4"
|
||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
||||
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
||||
GROK_3 = "x-ai/grok-3"
|
||||
GROK_4 = "x-ai/grok-4"
|
||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||
@@ -283,6 +263,9 @@ MODEL_METADATA = {
|
||||
LlmModel.GPT4_TURBO: ModelMetadata(
|
||||
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
|
||||
), # gpt-4-turbo-2024-04-09
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata(
|
||||
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
|
||||
), # gpt-3.5-turbo-0125
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3
|
||||
@@ -357,41 +340,17 @@ MODEL_METADATA = {
|
||||
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
|
||||
),
|
||||
# https://openrouter.ai/models
|
||||
LlmModel.GEMINI_2_5_PRO_PREVIEW: ModelMetadata(
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
1050000,
|
||||
8192,
|
||||
"Gemini 2.5 Pro Preview 03.25",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 2.5 Pro",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_1_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3.1 Pro Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_FLASH_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3 Flash Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
|
||||
@@ -399,15 +358,6 @@ MODEL_METADATA = {
|
||||
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
|
||||
),
|
||||
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3.1 Flash Lite Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
@@ -429,78 +379,12 @@ MODEL_METADATA = {
|
||||
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
|
||||
),
|
||||
LlmModel.MISTRAL_LARGE_3: ModelMetadata(
|
||||
"open_router",
|
||||
262144,
|
||||
None,
|
||||
"Mistral Large 3 2512",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
2,
|
||||
),
|
||||
LlmModel.MISTRAL_MEDIUM_3_1: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
None,
|
||||
"Mistral Medium 3.1",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
2,
|
||||
),
|
||||
LlmModel.MISTRAL_SMALL_3_2: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
131072,
|
||||
"Mistral Small 3.2 24B",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
1,
|
||||
),
|
||||
LlmModel.CODESTRAL: ModelMetadata(
|
||||
"open_router",
|
||||
256000,
|
||||
None,
|
||||
"Codestral 2508",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
1,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_03_2025: ModelMetadata(
|
||||
"open_router", 256000, 8192, "Command A 03.2025", "OpenRouter", "Cohere", 2
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8192,
|
||||
"Command A Translate 08.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
2,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: ModelMetadata(
|
||||
"open_router",
|
||||
256000,
|
||||
32768,
|
||||
"Command A Reasoning 08.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
3,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_VISION_07_2025: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8192,
|
||||
"Command A Vision 07.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
2,
|
||||
),
|
||||
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
||||
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
|
||||
),
|
||||
@@ -513,15 +397,6 @@ MODEL_METADATA = {
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
|
||||
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8000,
|
||||
"Sonar Reasoning Pro",
|
||||
"OpenRouter",
|
||||
"Perplexity",
|
||||
2,
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
@@ -567,9 +442,6 @@ MODEL_METADATA = {
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
||||
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
|
||||
),
|
||||
LlmModel.MICROSOFT_PHI_4: ModelMetadata(
|
||||
"open_router", 16384, 16384, "Phi-4", "OpenRouter", "Microsoft", 1
|
||||
),
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
||||
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
|
||||
),
|
||||
@@ -579,15 +451,6 @@ MODEL_METADATA = {
|
||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
|
||||
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
|
||||
),
|
||||
LlmModel.GROK_3: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
131072,
|
||||
"Grok 3",
|
||||
"OpenRouter",
|
||||
"xAI",
|
||||
2,
|
||||
),
|
||||
LlmModel.GROK_4: ModelMetadata(
|
||||
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
|
||||
),
|
||||
@@ -804,53 +667,36 @@ async def llm_call(
|
||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
||||
|
||||
if provider == "openai":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
oai_client = openai.AsyncOpenAI(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = None
|
||||
|
||||
tools_param = convert_tools_to_responses_format(tools) if tools else openai.omit
|
||||
parallel_tool_calls = get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
text_config = openai.omit
|
||||
if force_json_output:
|
||||
text_config = {"format": {"type": "json_object"}} # type: ignore
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = await oai_client.responses.create(
|
||||
response = await oai_client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
input=prompt, # type: ignore[arg-type]
|
||||
tools=tools_param, # type: ignore[arg-type]
|
||||
max_output_tokens=max_tokens,
|
||||
parallel_tool_calls=get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
),
|
||||
text=text_config, # type: ignore[arg-type]
|
||||
store=False,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
)
|
||||
|
||||
raw_tool_calls = extract_responses_tool_calls(response)
|
||||
tool_calls = (
|
||||
[
|
||||
ToolContentBlock(
|
||||
id=tc["id"],
|
||||
type=tc["type"],
|
||||
function=ToolCall(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
),
|
||||
)
|
||||
for tc in raw_tool_calls
|
||||
]
|
||||
if raw_tool_calls
|
||||
else None
|
||||
)
|
||||
reasoning = extract_responses_reasoning(response)
|
||||
content = extract_responses_content(response)
|
||||
prompt_tokens, completion_tokens = extract_responses_usage(response)
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=response,
|
||||
raw_response=response.choices[0].message,
|
||||
prompt=prompt,
|
||||
response=content,
|
||||
response=response.choices[0].message.content or "",
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
@@ -1296,10 +1142,8 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = await fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = await fmt.format_string(
|
||||
input_data.sys_prompt, values
|
||||
)
|
||||
input_data.prompt = fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = fmt.format_string(input_data.sys_prompt, values)
|
||||
|
||||
if input_data.sys_prompt:
|
||||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||
|
||||
@@ -4,7 +4,7 @@ from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
import openai
|
||||
from pydantic import SecretStr, field_validator
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
@@ -13,7 +13,6 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -36,20 +35,6 @@ class PerplexityModel(str, Enum):
|
||||
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
|
||||
|
||||
def _sanitize_perplexity_model(value: Any) -> PerplexityModel:
|
||||
"""Return a valid PerplexityModel, falling back to SONAR for invalid values."""
|
||||
if isinstance(value, PerplexityModel):
|
||||
return value
|
||||
try:
|
||||
return PerplexityModel(value)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid PerplexityModel '{value}', "
|
||||
f"falling back to {PerplexityModel.SONAR.value}"
|
||||
)
|
||||
return PerplexityModel.SONAR
|
||||
|
||||
|
||||
PerplexityCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
|
||||
]
|
||||
@@ -88,25 +73,6 @@ class PerplexityBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
credentials: PerplexityCredentials = PerplexityCredentialsField()
|
||||
|
||||
@field_validator("model", mode="before")
|
||||
@classmethod
|
||||
def fallback_invalid_model(cls, v: Any) -> PerplexityModel:
|
||||
"""Fall back to SONAR if the model value is not a valid
|
||||
PerplexityModel (e.g. an OpenAI model ID set by the agent
|
||||
generator)."""
|
||||
return _sanitize_perplexity_model(v)
|
||||
|
||||
@classmethod
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
"""Sanitize the model field before JSON schema validation so that
|
||||
invalid values are replaced with the default instead of raising a
|
||||
BlockInputError."""
|
||||
model_value = data.get("model")
|
||||
if model_value is not None:
|
||||
data["model"] = _sanitize_perplexity_model(model_value).value
|
||||
return super().validate_data(data)
|
||||
|
||||
system_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="",
|
||||
|
||||
@@ -2232,7 +2232,6 @@ class DeleteRedditPostBlock(Block):
|
||||
("post_id", "abc123"),
|
||||
],
|
||||
test_mock={"delete_post": lambda creds, post_id: True},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -2291,7 +2290,6 @@ class DeleteRedditCommentBlock(Block):
|
||||
("comment_id", "xyz789"),
|
||||
],
|
||||
test_mock={"delete_comment": lambda creds, comment_id: True},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -72,7 +72,6 @@ class Slant3DCreateOrderBlock(Slant3DBlockBase):
|
||||
"_make_request": lambda *args, **kwargs: {"orderId": "314144241"},
|
||||
"_convert_to_color": lambda *args, **kwargs: "black",
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
async def run(
|
||||
|
||||
@@ -61,27 +61,20 @@ class ExecutionParams(BaseModel):
|
||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Return a list of tool_call_ids if the entry is a tool request.
|
||||
Supports OpenAI Chat Completions, Responses API, and Anthropic formats.
|
||||
Supports both OpenAI and Anthropics formats.
|
||||
"""
|
||||
tool_call_ids = []
|
||||
|
||||
# OpenAI Responses API: function_call items have type="function_call"
|
||||
if entry.get("type") == "function_call":
|
||||
if call_id := entry.get("call_id"):
|
||||
tool_call_ids.append(call_id)
|
||||
return tool_call_ids
|
||||
|
||||
if entry.get("role") != "assistant":
|
||||
return tool_call_ids
|
||||
|
||||
# OpenAI Chat Completions: check for tool_calls in the entry.
|
||||
# OpenAI: check for tool_calls in the entry.
|
||||
calls = entry.get("tool_calls")
|
||||
if isinstance(calls, list):
|
||||
for call in calls:
|
||||
if tool_id := call.get("id"):
|
||||
tool_call_ids.append(tool_id)
|
||||
|
||||
# Anthropic: check content items for tool_use type.
|
||||
# Anthropics: check content items for tool_use type.
|
||||
content = entry.get("content")
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
@@ -96,22 +89,16 @@ def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Return a list of tool_call_ids if the entry is a tool response.
|
||||
Supports OpenAI Chat Completions, Responses API, and Anthropic formats.
|
||||
Supports both OpenAI and Anthropics formats.
|
||||
"""
|
||||
tool_call_ids: list[str] = []
|
||||
|
||||
# OpenAI Responses API: function_call_output items
|
||||
if entry.get("type") == "function_call_output":
|
||||
if call_id := entry.get("call_id"):
|
||||
tool_call_ids.append(str(call_id))
|
||||
return tool_call_ids
|
||||
|
||||
# OpenAI Chat Completions: a tool response message with role "tool".
|
||||
# OpenAI: a tool response message with role "tool" and key "tool_call_id".
|
||||
if entry.get("role") == "tool":
|
||||
if tool_call_id := entry.get("tool_call_id"):
|
||||
tool_call_ids.append(str(tool_call_id))
|
||||
|
||||
# Anthropic: check content items for tool_result type.
|
||||
# Anthropics: check content items for tool_result type.
|
||||
if entry.get("role") == "user":
|
||||
content = entry.get("content")
|
||||
if isinstance(content, list):
|
||||
@@ -124,16 +111,14 @@ def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
|
||||
return tool_call_ids
|
||||
|
||||
|
||||
def _create_tool_response(
|
||||
call_id: str, output: Any, *, responses_api: bool = False
|
||||
) -> dict[str, Any]:
|
||||
def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Create a tool response message for OpenAI, Anthropic, or OpenAI Responses API,
|
||||
based on the tool_id format and the responses_api flag.
|
||||
Create a tool response message for either OpenAI or Anthropics,
|
||||
based on the tool_id format.
|
||||
"""
|
||||
content = output if isinstance(output, str) else json.dumps(output)
|
||||
|
||||
# Anthropic format: tool IDs typically start with "toolu_"
|
||||
# Anthropics format: tool IDs typically start with "toolu_"
|
||||
if call_id.startswith("toolu_"):
|
||||
return {
|
||||
"role": "user",
|
||||
@@ -143,11 +128,8 @@ def _create_tool_response(
|
||||
],
|
||||
}
|
||||
|
||||
# OpenAI Responses API format
|
||||
if responses_api:
|
||||
return {"type": "function_call_output", "call_id": call_id, "output": content}
|
||||
|
||||
# OpenAI Chat Completions format (default fallback)
|
||||
# OpenAI format: tool IDs typically start with "call_".
|
||||
# Or default fallback (if the tool_id doesn't match any known prefix)
|
||||
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
|
||||
|
||||
@@ -195,19 +177,10 @@ def _combine_tool_responses(tool_outputs: list[dict[str, Any]]) -> list[dict[str
|
||||
return tool_outputs
|
||||
|
||||
|
||||
def _convert_raw_response_to_dict(
|
||||
raw_response: Any,
|
||||
) -> dict[str, Any] | list[dict[str, Any]]:
|
||||
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Safely convert raw_response to dictionary format for conversation history.
|
||||
Handles different response types from different LLM providers.
|
||||
|
||||
For the OpenAI Responses API, the raw_response is the entire Response
|
||||
object. Its ``output`` items (messages, function_calls) are extracted
|
||||
individually so they can be used as valid input items on the next call.
|
||||
Returns a **list** of dicts in that case.
|
||||
|
||||
For Chat Completions / Anthropic / Ollama, returns a single dict.
|
||||
"""
|
||||
if isinstance(raw_response, str):
|
||||
# Ollama returns a string, convert to dict format
|
||||
@@ -215,28 +188,11 @@ def _convert_raw_response_to_dict(
|
||||
elif isinstance(raw_response, dict):
|
||||
# Already a dict (from tests or some providers)
|
||||
return raw_response
|
||||
elif _is_responses_api_object(raw_response):
|
||||
# OpenAI Responses API: extract individual output items
|
||||
items = [json.to_dict(item) for item in raw_response.output]
|
||||
return items if items else [{"role": "assistant", "content": ""}]
|
||||
else:
|
||||
# Chat Completions / Anthropic return message objects
|
||||
# OpenAI/Anthropic return objects, convert with json.to_dict
|
||||
return json.to_dict(raw_response)
|
||||
|
||||
|
||||
def _is_responses_api_object(obj: Any) -> bool:
|
||||
"""Detect an OpenAI Responses API Response object.
|
||||
|
||||
These have ``object == "response"`` and an ``output`` list, but no
|
||||
``role`` attribute (unlike ChatCompletionMessage).
|
||||
"""
|
||||
return (
|
||||
getattr(obj, "object", None) == "response"
|
||||
and hasattr(obj, "output")
|
||||
and not hasattr(obj, "role")
|
||||
)
|
||||
|
||||
|
||||
def get_pending_tool_calls(conversation_history: list[Any] | None) -> dict[str, int]:
|
||||
"""
|
||||
All the tool calls entry in the conversation history requires a response.
|
||||
@@ -798,34 +754,19 @@ class SmartDecisionMakerBlock(Block):
|
||||
self, prompt: list[dict], response, tool_outputs: list | None = None
|
||||
):
|
||||
"""Update conversation history with response and tool outputs."""
|
||||
converted = _convert_raw_response_to_dict(response.raw_response)
|
||||
# Don't add separate reasoning message with tool calls (breaks Anthropic's tool_use->tool_result pairing)
|
||||
assistant_message = _convert_raw_response_to_dict(response.raw_response)
|
||||
has_tool_calls = isinstance(assistant_message.get("content"), list) and any(
|
||||
item.get("type") == "tool_use"
|
||||
for item in assistant_message.get("content", [])
|
||||
)
|
||||
|
||||
if isinstance(converted, list):
|
||||
# Responses API: output items are already individual dicts
|
||||
has_tool_calls = any(
|
||||
item.get("type") == "function_call" for item in converted
|
||||
if response.reasoning and not has_tool_calls:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
)
|
||||
if response.reasoning and not has_tool_calls:
|
||||
prompt.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"[Reasoning]: {response.reasoning}",
|
||||
}
|
||||
)
|
||||
prompt.extend(converted)
|
||||
else:
|
||||
# Chat Completions / Anthropic: single assistant message dict
|
||||
has_tool_calls = isinstance(converted.get("content"), list) and any(
|
||||
item.get("type") == "tool_use" for item in converted.get("content", [])
|
||||
)
|
||||
if response.reasoning and not has_tool_calls:
|
||||
prompt.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"[Reasoning]: {response.reasoning}",
|
||||
}
|
||||
)
|
||||
prompt.append(converted)
|
||||
|
||||
prompt.append(assistant_message)
|
||||
|
||||
if tool_outputs:
|
||||
prompt.extend(tool_outputs)
|
||||
@@ -835,8 +776,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool_info: ToolInfo,
|
||||
execution_params: ExecutionParams,
|
||||
execution_processor: "ExecutionProcessor",
|
||||
*,
|
||||
responses_api: bool = False,
|
||||
) -> dict:
|
||||
"""Execute a single tool using the execution manager for proper integration."""
|
||||
# Lazy imports to avoid circular dependencies
|
||||
@@ -929,17 +868,13 @@ class SmartDecisionMakerBlock(Block):
|
||||
if node_outputs
|
||||
else "Tool executed successfully"
|
||||
)
|
||||
return _create_tool_response(
|
||||
tool_call.id, tool_response_content, responses_api=responses_api
|
||||
)
|
||||
return _create_tool_response(tool_call.id, tool_response_content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution with manager failed: {e}")
|
||||
# Return error response
|
||||
return _create_tool_response(
|
||||
tool_call.id,
|
||||
f"Tool execution failed: {str(e)}",
|
||||
responses_api=responses_api,
|
||||
tool_call.id, f"Tool execution failed: {str(e)}"
|
||||
)
|
||||
|
||||
async def _execute_tools_agent_mode(
|
||||
@@ -960,7 +895,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
"""Execute tools in agent mode with a loop until finished."""
|
||||
max_iterations = input_data.agent_mode_max_iterations
|
||||
iteration = 0
|
||||
use_responses_api = input_data.model.metadata.provider == "openai"
|
||||
|
||||
# Execution parameters for tool execution
|
||||
execution_params = ExecutionParams(
|
||||
@@ -1017,19 +951,14 @@ class SmartDecisionMakerBlock(Block):
|
||||
for tool_info in processed_tools:
|
||||
try:
|
||||
tool_response = await self._execute_single_tool_with_manager(
|
||||
tool_info,
|
||||
execution_params,
|
||||
execution_processor,
|
||||
responses_api=use_responses_api,
|
||||
tool_info, execution_params, execution_processor
|
||||
)
|
||||
tool_outputs.append(tool_response)
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution failed: {e}")
|
||||
# Create error response for the tool
|
||||
error_response = _create_tool_response(
|
||||
tool_info.tool_call.id,
|
||||
f"Error: {str(e)}",
|
||||
responses_api=use_responses_api,
|
||||
tool_info.tool_call.id, f"Error: {str(e)}"
|
||||
)
|
||||
tool_outputs.append(error_response)
|
||||
|
||||
@@ -1091,17 +1020,11 @@ class SmartDecisionMakerBlock(Block):
|
||||
if pending_tool_calls and input_data.last_tool_output is None:
|
||||
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
|
||||
|
||||
use_responses_api = input_data.model.metadata.provider == "openai"
|
||||
|
||||
tool_output = []
|
||||
if pending_tool_calls and input_data.last_tool_output is not None:
|
||||
first_call_id = next(iter(pending_tool_calls.keys()))
|
||||
tool_output.append(
|
||||
_create_tool_response(
|
||||
first_call_id,
|
||||
input_data.last_tool_output,
|
||||
responses_api=use_responses_api,
|
||||
)
|
||||
_create_tool_response(first_call_id, input_data.last_tool_output)
|
||||
)
|
||||
|
||||
prompt.extend(tool_output)
|
||||
@@ -1127,15 +1050,11 @@ class SmartDecisionMakerBlock(Block):
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = await llm.fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = await llm.fmt.format_string(
|
||||
input_data.sys_prompt, values
|
||||
)
|
||||
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
|
||||
|
||||
if input_data.sys_prompt and not any(
|
||||
p.get("role") == "system"
|
||||
and isinstance(p.get("content"), str)
|
||||
and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||
p["role"] == "system" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||
for p in prompt
|
||||
):
|
||||
prompt.append(
|
||||
@@ -1146,9 +1065,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
|
||||
if input_data.prompt and not any(
|
||||
p.get("role") == "user"
|
||||
and isinstance(p.get("content"), str)
|
||||
and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||
p["role"] == "user" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||
for p in prompt
|
||||
):
|
||||
prompt.append(
|
||||
@@ -1256,26 +1173,11 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
yield emit_key, arg_value
|
||||
|
||||
converted = _convert_raw_response_to_dict(response.raw_response)
|
||||
|
||||
# Check for tool calls to avoid inserting reasoning between tool pairs
|
||||
if isinstance(converted, list):
|
||||
has_tool_calls = any(
|
||||
item.get("type") == "function_call" for item in converted
|
||||
)
|
||||
else:
|
||||
has_tool_calls = isinstance(converted.get("content"), list) and any(
|
||||
item.get("type") == "tool_use" for item in converted.get("content", [])
|
||||
)
|
||||
|
||||
if response.reasoning and not has_tool_calls:
|
||||
if response.reasoning:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
)
|
||||
|
||||
if isinstance(converted, list):
|
||||
prompt.extend(converted)
|
||||
else:
|
||||
prompt.append(converted)
|
||||
prompt.append(_convert_raw_response_to_dict(response.raw_response))
|
||||
|
||||
yield "conversations", prompt
|
||||
|
||||
@@ -1,223 +0,0 @@
|
||||
"""Tests for AutoPilotBlock: recursion guard, streaming, validation, and error paths."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.autopilot import (
|
||||
AUTOPILOT_BLOCK_ID,
|
||||
AutoPilotBlock,
|
||||
_autopilot_recursion_depth,
|
||||
_autopilot_recursion_limit,
|
||||
_check_recursion,
|
||||
_reset_recursion,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
|
||||
def _make_context(user_id: str = "test-user-123") -> ExecutionContext:
|
||||
"""Helper to build an ExecutionContext for tests."""
|
||||
return ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id="graph-1",
|
||||
graph_exec_id="gexec-1",
|
||||
graph_version=1,
|
||||
node_id="node-1",
|
||||
node_exec_id="nexec-1",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recursion guard unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRecursion:
|
||||
"""Unit tests for _check_recursion / _reset_recursion."""
|
||||
|
||||
def test_first_call_increments_depth(self):
|
||||
tokens = _check_recursion(3)
|
||||
try:
|
||||
assert _autopilot_recursion_depth.get() == 1
|
||||
assert _autopilot_recursion_limit.get() == 3
|
||||
finally:
|
||||
_reset_recursion(tokens)
|
||||
|
||||
def test_reset_restores_previous_values(self):
|
||||
assert _autopilot_recursion_depth.get() == 0
|
||||
assert _autopilot_recursion_limit.get() is None
|
||||
tokens = _check_recursion(5)
|
||||
_reset_recursion(tokens)
|
||||
assert _autopilot_recursion_depth.get() == 0
|
||||
assert _autopilot_recursion_limit.get() is None
|
||||
|
||||
def test_exceeding_limit_raises(self):
|
||||
t1 = _check_recursion(2)
|
||||
try:
|
||||
t2 = _check_recursion(2)
|
||||
try:
|
||||
with pytest.raises(RuntimeError, match="recursion depth limit"):
|
||||
_check_recursion(2)
|
||||
finally:
|
||||
_reset_recursion(t2)
|
||||
finally:
|
||||
_reset_recursion(t1)
|
||||
|
||||
def test_nested_calls_respect_inherited_limit(self):
|
||||
"""Inner call with higher max_depth still respects outer limit."""
|
||||
t1 = _check_recursion(2) # sets limit=2
|
||||
try:
|
||||
t2 = _check_recursion(10) # inner wants 10, but inherited is 2
|
||||
try:
|
||||
# depth is now 2, limit is min(10, 2) = 2 → should raise
|
||||
with pytest.raises(RuntimeError, match="recursion depth limit"):
|
||||
_check_recursion(10)
|
||||
finally:
|
||||
_reset_recursion(t2)
|
||||
finally:
|
||||
_reset_recursion(t1)
|
||||
|
||||
def test_limit_of_one_blocks_immediately_on_second_call(self):
|
||||
t1 = _check_recursion(1)
|
||||
try:
|
||||
with pytest.raises(RuntimeError):
|
||||
_check_recursion(1)
|
||||
finally:
|
||||
_reset_recursion(t1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AutoPilotBlock.run() validation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunValidation:
|
||||
"""Tests for input validation in AutoPilotBlock.run()."""
|
||||
|
||||
@pytest.fixture
|
||||
def block(self):
|
||||
return AutoPilotBlock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_prompt_yields_error(self, block):
|
||||
block.Input # ensure schema is accessible
|
||||
input_data = block.Input(prompt=" ", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
assert outputs.get("error") == "Prompt cannot be empty."
|
||||
assert "response" not in outputs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_user_id_yields_error(self, block):
|
||||
input_data = block.Input(prompt="hello", max_recursion_depth=3)
|
||||
ctx = _make_context(user_id="")
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
assert "authenticated user" in outputs.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_run_yields_all_outputs(self, block):
|
||||
"""With execute_copilot mocked, run() should yield all 5 success outputs."""
|
||||
mock_result = (
|
||||
"Hello world",
|
||||
[],
|
||||
'[{"role":"user","content":"hi"}]',
|
||||
"sess-abc",
|
||||
{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
block.execute_copilot = AsyncMock(return_value=mock_result)
|
||||
block.create_session = AsyncMock(return_value="sess-abc")
|
||||
|
||||
input_data = block.Input(prompt="hi", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["response"] == "Hello world"
|
||||
assert outputs["tool_calls"] == []
|
||||
assert outputs["session_id"] == "sess-abc"
|
||||
assert outputs["token_usage"]["total_tokens"] == 15
|
||||
assert "error" not in outputs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_yields_error(self, block):
|
||||
"""On unexpected failure, run() should yield an error output."""
|
||||
block.execute_copilot = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
block.create_session = AsyncMock(return_value="sess-fail")
|
||||
|
||||
input_data = block.Input(prompt="do something", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["session_id"] == "sess-fail"
|
||||
assert "boom" in outputs.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancelled_error_yields_error_and_reraises(self, block):
|
||||
"""CancelledError should yield error, then re-raise."""
|
||||
block.execute_copilot = AsyncMock(side_effect=asyncio.CancelledError())
|
||||
block.create_session = AsyncMock(return_value="sess-cancel")
|
||||
|
||||
input_data = block.Input(prompt="do something", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["session_id"] == "sess-cancel"
|
||||
assert "cancelled" in outputs.get("error", "").lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_existing_session_id_skips_create(self, block):
|
||||
"""When session_id is provided, create_session should not be called."""
|
||||
mock_result = (
|
||||
"ok",
|
||||
[],
|
||||
"[]",
|
||||
"existing-sid",
|
||||
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
)
|
||||
block.execute_copilot = AsyncMock(return_value=mock_result)
|
||||
block.create_session = AsyncMock()
|
||||
|
||||
input_data = block.Input(
|
||||
prompt="test", session_id="existing-sid", max_recursion_depth=3
|
||||
)
|
||||
ctx = _make_context()
|
||||
async for _ in block.run(input_data, execution_context=ctx):
|
||||
pass
|
||||
|
||||
block.create_session.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Block registration / ID tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBlockRegistration:
|
||||
def test_block_id_matches_constant(self):
|
||||
block = AutoPilotBlock()
|
||||
assert block.id == AUTOPILOT_BLOCK_ID
|
||||
|
||||
def test_max_recursion_depth_has_upper_bound(self):
|
||||
"""Schema should enforce le=10."""
|
||||
schema = AutoPilotBlock.Input.model_json_schema()
|
||||
max_rec = schema["properties"]["max_recursion_depth"]
|
||||
assert (
|
||||
max_rec.get("maximum") == 10 or max_rec.get("exclusiveMaximum", 999) <= 11
|
||||
)
|
||||
|
||||
def test_output_schema_has_no_duplicate_error_field(self):
|
||||
"""Output should inherit error from BlockSchemaOutput, not redefine it."""
|
||||
# The field should exist (inherited) but there should be no explicit
|
||||
# redefinition. We verify by checking the class __annotations__ directly.
|
||||
assert "error" not in AutoPilotBlock.Output.__annotations__
|
||||
@@ -13,17 +13,18 @@ class TestLLMStatsTracking:
|
||||
"""Test that llm_call returns proper token counts in LLMResponse."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
# Mock the OpenAI Responses API response
|
||||
# Mock the OpenAI client
|
||||
mock_response = MagicMock()
|
||||
mock_response.output_text = "Test response"
|
||||
mock_response.output = []
|
||||
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
|
||||
mock_response.choices = [
|
||||
MagicMock(message=MagicMock(content="Test response", tool_calls=None))
|
||||
]
|
||||
mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20)
|
||||
|
||||
# Test with mocked OpenAI response
|
||||
with patch("openai.AsyncOpenAI") as mock_openai:
|
||||
mock_client = AsyncMock()
|
||||
mock_openai.return_value = mock_client
|
||||
mock_client.responses.create = AsyncMock(return_value=mock_response)
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
response = await llm.llm_call(
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
@@ -270,17 +271,30 @@ class TestLLMStatsTracking:
|
||||
mock_response = MagicMock()
|
||||
# Return different responses for chunk summary vs final summary
|
||||
if call_count == 1:
|
||||
mock_response.output_text = '<json_output id="test123456">{"summary": "Test chunk summary"}</json_output>'
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
content='<json_output id="test123456">{"summary": "Test chunk summary"}</json_output>',
|
||||
tool_calls=None,
|
||||
)
|
||||
)
|
||||
]
|
||||
else:
|
||||
mock_response.output_text = '<json_output id="test123456">{"final_summary": "Test final summary"}</json_output>'
|
||||
mock_response.output = []
|
||||
mock_response.usage = MagicMock(input_tokens=50, output_tokens=30)
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
content='<json_output id="test123456">{"final_summary": "Test final summary"}</json_output>',
|
||||
tool_calls=None,
|
||||
)
|
||||
)
|
||||
]
|
||||
mock_response.usage = MagicMock(prompt_tokens=50, completion_tokens=30)
|
||||
return mock_response
|
||||
|
||||
with patch("openai.AsyncOpenAI") as mock_openai:
|
||||
mock_client = AsyncMock()
|
||||
mock_openai.return_value = mock_client
|
||||
mock_client.responses.create = mock_create
|
||||
mock_client.chat.completions.create = mock_create
|
||||
|
||||
# Test with very short text (should only need 1 chunk + 1 final summary)
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
"""Unit tests for PerplexityBlock model fallback behavior."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.perplexity import (
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
PerplexityBlock,
|
||||
PerplexityModel,
|
||||
)
|
||||
|
||||
|
||||
def _make_input(**overrides) -> dict:
|
||||
defaults = {
|
||||
"prompt": "test query",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return defaults
|
||||
|
||||
|
||||
class TestPerplexityModelFallback:
|
||||
"""Tests for fallback_invalid_model field_validator."""
|
||||
|
||||
def test_invalid_model_falls_back_to_sonar(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="gpt-5.2-2025-12-11"))
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
def test_another_invalid_model_falls_back_to_sonar(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="gpt-4o"))
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
def test_valid_model_string_is_kept(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="perplexity/sonar-pro"))
|
||||
assert inp.model == PerplexityModel.SONAR_PRO
|
||||
|
||||
def test_valid_enum_value_is_kept(self):
|
||||
inp = PerplexityBlock.Input(
|
||||
**_make_input(model=PerplexityModel.SONAR_DEEP_RESEARCH)
|
||||
)
|
||||
assert inp.model == PerplexityModel.SONAR_DEEP_RESEARCH
|
||||
|
||||
def test_default_model_when_omitted(self):
|
||||
inp = PerplexityBlock.Input(**_make_input())
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_value",
|
||||
[
|
||||
"perplexity/sonar",
|
||||
"perplexity/sonar-pro",
|
||||
"perplexity/sonar-deep-research",
|
||||
],
|
||||
)
|
||||
def test_all_valid_models_accepted(self, model_value: str):
|
||||
inp = PerplexityBlock.Input(**_make_input(model=model_value))
|
||||
assert inp.model.value == model_value
|
||||
|
||||
|
||||
class TestPerplexityValidateData:
|
||||
"""Tests for validate_data which runs during block execution (before
|
||||
Pydantic instantiation). Invalid models must be sanitized here so
|
||||
JSON schema validation does not reject them."""
|
||||
|
||||
def test_invalid_model_sanitized_before_schema_validation(self):
|
||||
data = _make_input(model="gpt-5.2-2025-12-11")
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
assert data["model"] == PerplexityModel.SONAR.value
|
||||
|
||||
def test_valid_model_unchanged_by_validate_data(self):
|
||||
data = _make_input(model="perplexity/sonar-pro")
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
assert data["model"] == "perplexity/sonar-pro"
|
||||
|
||||
def test_missing_model_uses_default(self):
|
||||
data = _make_input() # no model key
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
inp = PerplexityBlock.Input(**data)
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
File diff suppressed because it is too large
Load Diff
@@ -290,9 +290,7 @@ class FillTextTemplateBlock(Block):
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
formatter = text.TextFormatter(autoescape=input_data.escape_html)
|
||||
yield "output", await formatter.format_string(
|
||||
input_data.format, input_data.values
|
||||
)
|
||||
yield "output", formatter.format_string(input_data.format, input_data.values)
|
||||
|
||||
|
||||
class CombineTextsBlock(Block):
|
||||
|
||||
@@ -36,15 +36,13 @@ from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from backend.copilot.service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
_get_openai_client,
|
||||
client,
|
||||
config,
|
||||
)
|
||||
from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -91,7 +89,7 @@ async def _compress_session_messages(
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=_get_openai_client(),
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
|
||||
@@ -223,10 +221,6 @@ async def stream_chat_completion_baseline(
|
||||
text_block_id = str(uuid.uuid4())
|
||||
text_started = False
|
||||
step_open = False
|
||||
# Token usage accumulators — populated from streaming chunks
|
||||
turn_prompt_tokens = 0
|
||||
turn_completion_tokens = 0
|
||||
_stream_error = False # Track whether an error occurred during streaming
|
||||
try:
|
||||
for _round in range(_MAX_TOOL_ROUNDS):
|
||||
# Open a new step for each LLM round
|
||||
@@ -238,31 +232,16 @@ async def stream_chat_completion_baseline(
|
||||
model=config.model,
|
||||
messages=openai_messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
if tools:
|
||||
create_kwargs["tools"] = tools
|
||||
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
|
||||
# Accumulate streamed response (text + tool calls)
|
||||
round_text = ""
|
||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in response:
|
||||
# Capture token usage from the streaming chunk.
|
||||
# OpenRouter normalises all providers into OpenAI format
|
||||
# where prompt_tokens already includes cached tokens
|
||||
# (unlike Anthropic's native API). Use += to sum all
|
||||
# tool-call rounds since each API call is independent.
|
||||
# NOTE: stream_options={"include_usage": True} is not
|
||||
# universally supported — some providers (Mistral, Llama
|
||||
# via OpenRouter) always return chunk.usage=None. When
|
||||
# that happens, tokens stay 0 and the tiktoken fallback
|
||||
# below activates. Fail-open: one round is estimated.
|
||||
if chunk.usage:
|
||||
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
|
||||
turn_completion_tokens += chunk.usage.completion_tokens or 0
|
||||
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if not delta:
|
||||
continue
|
||||
@@ -415,7 +394,6 @@ async def stream_chat_completion_baseline(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
_stream_error = True
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
|
||||
# Close any open text/step before emitting error
|
||||
@@ -433,49 +411,6 @@ async def stream_chat_completion_baseline(
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context teardown failed")
|
||||
|
||||
# Fallback: estimate tokens via tiktoken when the provider does
|
||||
# not honour stream_options={"include_usage": True}.
|
||||
# Count the full message list (system + history + turn) since
|
||||
# each API call sends the complete context window.
|
||||
# NOTE: This estimates one round's prompt tokens. Multi-round tool-calling
|
||||
# turns consume prompt tokens on each API call, so the total is underestimated.
|
||||
# Skip fallback when an error occurred and no output was produced —
|
||||
# charging rate-limit tokens for completely failed requests is unfair.
|
||||
if (
|
||||
turn_prompt_tokens == 0
|
||||
and turn_completion_tokens == 0
|
||||
and not (_stream_error and not assistant_text)
|
||||
):
|
||||
from backend.util.prompt import (
|
||||
estimate_token_count,
|
||||
estimate_token_count_str,
|
||||
)
|
||||
|
||||
turn_prompt_tokens = max(
|
||||
estimate_token_count(openai_messages, model=config.model), 1
|
||||
)
|
||||
turn_completion_tokens = estimate_token_count_str(
|
||||
assistant_text, model=config.model
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] No streaming usage reported; estimated tokens: "
|
||||
"prompt=%d, completion=%d",
|
||||
turn_prompt_tokens,
|
||||
turn_completion_tokens,
|
||||
)
|
||||
|
||||
# Persist token usage to session and record for rate limiting.
|
||||
# NOTE: OpenRouter folds cached tokens into prompt_tokens, so we
|
||||
# cannot break out cache_read/cache_creation weights. Users on the
|
||||
# baseline path may be slightly over-counted vs the SDK path.
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
|
||||
# Persist assistant response
|
||||
if assistant_text:
|
||||
session.messages.append(
|
||||
@@ -486,16 +421,4 @@ async def stream_chat_completion_baseline(
|
||||
except Exception as persist_err:
|
||||
logger.error("[Baseline] Failed to persist session: %s", persist_err)
|
||||
|
||||
# Yield usage and finish AFTER try/finally (not inside finally).
|
||||
# PEP 525 prohibits yielding from finally in async generators during
|
||||
# aclose() — doing so raises RuntimeError on client disconnect.
|
||||
# On GeneratorExit the client is already gone, so unreachable yields
|
||||
# are harmless; on normal completion they reach the SSE stream.
|
||||
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
|
||||
yield StreamUsage(
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
total_tokens=turn_prompt_tokens + turn_completion_tokens,
|
||||
)
|
||||
|
||||
yield StreamFinish()
|
||||
|
||||
@@ -70,27 +70,6 @@ class ChatConfig(BaseSettings):
|
||||
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
|
||||
)
|
||||
|
||||
# Rate limiting — token-based limits per day and per week.
|
||||
# Per-turn token cost varies with context size: ~10-15K for early turns,
|
||||
# ~30-50K mid-session, up to ~100K pre-compaction. Average across a
|
||||
# session with compaction cycles is ~25-35K tokens/turn, so 2.5M daily
|
||||
# allows ~70-100 turns/day.
|
||||
# Checked at the HTTP layer (routes.py) before each turn.
|
||||
#
|
||||
# TODO: These are deploy-time constants applied identically to every user.
|
||||
# If per-user or per-plan limits are needed (e.g., free tier vs paid), these
|
||||
# must move to the database (e.g., a UserPlan table) and get_usage_status /
|
||||
# check_rate_limit would look up each user's specific limits instead of
|
||||
# reading config.daily_token_limit / config.weekly_token_limit.
|
||||
daily_token_limit: int = Field(
|
||||
default=2_500_000,
|
||||
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
|
||||
)
|
||||
weekly_token_limit: int = Field(
|
||||
default=12_500_000,
|
||||
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
@@ -115,22 +94,10 @@ class ChatConfig(BaseSettings):
|
||||
description="Use --resume for multi-turn conversations instead of "
|
||||
"history compression. Falls back to compression when unavailable.",
|
||||
)
|
||||
use_openrouter: bool = Field(
|
||||
default=True,
|
||||
description="Enable routing API calls through the OpenRouter proxy. "
|
||||
"The actual decision also requires ``api_key`` and ``base_url`` — "
|
||||
"use the ``openrouter_active`` property for the final answer.",
|
||||
)
|
||||
use_claude_code_subscription: bool = Field(
|
||||
default=False,
|
||||
description="For personal/dev use: use Claude Code CLI subscription auth instead of API keys. Requires `claude login` on the host. Only works with SDK mode.",
|
||||
)
|
||||
test_mode: bool = Field(
|
||||
default=False,
|
||||
description="Use dummy service instead of real LLM calls. "
|
||||
"Send __test_transient_error__, __test_fatal_error__, or "
|
||||
"__test_slow_response__ to trigger specific scenarios.",
|
||||
)
|
||||
|
||||
# E2B Sandbox Configuration
|
||||
use_e2b_sandbox: bool = Field(
|
||||
@@ -148,7 +115,7 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=420, # 7 min safety net — allows headroom for compaction retries
|
||||
default=10800, # 3 hours — wall-clock timeout, not idle; explicit pause is primary
|
||||
description="E2B sandbox running-time timeout (seconds). "
|
||||
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
|
||||
"mechanism; this is the safety net.",
|
||||
@@ -158,21 +125,6 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
|
||||
)
|
||||
|
||||
@property
|
||||
def openrouter_active(self) -> bool:
|
||||
"""True when OpenRouter is enabled AND credentials are usable.
|
||||
|
||||
Single source of truth for "will the SDK route through OpenRouter?".
|
||||
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
|
||||
present — mirrors the fallback logic in ``_build_sdk_env``.
|
||||
"""
|
||||
if not self.use_openrouter:
|
||||
return False
|
||||
base = (self.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
return bool(self.api_key and base and base.startswith("http"))
|
||||
|
||||
@property
|
||||
def e2b_active(self) -> bool:
|
||||
"""True when E2B is enabled and the API key is present.
|
||||
@@ -195,6 +147,15 @@ class ChatConfig(BaseSettings):
|
||||
"""
|
||||
return self.e2b_api_key if self.e2b_active else None
|
||||
|
||||
@field_validator("use_e2b_sandbox", mode="before")
|
||||
@classmethod
|
||||
def get_use_e2b_sandbox(cls, v):
|
||||
"""Get use_e2b_sandbox from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_E2B_SANDBOX", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("e2b_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_e2b_api_key(cls, v):
|
||||
@@ -237,6 +198,26 @@ class ChatConfig(BaseSettings):
|
||||
v = OPENROUTER_BASE_URL
|
||||
return v
|
||||
|
||||
@field_validator("use_claude_agent_sdk", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_agent_sdk(cls, v):
|
||||
"""Get use_claude_agent_sdk from environment if not provided."""
|
||||
# Check environment variable - default to True if not set
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
# Default to True (SDK enabled by default)
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("use_claude_code_subscription", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_code_subscription(cls, v):
|
||||
"""Get use_claude_code_subscription from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_CODE_SUBSCRIPTION", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return False if v is None else v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
@@ -246,7 +227,6 @@ class ChatConfig(BaseSettings):
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
env_prefix = "CHAT_"
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore" # Ignore extra environment variables
|
||||
|
||||
@@ -6,70 +6,19 @@ from .config import ChatConfig
|
||||
|
||||
# Env vars that the ChatConfig validators read — must be cleared so they don't
|
||||
# override the explicit constructor values we pass in each test.
|
||||
_ENV_VARS_TO_CLEAR = (
|
||||
_E2B_ENV_VARS = (
|
||||
"CHAT_USE_E2B_SANDBOX",
|
||||
"CHAT_E2B_API_KEY",
|
||||
"E2B_API_KEY",
|
||||
"CHAT_USE_OPENROUTER",
|
||||
"CHAT_API_KEY",
|
||||
"OPEN_ROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"CHAT_BASE_URL",
|
||||
"OPENROUTER_BASE_URL",
|
||||
"OPENAI_BASE_URL",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _ENV_VARS_TO_CLEAR:
|
||||
def _clean_e2b_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _E2B_ENV_VARS:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestOpenrouterActive:
|
||||
"""Tests for the openrouter_active property."""
|
||||
|
||||
def test_enabled_with_credentials_returns_true(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is True
|
||||
|
||||
def test_enabled_but_missing_api_key_returns_false(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key=None,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is False
|
||||
|
||||
def test_disabled_returns_false_despite_credentials(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is False
|
||||
|
||||
def test_strips_v1_suffix_and_still_valid(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is True
|
||||
|
||||
def test_invalid_base_url_returns_false(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="not-a-url",
|
||||
)
|
||||
assert cfg.openrouter_active is False
|
||||
|
||||
|
||||
class TestE2BActive:
|
||||
"""Tests for the e2b_active property — single source of truth for E2B usage."""
|
||||
|
||||
|
||||
@@ -4,9 +4,6 @@
|
||||
# The hex suffix makes accidental LLM generation of these strings virtually
|
||||
# impossible, avoiding false-positive marker detection in normal conversation.
|
||||
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX = (
|
||||
"[__COPILOT_RETRYABLE_ERROR_a9c2__]" # ErrorCard + retry
|
||||
)
|
||||
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
|
||||
|
||||
# Prefix for all synthetic IDs generated by CoPilot block execution.
|
||||
@@ -38,24 +35,3 @@ def parse_node_id_from_exec_id(node_exec_id: str) -> str:
|
||||
Format: "{node_id}:{random_hex}" → returns "{node_id}".
|
||||
"""
|
||||
return node_exec_id.rsplit(COPILOT_NODE_EXEC_ID_SEPARATOR, 1)[0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transient Anthropic API error detection
|
||||
# ---------------------------------------------------------------------------
|
||||
# Patterns in error text that indicate a transient Anthropic API error
|
||||
# (ECONNRESET / dropped TCP connection) which is retryable.
|
||||
_TRANSIENT_ERROR_PATTERNS = (
|
||||
"socket connection was closed unexpectedly",
|
||||
"ECONNRESET",
|
||||
"connection was forcibly closed",
|
||||
"network socket disconnected",
|
||||
)
|
||||
|
||||
FRIENDLY_TRANSIENT_MSG = "Anthropic connection interrupted — please retry"
|
||||
|
||||
|
||||
def is_transient_api_error(error_text: str) -> bool:
|
||||
"""Return True if *error_text* matches a known transient Anthropic API error."""
|
||||
lower = error_text.lower()
|
||||
return any(pat.lower() in lower for pat in _TRANSIENT_ERROR_PATTERNS)
|
||||
|
||||
@@ -11,23 +11,12 @@ from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
# Allowed base directory for the Read tool. Public so service.py can use it
|
||||
# for sweep operations without depending on a private implementation detail.
|
||||
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
|
||||
# _projects_base() function.
|
||||
_config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects"))
|
||||
|
||||
# Compiled UUID pattern for validating conversation directory names.
|
||||
# Kept as a module-level constant so the security-relevant pattern is easy
|
||||
# to audit in one place and avoids recompilation on every call.
|
||||
_UUID_RE = re.compile(r"^[0-9a-f]{8}(?:-[0-9a-f]{4}){3}-[0-9a-f]{12}$", re.IGNORECASE)
|
||||
# Allowed base directory for the Read tool.
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
# Encoded project-directory name for the current session (e.g.
|
||||
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
|
||||
@@ -44,20 +33,11 @@ _current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
|
||||
def encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The Claude CLI encodes the absolute cwd as a directory name by replacing
|
||||
every non-alphanumeric character with ``-``. For example
|
||||
``/tmp/copilot-abc`` becomes ``-tmp-copilot-abc``.
|
||||
"""
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does."""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
# Keep the private alias for internal callers (backwards compat).
|
||||
_encode_cwd_for_cli = encode_cwd_for_cli
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
@@ -102,25 +82,12 @@ def resolve_sandbox_path(path: str) -> str:
|
||||
return normalized
|
||||
|
||||
|
||||
async def get_workspace_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped :class:`WorkspaceManager`.
|
||||
|
||||
Placed here (rather than in ``tools/workspace_files``) so that modules
|
||||
like ``sdk/file_ref`` can import it without triggering the heavy
|
||||
``tools/__init__`` import chain.
|
||||
"""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
|
||||
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
"""Return True if *path* is within an allowed host-filesystem location.
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``.
|
||||
The SDK nests tool-results under a conversation UUID directory;
|
||||
the UUID segment is validated with ``_UUID_RE``.
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
@@ -139,22 +106,10 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
|
||||
encoded = _current_project_dir.get("")
|
||||
if encoded:
|
||||
project_dir = os.path.realpath(os.path.join(SDK_PROJECTS_DIR, encoded))
|
||||
# Defence-in-depth: ensure project_dir didn't escape the base.
|
||||
if not project_dir.startswith(SDK_PROJECTS_DIR + os.sep):
|
||||
return False
|
||||
# Only allow: <encoded-cwd>/<uuid>/tool-results/<file>
|
||||
# The SDK always creates a conversation UUID directory between
|
||||
# the project dir and tool-results/.
|
||||
if resolved.startswith(project_dir + os.sep):
|
||||
relative = resolved[len(project_dir) + 1 :]
|
||||
parts = relative.split(os.sep)
|
||||
# Require exactly: [<uuid>, "tool-results", <file>, ...]
|
||||
if (
|
||||
len(parts) >= 3
|
||||
and _UUID_RE.match(parts[0])
|
||||
and parts[1] == "tool-results"
|
||||
):
|
||||
return True
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
if resolved == tool_results_dir or resolved.startswith(
|
||||
tool_results_dir + os.sep
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -9,7 +9,7 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import (
|
||||
SDK_PROJECTS_DIR,
|
||||
_SDK_PROJECTS_DIR,
|
||||
_current_project_dir,
|
||||
get_current_sandbox,
|
||||
get_execution_context,
|
||||
@@ -104,13 +104,11 @@ def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
|
||||
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_with_uuid():
|
||||
"""Files under <encoded-cwd>/<uuid>/tool-results/ are allowed."""
|
||||
def test_is_allowed_local_path_tool_results_dir():
|
||||
"""Files under the tool-results directory for the current project are allowed."""
|
||||
encoded = "test-encoded-dir"
|
||||
conv_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
path = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-results", "output.txt"
|
||||
)
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
path = os.path.join(tool_results_dir, "output.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
@@ -119,22 +117,10 @@ def test_is_allowed_local_path_tool_results_with_uuid():
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_without_uuid_rejected():
|
||||
"""Direct <encoded-cwd>/tool-results/ (no UUID) is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
path = os.path.join(SDK_PROJECTS_DIR, encoded, "tool-results", "output.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert not is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
"""A path adjacent to tool-results/ but not inside it is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
sibling_path = os.path.join(SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
@@ -143,21 +129,6 @@ def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_valid_uuid_wrong_segment_name_rejected():
|
||||
"""A valid UUID dir but non-'tool-results' second segment is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
uuid_str = "12345678-1234-5678-9abc-def012345678"
|
||||
path = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, uuid_str, "not-tool-results", "output.txt"
|
||||
)
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert not is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -16,7 +16,6 @@ from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
@@ -247,25 +246,17 @@ class CoPilotProcessor:
|
||||
# Choose service based on LaunchDarkly flag.
|
||||
# Claude Code subscription forces SDK mode (CLI subprocess auth).
|
||||
config = ChatConfig()
|
||||
|
||||
if config.test_mode:
|
||||
stream_fn = stream_chat_completion_dummy
|
||||
log.warning("Using DUMMY service (CHAT_TEST_MODE=true)")
|
||||
else:
|
||||
use_sdk = (
|
||||
config.use_claude_code_subscription
|
||||
or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
use_sdk = config.use_claude_code_subscription or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
async for chunk in stream_fn(
|
||||
|
||||
@@ -1,173 +0,0 @@
|
||||
"""Integration credential lookup with per-process TTL cache.
|
||||
|
||||
Provides token retrieval for connected integrations so that copilot tools
|
||||
(e.g. bash_exec) can inject auth tokens into the execution environment without
|
||||
hitting the database on every command.
|
||||
|
||||
Cache semantics (handled automatically by TTLCache):
|
||||
- Token found → cached for _TOKEN_CACHE_TTL (5 min). Avoids repeated DB hits
|
||||
for users who have credentials and are running many bash commands.
|
||||
- No credentials found → cached for _NULL_CACHE_TTL (60 s). Avoids a DB hit
|
||||
on every E2B command for users who haven't connected an account yet, while
|
||||
still picking up a newly-connected account within one minute.
|
||||
|
||||
Both caches are bounded to _CACHE_MAX_SIZE entries; cachetools evicts the
|
||||
least-recently-used entry when the limit is reached.
|
||||
|
||||
Multi-worker note: both caches are in-process only. Each worker/replica
|
||||
maintains its own independent cache, so a credential fetch may be duplicated
|
||||
across processes. This is acceptable for the current goal (reduce DB hits per
|
||||
session per-process), but if cache efficiency across replicas becomes important
|
||||
a shared cache (e.g. Redis) should be used instead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from backend.copilot.providers import SUPPORTED_PROVIDERS
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
register_creds_changed_hook,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Derived from the single SUPPORTED_PROVIDERS registry for backward compat.
|
||||
PROVIDER_ENV_VARS: dict[str, list[str]] = {
|
||||
slug: entry["env_vars"] for slug, entry in SUPPORTED_PROVIDERS.items()
|
||||
}
|
||||
|
||||
_TOKEN_CACHE_TTL = 300.0 # seconds — for found tokens
|
||||
_NULL_CACHE_TTL = 60.0 # seconds — for "not connected" results
|
||||
_CACHE_MAX_SIZE = 10_000
|
||||
|
||||
# (user_id, provider) → token string. TTLCache handles expiry + eviction.
|
||||
# Thread-safety note: TTLCache is NOT thread-safe, but that is acceptable here
|
||||
# because all callers (get_provider_token, invalidate_user_provider_cache) run
|
||||
# exclusively on the asyncio event loop. There are no await points between a
|
||||
# cache read and its corresponding write within any function, so no concurrent
|
||||
# coroutine can interleave. If ThreadPoolExecutor workers are ever added to
|
||||
# this path, a threading.RLock should be wrapped around these caches.
|
||||
_token_cache: TTLCache[tuple[str, str], str] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_TOKEN_CACHE_TTL
|
||||
)
|
||||
# Separate cache for "no credentials" results with a shorter TTL.
|
||||
_null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
"""Remove the cached entry for *user_id*/*provider* from both caches.
|
||||
|
||||
Call this after storing new credentials so that the next
|
||||
``get_provider_token()`` call performs a fresh DB lookup instead of
|
||||
serving a stale TTL-cached result.
|
||||
"""
|
||||
key = (user_id, provider)
|
||||
_token_cache.pop(key, None)
|
||||
_null_cache.pop(key, None)
|
||||
|
||||
|
||||
# Register this module's cache-bust function with the credentials manager so
|
||||
# that any create/update/delete operation immediately evicts stale cache
|
||||
# entries. This avoids a lazy import inside creds_manager and eliminates the
|
||||
# circular-import risk.
|
||||
try:
|
||||
register_creds_changed_hook(invalidate_user_provider_cache)
|
||||
except RuntimeError:
|
||||
# Hook already registered (e.g. module re-import in tests).
|
||||
pass
|
||||
|
||||
# Module-level singleton to avoid re-instantiating IntegrationCredentialsManager
|
||||
# on every cache-miss call to get_provider_token().
|
||||
_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
"""Return the user's access token for *provider*, or ``None`` if not connected.
|
||||
|
||||
OAuth2 tokens are preferred (refreshed if needed); API keys are the fallback.
|
||||
Found tokens are cached for _TOKEN_CACHE_TTL (5 min). "Not connected" results
|
||||
are cached for _NULL_CACHE_TTL (60 s) to avoid a DB hit on every bash_exec
|
||||
command for users who haven't connected yet, while still picking up a
|
||||
newly-connected account within one minute.
|
||||
"""
|
||||
cache_key = (user_id, provider)
|
||||
|
||||
if cache_key in _null_cache:
|
||||
return None
|
||||
if cached := _token_cache.get(cache_key):
|
||||
return cached
|
||||
|
||||
manager = _manager
|
||||
try:
|
||||
creds_list = await manager.store.get_creds_by_provider(user_id, provider)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch %s credentials for user %s",
|
||||
provider,
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
# Pass 1: prefer OAuth2 (carry scope info, refreshable via token endpoint).
|
||||
# Sort so broader-scoped tokens come first: a token with "repo" scope covers
|
||||
# full git access, while a public-data-only token lacks push/pull permission.
|
||||
# lock=False — background injection; not worth a distributed lock acquisition.
|
||||
oauth2_creds = sorted(
|
||||
[c for c in creds_list if c.type == "oauth2"],
|
||||
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
|
||||
)
|
||||
for creds in oauth2_creds:
|
||||
if creds.type == "oauth2":
|
||||
try:
|
||||
fresh = await manager.refresh_if_needed(
|
||||
user_id, cast(OAuth2Credentials, creds), lock=False
|
||||
)
|
||||
token = fresh.access_token.get_secret_value()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to refresh %s OAuth token for user %s; "
|
||||
"discarding stale token to force re-auth",
|
||||
provider,
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
# Do NOT fall back to the stale token — it is likely expired
|
||||
# or revoked. Returning None forces the caller to re-auth,
|
||||
# preventing the LLM from receiving a non-functional token.
|
||||
continue
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# Pass 2: fall back to API key (no expiry, no refresh needed).
|
||||
for creds in creds_list:
|
||||
if creds.type == "api_key":
|
||||
token = cast(APIKeyCredentials, creds).api_key.get_secret_value()
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# No credentials found — cache to avoid repeated DB hits.
|
||||
_null_cache[cache_key] = True
|
||||
return None
|
||||
|
||||
|
||||
async def get_integration_env_vars(user_id: str) -> dict[str, str]:
|
||||
"""Return env vars for all providers the user has connected.
|
||||
|
||||
Iterates :data:`PROVIDER_ENV_VARS`, fetches each token, and builds a flat
|
||||
``{env_var: token}`` dict ready to pass to a subprocess or E2B sandbox.
|
||||
Only providers with a stored credential contribute entries.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
for provider, var_names in PROVIDER_ENV_VARS.items():
|
||||
token = await get_provider_token(user_id, provider)
|
||||
if token:
|
||||
for var in var_names:
|
||||
env[var] = token
|
||||
return env
|
||||
@@ -1,195 +0,0 @@
|
||||
"""Tests for integration_creds — TTL cache and token lookup paths."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.copilot.integration_creds import (
|
||||
_NULL_CACHE_TTL,
|
||||
_TOKEN_CACHE_TTL,
|
||||
PROVIDER_ENV_VARS,
|
||||
_null_cache,
|
||||
_token_cache,
|
||||
get_integration_env_vars,
|
||||
get_provider_token,
|
||||
invalidate_user_provider_cache,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
|
||||
_USER = "user-integration-creds-test"
|
||||
_PROVIDER = "github"
|
||||
|
||||
|
||||
def _make_api_key_creds(key: str = "test-api-key") -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="creds-api-key",
|
||||
provider=_PROVIDER,
|
||||
api_key=SecretStr(key),
|
||||
title="Test API Key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_oauth2_creds(token: str = "test-oauth-token") -> OAuth2Credentials:
|
||||
return OAuth2Credentials(
|
||||
id="creds-oauth2",
|
||||
provider=_PROVIDER,
|
||||
title="Test OAuth",
|
||||
access_token=SecretStr(token),
|
||||
refresh_token=SecretStr("test-refresh"),
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_caches():
|
||||
"""Ensure clean caches before and after every test."""
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
yield
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
|
||||
|
||||
class TestInvalidateUserProviderCache:
|
||||
def test_removes_token_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_token_cache[key] = "tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _token_cache
|
||||
|
||||
def test_removes_null_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_null_cache[key] = True
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _null_cache
|
||||
|
||||
def test_noop_when_key_not_cached(self):
|
||||
# Should not raise even when there is no cache entry.
|
||||
invalidate_user_provider_cache("no-such-user", _PROVIDER)
|
||||
|
||||
def test_only_removes_targeted_key(self):
|
||||
other_key = ("other-user", _PROVIDER)
|
||||
_token_cache[other_key] = "other-tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert other_key in _token_cache
|
||||
|
||||
|
||||
class TestGetProviderToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_cached_token_without_db_hit(self):
|
||||
_token_cache[(_USER, _PROVIDER)] = "cached-tok"
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "cached-tok"
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_none_for_null_cached_provider(self):
|
||||
_null_cache[(_USER, _PROVIDER)] = True
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_api_key_creds_returned_and_cached(self):
|
||||
api_creds = _make_api_key_creds("my-api-key")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[api_creds])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "my-api-key"
|
||||
assert _token_cache.get((_USER, _PROVIDER)) == "my-api-key"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_preferred_over_api_key(self):
|
||||
oauth_creds = _make_oauth2_creds("oauth-tok")
|
||||
api_creds = _make_api_key_creds("api-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
return_value=[api_creds, oauth_creds]
|
||||
)
|
||||
mock_manager.refresh_if_needed = AsyncMock(return_value=oauth_creds)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_refresh_failure_returns_none(self):
|
||||
"""On refresh failure, return None instead of caching a stale token."""
|
||||
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
|
||||
mock_manager.refresh_if_needed = AsyncMock(side_effect=RuntimeError("network"))
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
# Stale tokens must NOT be returned — forces re-auth.
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_caches_null_entry(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
assert _null_cache.get((_USER, _PROVIDER)) is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_db_exception_returns_none_without_caching(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
side_effect=RuntimeError("db down")
|
||||
)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
# DB errors are not cached — next call will retry
|
||||
assert (_USER, _PROVIDER) not in _token_cache
|
||||
assert (_USER, _PROVIDER) not in _null_cache
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_null_cache_has_shorter_ttl_than_token_cache(self):
|
||||
"""Verify the TTL constants are set correctly for each cache."""
|
||||
assert _null_cache.ttl == _NULL_CACHE_TTL
|
||||
assert _token_cache.ttl == _TOKEN_CACHE_TTL
|
||||
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
|
||||
|
||||
|
||||
class TestGetIntegrationEnvVars:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_injects_all_env_vars_for_provider(self):
|
||||
_token_cache[(_USER, "github")] = "gh-tok"
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
for var in PROVIDER_ENV_VARS["github"]:
|
||||
assert result[var] == "gh-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_dict_when_no_credentials(self):
|
||||
_null_cache[(_USER, "github")] = True
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
assert result == {}
|
||||
@@ -73,9 +73,6 @@ class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
|
||||
cache_read_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
|
||||
|
||||
class ChatSessionInfo(BaseModel):
|
||||
@@ -101,10 +98,7 @@ class ChatSessionInfo(BaseModel):
|
||||
prisma_session.successfulAgentSchedules, default={}
|
||||
)
|
||||
|
||||
# Calculate usage from token counts.
|
||||
# NOTE: Per-turn cache_read_tokens / cache_creation_tokens breakdown
|
||||
# is lost after persistence — the DB only stores aggregate prompt and
|
||||
# completion totals. This is a known limitation.
|
||||
# Calculate usage from token counts
|
||||
usage = []
|
||||
if prisma_session.totalPromptTokens or prisma_session.totalCompletionTokens:
|
||||
usage.append(
|
||||
|
||||
@@ -6,11 +6,10 @@ handling the distinction between:
|
||||
- Local mode vs E2B mode (storage/filesystem differences)
|
||||
"""
|
||||
|
||||
from backend.blocks.autopilot import AUTOPILOT_BLOCK_ID
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
# Shared technical notes that apply to both SDK and baseline modes
|
||||
_SHARED_TOOL_NOTES = f"""\
|
||||
_SHARED_TOOL_NOTES = """\
|
||||
|
||||
### Sharing files with the user
|
||||
After saving a file to the persistent workspace with `write_workspace_file`,
|
||||
@@ -53,82 +52,10 @@ Examples:
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
|
||||
**Structured data**: When the **entire** argument value is a single file
|
||||
reference (no surrounding text), the platform automatically parses the file
|
||||
content based on its extension or MIME type. Supported formats: JSON, JSONL,
|
||||
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
|
||||
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
|
||||
the rows will be parsed into `list[list[str]]` automatically. If the format is
|
||||
unrecognised or parsing fails, the content is returned as a plain string.
|
||||
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
|
||||
|
||||
**Type coercion**: The platform also coerces expanded values to match the
|
||||
block's expected input types. For example, if a block expects `list[list[str]]`
|
||||
and the expanded value is a JSON string, it will be parsed into the correct type.
|
||||
|
||||
### Media file inputs (format: "file")
|
||||
Some block inputs accept media files — their schema shows `"format": "file"`.
|
||||
These fields accept:
|
||||
- **`workspace://<file_id>`** or **`workspace://<file_id>#<mime>`** — preferred
|
||||
for large files (images, videos, PDFs). The platform passes the reference
|
||||
directly to the block without reading the content into memory.
|
||||
- **`data:<mime>;base64,<payload>`** — inline base64 data URI, suitable for
|
||||
small files only.
|
||||
|
||||
When a block input has `format: "file"`, **pass the `workspace://` URI
|
||||
directly as the value** (do NOT wrap it in `@@agptfile:`). This avoids large
|
||||
payloads in tool arguments and preserves binary content (images, videos)
|
||||
that would be corrupted by text encoding.
|
||||
|
||||
Example — committing an image file to GitHub:
|
||||
```json
|
||||
{{
|
||||
"files": [{{
|
||||
"path": "docs/hero.png",
|
||||
"content": "workspace://abc123#image/png",
|
||||
"operation": "upsert"
|
||||
}}]
|
||||
}}
|
||||
```
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
|
||||
### Delegating to another autopilot (sub-autopilot pattern)
|
||||
Use the **AutoPilotBlock** (`run_block` with block_id
|
||||
`{AUTOPILOT_BLOCK_ID}`) to delegate a task to a fresh
|
||||
autopilot instance. The sub-autopilot has its own full tool set and can
|
||||
perform multi-step work autonomously.
|
||||
|
||||
- **Input**: `prompt` (required) — the task description.
|
||||
Optional: `system_context` to constrain behavior, `session_id` to
|
||||
continue a previous conversation, `max_recursion_depth` (default 3).
|
||||
- **Output**: `response` (text), `tool_calls` (list), `session_id`
|
||||
(for continuation), `conversation_history`, `token_usage`.
|
||||
|
||||
Use this when a task is complex enough to benefit from a separate
|
||||
autopilot context, e.g. "research X and write a report" while the
|
||||
parent autopilot handles orchestration.
|
||||
"""
|
||||
|
||||
# E2B-only notes — E2B has full internet access so gh CLI works there.
|
||||
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
|
||||
_E2B_TOOL_NOTES = """
|
||||
### GitHub CLI (`gh`) and git
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
`git` HTTPS operations (clone, push, pull) work automatically.
|
||||
- If the token changes mid-session (e.g. user reconnects with a new token),
|
||||
run `gh auth setup-git` to re-register the credential helper.
|
||||
- If `gh` or `git` fails with an authentication error (e.g. "authentication
|
||||
required", "could not read Username", or exit code 128), call
|
||||
`connect_integration(provider="github")` to surface the GitHub credentials
|
||||
setup card so the user can connect their account. Once connected, retry
|
||||
the operation.
|
||||
- For operations that need broader access (e.g. private org repos, GitHub
|
||||
Actions), pass the required scopes: e.g.
|
||||
`connect_integration(provider="github", scopes=["repo", "read:org"])`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -141,7 +68,6 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: list[str],
|
||||
file_move_name_1_to_2: str,
|
||||
file_move_name_2_to_1: str,
|
||||
extra_notes: str = "",
|
||||
) -> str:
|
||||
"""Build storage/filesystem supplement for a specific environment.
|
||||
|
||||
@@ -156,7 +82,6 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: List of persistence behavior descriptions
|
||||
file_move_name_1_to_2: Direction label for primary→persistent
|
||||
file_move_name_2_to_1: Direction label for persistent→primary
|
||||
extra_notes: Environment-specific notes appended after shared notes
|
||||
"""
|
||||
# Format lists as bullet points with proper indentation
|
||||
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
|
||||
@@ -190,23 +115,12 @@ def _build_storage_supplement(
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
|
||||
### SDK tool-result files
|
||||
When tool outputs are large, the SDK truncates them and saves the full output to
|
||||
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
|
||||
always use `read_file` or `Read` (NOT `read_workspace_file`).
|
||||
`read_workspace_file` reads from cloud workspace storage, where SDK
|
||||
tool-results are NOT stored.
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
|
||||
|
||||
# Pre-built supplements for common environments
|
||||
def _get_local_storage_supplement(cwd: str) -> str:
|
||||
"""Local ephemeral storage (files lost between turns).
|
||||
|
||||
Network is isolated (bubblewrap --unshare-net), so internet-dependent CLIs
|
||||
like gh will not work — no integration env-var notes are included.
|
||||
"""
|
||||
"""Local ephemeral storage (files lost between turns)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir=cwd,
|
||||
sandbox_type="in a network-isolated sandbox",
|
||||
@@ -224,11 +138,7 @@ def _get_local_storage_supplement(cwd: str) -> str:
|
||||
|
||||
|
||||
def _get_cloud_sandbox_supplement() -> str:
|
||||
"""Cloud persistent sandbox (files survive across turns in session).
|
||||
|
||||
E2B has full internet access, so integration tokens (GH_TOKEN etc.) are
|
||||
injected per command in bash_exec — include the CLI guidance notes.
|
||||
"""
|
||||
"""Cloud persistent sandbox (files survive across turns in session)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir="/home/user",
|
||||
sandbox_type="in a cloud sandbox with full internet access",
|
||||
@@ -243,7 +153,6 @@ def _get_cloud_sandbox_supplement() -> str:
|
||||
],
|
||||
file_move_name_1_to_2="Sandbox → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Sandbox",
|
||||
extra_notes=_E2B_TOOL_NOTES,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
"""Single source of truth for copilot-supported integration providers.
|
||||
|
||||
Both :mod:`~backend.copilot.integration_creds` (env-var injection) and
|
||||
:mod:`~backend.copilot.tools.connect_integration` (UI setup card) import from
|
||||
here, eliminating the risk of the two registries drifting out of sync.
|
||||
"""
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class ProviderEntry(TypedDict):
|
||||
"""Metadata for a supported integration provider.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable display name (e.g. "GitHub").
|
||||
env_vars: Environment variable names injected when the provider is
|
||||
connected (e.g. ``["GH_TOKEN", "GITHUB_TOKEN"]``).
|
||||
default_scopes: Default OAuth scopes requested when the agent does not
|
||||
specify any.
|
||||
"""
|
||||
|
||||
name: str
|
||||
env_vars: list[str]
|
||||
default_scopes: list[str]
|
||||
|
||||
|
||||
def _is_github_oauth_configured() -> bool:
|
||||
"""Return True if GitHub OAuth env vars are set.
|
||||
|
||||
Uses a lazy import to avoid triggering ``Secrets()`` during module import,
|
||||
which can fail in environments where secrets are not yet loaded (e.g. tests,
|
||||
CLI tooling).
|
||||
"""
|
||||
from backend.blocks.github._auth import GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
return GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
|
||||
# -- Registry ----------------------------------------------------------------
|
||||
# Add new providers here. Both env-var injection and the setup-card tool read
|
||||
# from this single registry.
|
||||
|
||||
SUPPORTED_PROVIDERS: dict[str, ProviderEntry] = {
|
||||
"github": {
|
||||
"name": "GitHub",
|
||||
"env_vars": ["GH_TOKEN", "GITHUB_TOKEN"],
|
||||
"default_scopes": ["repo"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_provider_auth_types(provider: str) -> list[str]:
|
||||
"""Return the supported credential types for *provider* at runtime.
|
||||
|
||||
OAuth types are only offered when the corresponding OAuth client env vars
|
||||
are configured.
|
||||
"""
|
||||
if provider == "github":
|
||||
if _is_github_oauth_configured():
|
||||
return ["api_key", "oauth2"]
|
||||
return ["api_key"]
|
||||
# Default for unknown/future providers — API key only.
|
||||
return ["api_key"]
|
||||
@@ -1,266 +0,0 @@
|
||||
"""CoPilot rate limiting based on token usage.
|
||||
|
||||
Uses Redis fixed-window counters to track per-user token consumption
|
||||
with configurable daily and weekly limits. Daily windows reset at
|
||||
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
|
||||
UTC). Fails open when Redis is unavailable to avoid blocking users.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis key prefixes
|
||||
_USAGE_KEY_PREFIX = "copilot:usage"
|
||||
|
||||
|
||||
class UsageWindow(BaseModel):
|
||||
"""Usage within a single time window."""
|
||||
|
||||
used: int
|
||||
limit: int = Field(
|
||||
description="Maximum tokens allowed in this window. 0 means unlimited."
|
||||
)
|
||||
resets_at: datetime
|
||||
|
||||
|
||||
class CoPilotUsageStatus(BaseModel):
|
||||
"""Current usage status for a user across all windows."""
|
||||
|
||||
daily: UsageWindow
|
||||
weekly: UsageWindow
|
||||
|
||||
|
||||
class RateLimitExceeded(Exception):
|
||||
"""Raised when a user exceeds their CoPilot usage limit."""
|
||||
|
||||
def __init__(self, window: str, resets_at: datetime):
|
||||
self.window = window
|
||||
self.resets_at = resets_at
|
||||
delta = resets_at - datetime.now(UTC)
|
||||
total_secs = delta.total_seconds()
|
||||
if total_secs <= 0:
|
||||
time_str = "now"
|
||||
else:
|
||||
hours = int(total_secs // 3600)
|
||||
minutes = int((total_secs % 3600) // 60)
|
||||
time_str = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m"
|
||||
super().__init__(
|
||||
f"You've reached your {window} usage limit. Resets in {time_str}."
|
||||
)
|
||||
|
||||
|
||||
async def get_usage_status(
|
||||
user_id: str,
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get current usage status for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
daily_token_limit: Max tokens per day (0 = unlimited).
|
||||
weekly_token_limit: Max tokens per week (0 = unlimited).
|
||||
|
||||
Returns:
|
||||
CoPilotUsageStatus with current usage and limits.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
daily_used = 0
|
||||
weekly_used = 0
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
daily_raw, weekly_raw = await asyncio.gather(
|
||||
redis.get(_daily_key(user_id, now=now)),
|
||||
redis.get(_weekly_key(user_id, now=now)),
|
||||
)
|
||||
daily_used = int(daily_raw or 0)
|
||||
weekly_used = int(weekly_raw or 0)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning("Redis unavailable for usage status, returning zeros")
|
||||
|
||||
return CoPilotUsageStatus(
|
||||
daily=UsageWindow(
|
||||
used=daily_used,
|
||||
limit=daily_token_limit,
|
||||
resets_at=_daily_reset_time(now=now),
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=weekly_used,
|
||||
limit=weekly_token_limit,
|
||||
resets_at=_weekly_reset_time(now=now),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def check_rate_limit(
|
||||
user_id: str,
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
) -> None:
|
||||
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
|
||||
|
||||
This is a pre-turn soft check. The authoritative usage counter is updated
|
||||
by ``record_token_usage()`` after the turn completes. Under concurrency,
|
||||
two parallel turns may both pass this check against the same snapshot.
|
||||
This is acceptable because token-based limits are approximate by nature
|
||||
(the exact token count is unknown until after generation).
|
||||
|
||||
Fails open: if Redis is unavailable, allows the request.
|
||||
"""
|
||||
# Short-circuit: when both limits are 0 (unlimited) skip the Redis
|
||||
# round-trip entirely.
|
||||
if daily_token_limit <= 0 and weekly_token_limit <= 0:
|
||||
return
|
||||
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
daily_raw, weekly_raw = await asyncio.gather(
|
||||
redis.get(_daily_key(user_id, now=now)),
|
||||
redis.get(_weekly_key(user_id, now=now)),
|
||||
)
|
||||
daily_used = int(daily_raw or 0)
|
||||
weekly_used = int(weekly_raw or 0)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning("Redis unavailable for rate limit check, allowing request")
|
||||
return
|
||||
|
||||
# Worst-case overshoot: N concurrent requests × ~15K tokens each.
|
||||
if daily_token_limit > 0 and daily_used >= daily_token_limit:
|
||||
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
|
||||
|
||||
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
|
||||
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
|
||||
|
||||
|
||||
async def record_token_usage(
|
||||
user_id: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
*,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
) -> None:
|
||||
"""Record token usage for a user across all windows.
|
||||
|
||||
Uses cost-weighted counting so cached tokens don't unfairly penalise
|
||||
multi-turn conversations. Anthropic's pricing:
|
||||
- uncached input: 100%
|
||||
- cache creation: 25%
|
||||
- cache read: 10%
|
||||
- output: 100%
|
||||
|
||||
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
|
||||
from the API response). Cache counts are passed separately.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
prompt_tokens: Uncached input tokens.
|
||||
completion_tokens: Output tokens.
|
||||
cache_read_tokens: Tokens served from prompt cache (10% cost).
|
||||
cache_creation_tokens: Tokens written to prompt cache (25% cost).
|
||||
"""
|
||||
prompt_tokens = max(0, prompt_tokens)
|
||||
completion_tokens = max(0, completion_tokens)
|
||||
cache_read_tokens = max(0, cache_read_tokens)
|
||||
cache_creation_tokens = max(0, cache_creation_tokens)
|
||||
|
||||
weighted_input = (
|
||||
prompt_tokens
|
||||
+ round(cache_creation_tokens * 0.25)
|
||||
+ round(cache_read_tokens * 0.1)
|
||||
)
|
||||
total = weighted_input + completion_tokens
|
||||
if total <= 0:
|
||||
return
|
||||
|
||||
raw_total = (
|
||||
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
|
||||
)
|
||||
logger.info(
|
||||
"Recording token usage for %s: raw=%d, weighted=%d "
|
||||
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
|
||||
user_id[:8],
|
||||
raw_total,
|
||||
total,
|
||||
prompt_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_tokens,
|
||||
completion_tokens,
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
# transaction=False: these are independent INCRBY+EXPIRE pairs on
|
||||
# separate keys — no cross-key atomicity needed. Skipping
|
||||
# MULTI/EXEC avoids the overhead. If the connection drops between
|
||||
# INCRBY and EXPIRE the key survives until the next date-based key
|
||||
# rotation (daily/weekly), so the memory-leak risk is negligible.
|
||||
pipe = redis.pipeline(transaction=False)
|
||||
|
||||
# Daily counter (expires at next midnight UTC)
|
||||
d_key = _daily_key(user_id, now=now)
|
||||
pipe.incrby(d_key, total)
|
||||
seconds_until_daily_reset = int(
|
||||
(_daily_reset_time(now=now) - now).total_seconds()
|
||||
)
|
||||
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
|
||||
|
||||
# Weekly counter (expires end of week)
|
||||
w_key = _weekly_key(user_id, now=now)
|
||||
pipe.incrby(w_key, total)
|
||||
seconds_until_weekly_reset = int(
|
||||
(_weekly_reset_time(now=now) - now).total_seconds()
|
||||
)
|
||||
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
|
||||
|
||||
await pipe.execute()
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning(
|
||||
"Redis unavailable for recording token usage (tokens=%d)",
|
||||
total,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _daily_key(user_id: str, now: datetime | None = None) -> str:
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
return f"{_USAGE_KEY_PREFIX}:daily:{user_id}:{now.strftime('%Y-%m-%d')}"
|
||||
|
||||
|
||||
def _weekly_key(user_id: str, now: datetime | None = None) -> str:
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
year, week, _ = now.isocalendar()
|
||||
return f"{_USAGE_KEY_PREFIX}:weekly:{user_id}:{year}-W{week:02d}"
|
||||
|
||||
|
||||
def _daily_reset_time(now: datetime | None = None) -> datetime:
|
||||
"""Calculate when the current daily window resets (next midnight UTC)."""
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
|
||||
|
||||
|
||||
def _weekly_reset_time(now: datetime | None = None) -> datetime:
|
||||
"""Calculate when the current weekly window resets (next Monday 00:00 UTC)."""
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
days_until_monday = (7 - now.weekday()) % 7 or 7
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(
|
||||
days=days_until_monday
|
||||
)
|
||||
@@ -1,334 +0,0 @@
|
||||
"""Unit tests for CoPilot rate limiting."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from .rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
check_rate_limit,
|
||||
get_usage_status,
|
||||
record_token_usage,
|
||||
)
|
||||
|
||||
_USER = "test-user-rl"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RateLimitExceeded
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRateLimitExceeded:
|
||||
def test_message_contains_window_name(self):
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1))
|
||||
assert "daily" in str(exc)
|
||||
|
||||
def test_message_contains_reset_time(self):
|
||||
exc = RateLimitExceeded(
|
||||
"weekly", datetime.now(UTC) + timedelta(hours=2, minutes=30)
|
||||
)
|
||||
msg = str(exc)
|
||||
# Allow for slight timing drift (29m or 30m)
|
||||
assert "2h " in msg
|
||||
assert "Resets in" in msg
|
||||
|
||||
def test_message_minutes_only_when_under_one_hour(self):
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(minutes=15))
|
||||
msg = str(exc)
|
||||
assert "Resets in" in msg
|
||||
# Should not have "0h"
|
||||
assert "0h" not in msg
|
||||
|
||||
def test_message_says_now_when_resets_at_is_in_the_past(self):
|
||||
"""Negative delta (clock skew / stale TTL) should say 'now', not '-1h -30m'."""
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) - timedelta(minutes=5))
|
||||
assert "Resets in now" in str(exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_usage_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetUsageStatus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_redis_values(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert isinstance(status, CoPilotUsageStatus)
|
||||
assert status.daily.used == 500
|
||||
assert status.daily.limit == 10000
|
||||
assert status.weekly.used == 2000
|
||||
assert status.weekly.limit == 50000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_zeros_when_redis_unavailable(self):
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 0
|
||||
assert status.weekly.used == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_none_daily_counter(self):
|
||||
"""Daily counter is None (new day), weekly has usage."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=[None, "3000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 0
|
||||
assert status.weekly.used == 3000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_none_weekly_counter(self):
|
||||
"""Weekly counter is None (start of week), daily has usage."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["500", None])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 500
|
||||
assert status.weekly.used == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resets_at_daily_is_next_midnight_utc(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["0", "0"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
# Daily reset should be within 24h
|
||||
assert status.daily.resets_at > now
|
||||
assert status.daily.resets_at <= now + timedelta(hours=24, seconds=5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_rate_limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRateLimit:
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_when_under_limit(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["100", "200"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_daily_limit_exceeded(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
assert exc_info.value.window == "daily"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_weekly_limit_exceeded(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["100", "50000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
assert exc_info.value.window == "weekly"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_when_redis_unavailable(self):
|
||||
"""Fail-open: allow requests when Redis is down."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
# Should not raise
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_check_when_limit_is_zero(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["999999", "999999"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise — limits of 0 mean unlimited
|
||||
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# record_token_usage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecordTokenUsage:
|
||||
@staticmethod
|
||||
def _make_pipeline_mock() -> MagicMock:
|
||||
"""Create a pipeline mock with sync methods and async execute."""
|
||||
pipe = MagicMock()
|
||||
pipe.execute = AsyncMock(return_value=[])
|
||||
return pipe
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increments_redis_counters(self):
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
# Should call incrby twice (daily + weekly) with total=150
|
||||
incrby_calls = mock_pipe.incrby.call_args_list
|
||||
assert len(incrby_calls) == 2
|
||||
assert incrby_calls[0].args[1] == 150 # daily
|
||||
assert incrby_calls[1].args[1] == 150 # weekly
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_zero_tokens(self):
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
# Should not call pipeline at all
|
||||
mock_redis.pipeline.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sets_expire_on_both_keys(self):
|
||||
"""Pipeline should call expire for both daily and weekly keys."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
expire_calls = mock_pipe.expire.call_args_list
|
||||
assert len(expire_calls) == 2
|
||||
|
||||
# Daily key TTL should be positive (seconds until next midnight)
|
||||
daily_ttl = expire_calls[0].args[1]
|
||||
assert daily_ttl >= 1
|
||||
|
||||
# Weekly key TTL should be positive (seconds until next Monday)
|
||||
weekly_ttl = expire_calls[1].args[1]
|
||||
assert weekly_ttl >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_redis_failure_gracefully(self):
|
||||
"""Should not raise when Redis is unavailable."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
# Should not raise
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_weighted_counting(self):
|
||||
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(
|
||||
_USER,
|
||||
prompt_tokens=100, # uncached → 100
|
||||
completion_tokens=50, # output → 50
|
||||
cache_read_tokens=10000, # 10% → 1000
|
||||
cache_creation_tokens=400, # 25% → 100
|
||||
)
|
||||
|
||||
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
|
||||
incrby_calls = mock_pipe.incrby.call_args_list
|
||||
assert len(incrby_calls) == 2
|
||||
assert incrby_calls[0].args[1] == 1250 # daily
|
||||
assert incrby_calls[1].args[1] == 1250 # weekly
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_redis_error_during_pipeline_execute(self):
|
||||
"""Should not raise when pipeline.execute() fails with RedisError."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_pipe.execute = AsyncMock(side_effect=RedisError("Pipeline failed"))
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise — fail-open
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
@@ -43,7 +43,6 @@ class ResponseType(str, Enum):
|
||||
ERROR = "error"
|
||||
USAGE = "usage"
|
||||
HEARTBEAT = "heartbeat"
|
||||
STATUS = "status"
|
||||
|
||||
|
||||
class StreamBaseResponse(BaseModel):
|
||||
@@ -187,43 +186,12 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
|
||||
|
||||
class StreamUsage(StreamBaseResponse):
|
||||
"""Token usage statistics.
|
||||
|
||||
Emitted as an SSE comment so the Vercel AI SDK parser ignores it
|
||||
(it uses z.strictObject() and rejects unknown event types).
|
||||
Usage data is recorded server-side (session DB + Redis counters).
|
||||
"""
|
||||
"""Token usage statistics."""
|
||||
|
||||
type: ResponseType = ResponseType.USAGE
|
||||
prompt_tokens: int = Field(
|
||||
...,
|
||||
serialization_alias="promptTokens",
|
||||
description="Number of uncached prompt tokens",
|
||||
)
|
||||
completion_tokens: int = Field(
|
||||
...,
|
||||
serialization_alias="completionTokens",
|
||||
description="Number of completion tokens",
|
||||
)
|
||||
total_tokens: int = Field(
|
||||
...,
|
||||
serialization_alias="totalTokens",
|
||||
description="Total number of tokens (raw, not weighted)",
|
||||
)
|
||||
cache_read_tokens: int = Field(
|
||||
default=0,
|
||||
serialization_alias="cacheReadTokens",
|
||||
description="Prompt tokens served from cache (10% cost)",
|
||||
)
|
||||
cache_creation_tokens: int = Field(
|
||||
default=0,
|
||||
serialization_alias="cacheCreationTokens",
|
||||
description="Prompt tokens written to cache (25% cost)",
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Emit as SSE comment so the AI SDK parser ignores it."""
|
||||
return f": usage {self.model_dump_json(exclude_none=True, by_alias=True)}\n\n"
|
||||
promptTokens: int = Field(..., description="Number of prompt tokens")
|
||||
completionTokens: int = Field(..., description="Number of completion tokens")
|
||||
totalTokens: int = Field(..., description="Total number of tokens")
|
||||
|
||||
|
||||
class StreamError(StreamBaseResponse):
|
||||
@@ -264,19 +232,3 @@ class StreamHeartbeat(StreamBaseResponse):
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE comment format to keep connection alive."""
|
||||
return ": heartbeat\n\n"
|
||||
|
||||
|
||||
class StreamStatus(StreamBaseResponse):
|
||||
"""Transient status notification shown to the user during long operations.
|
||||
|
||||
Used to provide feedback when the backend performs behind-the-scenes work
|
||||
(e.g., compacting conversation context on a retry) that would otherwise
|
||||
leave the user staring at an unexplained pause.
|
||||
|
||||
Sent as a proper ``data:`` event so the frontend can display it to the
|
||||
user. The AI SDK stream parser gracefully skips unknown chunk types
|
||||
(logs a console warning), so this does not break the stream.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.STATUS
|
||||
message: str = Field(..., description="Human-readable status message")
|
||||
|
||||
@@ -3,57 +3,12 @@
|
||||
This module provides the integration layer between the Claude Agent SDK
|
||||
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||
|
||||
Submodule imports are deferred via PEP 562 ``__getattr__`` to break a
|
||||
circular import cycle::
|
||||
|
||||
sdk/__init__ → tool_adapter → copilot.tools (TOOL_REGISTRY)
|
||||
copilot.tools → run_block → sdk.file_ref (no cycle here, but…)
|
||||
sdk/__init__ → service → copilot.prompting → copilot.tools (cycle!)
|
||||
|
||||
``tool_adapter`` uses ``TOOL_REGISTRY`` at **module level** to build the
|
||||
static ``COPILOT_TOOL_NAMES`` list, so the import cannot be deferred to
|
||||
function scope without a larger refactor (moving tool-name registration
|
||||
to a separate lightweight module). The lazy-import pattern here is the
|
||||
least invasive way to break the cycle while keeping module-level constants
|
||||
intact.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
# Static imports for type checkers so they can resolve __all__ entries
|
||||
# without executing the lazy-import machinery at runtime.
|
||||
if TYPE_CHECKING:
|
||||
from .collect import CopilotResult as CopilotResult
|
||||
from .collect import collect_copilot_response as collect_copilot_response
|
||||
from .service import stream_chat_completion_sdk as stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server as create_copilot_mcp_server
|
||||
from .service import stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server
|
||||
|
||||
__all__ = [
|
||||
"CopilotResult",
|
||||
"collect_copilot_response",
|
||||
"stream_chat_completion_sdk",
|
||||
"create_copilot_mcp_server",
|
||||
]
|
||||
|
||||
# Dispatch table for PEP 562 lazy imports. Each entry is a (module, attr)
|
||||
# pair so new exports can be added without touching __getattr__ itself.
|
||||
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
|
||||
"CopilotResult": (".collect", "CopilotResult"),
|
||||
"collect_copilot_response": (".collect", "collect_copilot_response"),
|
||||
"stream_chat_completion_sdk": (".service", "stream_chat_completion_sdk"),
|
||||
"create_copilot_mcp_server": (".tool_adapter", "create_copilot_mcp_server"),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
entry = _LAZY_IMPORTS.get(name)
|
||||
if entry is not None:
|
||||
module_path, attr = entry
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_path, package=__name__)
|
||||
value = getattr(module, attr)
|
||||
globals()[name] = value
|
||||
return value
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -143,71 +143,6 @@ To use an MCP (Model Context Protocol) tool as a node in the agent:
|
||||
tool_arguments.
|
||||
6. Output: `result` (the tool's return value) and `error` (error message)
|
||||
|
||||
### Using SmartDecisionMakerBlock (AI Orchestrator with Agent Mode)
|
||||
|
||||
To create an agent where AI autonomously decides which tools or sub-agents to
|
||||
call in a loop until the task is complete:
|
||||
1. Create a `SmartDecisionMakerBlock` node
|
||||
(ID: `3b191d9f-356f-482d-8238-ba04b6d18381`)
|
||||
2. Set `input_default`:
|
||||
- `agent_mode_max_iterations`: Choose based on task complexity:
|
||||
- `1` for single-step tool calls (AI picks one tool, calls it, done)
|
||||
- `3`–`10` for multi-step tasks (AI calls tools iteratively)
|
||||
- `-1` for open-ended orchestration (AI loops until it decides it's done).
|
||||
**Use with caution** — prefer bounded iterations (3–10) unless
|
||||
genuinely needed, as unbounded loops risk runaway cost and execution.
|
||||
Do NOT use `0` (traditional mode) — it requires complex external
|
||||
conversation-history loop wiring that the agent generator does not
|
||||
produce.
|
||||
- `conversation_compaction`: `true` (recommended to avoid context overflow)
|
||||
- `retry`: Number of retries on tool-call failure (default `3`).
|
||||
Set to `0` to disable retries.
|
||||
- `multiple_tool_calls`: Whether the AI can invoke multiple tools in a
|
||||
single turn (default `false`). Enable when tools are independent and
|
||||
can run concurrently.
|
||||
- Optional: `sys_prompt` for extra LLM context about how to orchestrate
|
||||
3. Wire the `prompt` input from an `AgentInputBlock` (the user's task)
|
||||
4. Create downstream tool blocks — regular blocks **or** `AgentExecutorBlock`
|
||||
nodes that call sub-agents
|
||||
5. Link each tool to the SmartDecisionMaker: set `source_name: "tools"` on
|
||||
the SmartDecisionMaker side and `sink_name: <input_field>` on each tool
|
||||
block's input. Create one link per input field the tool needs.
|
||||
6. Wire the `finished` output to an `AgentOutputBlock` for the final result
|
||||
7. Credentials (LLM API key) are configured by the user in the platform UI
|
||||
after saving — do NOT require them upfront
|
||||
|
||||
**Example — Orchestrator calling two sub-agents:**
|
||||
- Node 1: `AgentInputBlock` (input_default: `{"name": "task"}`)
|
||||
- Node 2: `SmartDecisionMakerBlock` (input_default:
|
||||
`{"agent_mode_max_iterations": 10, "conversation_compaction": true}`)
|
||||
- Node 3: `AgentExecutorBlock` (sub-agent A — set `graph_id`, `graph_version`,
|
||||
`input_schema`, `output_schema` from library agent)
|
||||
- Node 4: `AgentExecutorBlock` (sub-agent B — same pattern)
|
||||
- Node 5: `AgentOutputBlock` (input_default: `{"name": "result"}`)
|
||||
- Links:
|
||||
- Input→SDM: `source_name: "result"`, `sink_name: "prompt"`
|
||||
- SDM→Agent A (per input field): `source_name: "tools"`,
|
||||
`sink_name: "<agent_a_input_field>"`
|
||||
- SDM→Agent B (per input field): `source_name: "tools"`,
|
||||
`sink_name: "<agent_b_input_field>"`
|
||||
- SDM→Output: `source_name: "finished"`, `sink_name: "value"`
|
||||
|
||||
**Example — Orchestrator calling regular blocks as tools:**
|
||||
- Node 1: `AgentInputBlock` (input_default: `{"name": "task"}`)
|
||||
- Node 2: `SmartDecisionMakerBlock` (input_default:
|
||||
`{"agent_mode_max_iterations": 5, "conversation_compaction": true}`)
|
||||
- Node 3: `GetWebpageBlock` (regular block — the AI calls it as a tool)
|
||||
- Node 4: `AITextGeneratorBlock` (another regular block as a tool)
|
||||
- Node 5: `AgentOutputBlock` (input_default: `{"name": "result"}`)
|
||||
- Links:
|
||||
- Input→SDM: `source_name: "result"`, `sink_name: "prompt"`
|
||||
- SDM→GetWebpage: `source_name: "tools"`, `sink_name: "url"`
|
||||
- SDM→AITextGenerator: `source_name: "tools"`, `sink_name: "prompt"`
|
||||
- SDM→Output: `source_name: "finished"`, `sink_name: "value"`
|
||||
|
||||
Regular blocks work exactly like sub-agents as tools — wire each input
|
||||
field from `source_name: "tools"` on the SmartDecisionMaker side.
|
||||
|
||||
### Example: Simple AI Text Processor
|
||||
|
||||
A minimal agent with input, processing, and output:
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
"""Public helpers for consuming a copilot stream as a simple request-response.
|
||||
|
||||
This module exposes :class:`CopilotResult` and :func:`collect_copilot_response`
|
||||
so that callers (e.g. the AutoPilot block) can consume the copilot stream
|
||||
without implementing their own event loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class CopilotResult:
|
||||
"""Aggregated result from consuming a copilot stream.
|
||||
|
||||
Returned by :func:`collect_copilot_response` so callers don't need to
|
||||
implement their own event-loop over the raw stream events.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"response_text",
|
||||
"tool_calls",
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.response_text: str = ""
|
||||
self.tool_calls: list[dict[str, Any]] = []
|
||||
self.prompt_tokens: int = 0
|
||||
self.completion_tokens: int = 0
|
||||
self.total_tokens: int = 0
|
||||
|
||||
|
||||
async def collect_copilot_response(
|
||||
*,
|
||||
session_id: str,
|
||||
message: str,
|
||||
user_id: str,
|
||||
is_user_message: bool = True,
|
||||
) -> CopilotResult:
|
||||
"""Consume :func:`stream_chat_completion_sdk` and return aggregated results.
|
||||
|
||||
This is the recommended entry-point for callers that need a simple
|
||||
request-response interface (e.g. the AutoPilot block) rather than
|
||||
streaming individual events. It avoids duplicating the event-collection
|
||||
logic and does NOT wrap the stream in ``asyncio.timeout`` — the SDK
|
||||
manages its own heartbeat-based timeouts internally.
|
||||
|
||||
Args:
|
||||
session_id: Chat session to use.
|
||||
message: The user message / prompt.
|
||||
user_id: Authenticated user ID.
|
||||
is_user_message: Whether this is a user-initiated message.
|
||||
|
||||
Returns:
|
||||
A :class:`CopilotResult` with the aggregated response text,
|
||||
tool calls, and token usage.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the stream yields a ``StreamError`` event.
|
||||
"""
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
from .service import stream_chat_completion_sdk
|
||||
|
||||
result = CopilotResult()
|
||||
response_parts: list[str] = []
|
||||
tool_calls_by_id: dict[str, dict[str, Any]] = {}
|
||||
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
):
|
||||
if isinstance(event, StreamTextDelta):
|
||||
response_parts.append(event.delta)
|
||||
elif isinstance(event, StreamToolInputAvailable):
|
||||
entry: dict[str, Any] = {
|
||||
"tool_call_id": event.toolCallId,
|
||||
"tool_name": event.toolName,
|
||||
"input": event.input,
|
||||
"output": None,
|
||||
"success": None,
|
||||
}
|
||||
result.tool_calls.append(entry)
|
||||
tool_calls_by_id[event.toolCallId] = entry
|
||||
elif isinstance(event, StreamToolOutputAvailable):
|
||||
if tc := tool_calls_by_id.get(event.toolCallId):
|
||||
tc["output"] = event.output
|
||||
tc["success"] = event.success
|
||||
elif isinstance(event, StreamUsage):
|
||||
result.prompt_tokens += event.prompt_tokens
|
||||
result.completion_tokens += event.completion_tokens
|
||||
result.total_tokens += event.total_tokens
|
||||
elif isinstance(event, StreamError):
|
||||
raise RuntimeError(f"Copilot error: {event.errorText}")
|
||||
|
||||
result.response_text = "".join(response_parts)
|
||||
return result
|
||||
@@ -11,8 +11,7 @@ persistence, and the ``CompactionTracker`` state machine.
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from collections.abc import Callable
|
||||
|
||||
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
||||
from ..model import ChatMessage, ChatSession
|
||||
@@ -28,19 +27,6 @@ from ..response_model import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompactionResult:
|
||||
"""Result of emit_end_if_ready — bundles events with compaction metadata.
|
||||
|
||||
Eliminates the need for separate ``compaction_just_ended`` checks,
|
||||
preventing TOCTOU races between the emit call and the flag read.
|
||||
"""
|
||||
|
||||
events: list[StreamBaseResponse] = field(default_factory=list)
|
||||
just_ended: bool = False
|
||||
transcript_path: str = ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event builders (private — use CompactionTracker or compaction_events)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -120,12 +106,14 @@ def filter_compaction_messages(
|
||||
filtered: list[ChatMessage] = []
|
||||
for msg in messages:
|
||||
if msg.role == "assistant" and msg.tool_calls:
|
||||
real_calls: list[dict[str, Any]] = []
|
||||
for tc in msg.tool_calls:
|
||||
if tc.get("function", {}).get("name") == COMPACTION_TOOL_NAME:
|
||||
compaction_ids.add(tc.get("id", ""))
|
||||
else:
|
||||
real_calls.append(tc)
|
||||
real_calls = [
|
||||
tc
|
||||
for tc in msg.tool_calls
|
||||
if tc.get("function", {}).get("name") != COMPACTION_TOOL_NAME
|
||||
]
|
||||
if not real_calls and not msg.content:
|
||||
continue
|
||||
if msg.role == "tool" and msg.tool_call_id in compaction_ids:
|
||||
@@ -189,22 +177,11 @@ class CompactionTracker:
|
||||
self._start_emitted = False
|
||||
self._done = False
|
||||
self._tool_call_id = ""
|
||||
self._transcript_path: str = ""
|
||||
|
||||
def on_compact(self, transcript_path: str = "") -> None:
|
||||
"""Callback for the PreCompact hook. Stores transcript_path."""
|
||||
if (
|
||||
self._transcript_path
|
||||
and transcript_path
|
||||
and self._transcript_path != transcript_path
|
||||
):
|
||||
logger.warning(
|
||||
"[Compaction] Overwriting transcript_path %s -> %s",
|
||||
self._transcript_path,
|
||||
transcript_path,
|
||||
)
|
||||
self._transcript_path = transcript_path
|
||||
self._compact_start.set()
|
||||
@property
|
||||
def on_compact(self) -> Callable[[], None]:
|
||||
"""Callback for the PreCompact hook."""
|
||||
return self._compact_start.set
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pre-query compaction
|
||||
@@ -221,11 +198,9 @@ class CompactionTracker:
|
||||
|
||||
def reset_for_query(self) -> None:
|
||||
"""Reset per-query state before a new SDK query."""
|
||||
self._compact_start.clear()
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
self._transcript_path = ""
|
||||
|
||||
def emit_start_if_ready(self) -> list[StreamBaseResponse]:
|
||||
"""If the PreCompact hook fired, emit start events (spinning tool)."""
|
||||
@@ -236,20 +211,15 @@ class CompactionTracker:
|
||||
return _start_events(self._tool_call_id)
|
||||
return []
|
||||
|
||||
async def emit_end_if_ready(self, session: ChatSession) -> CompactionResult:
|
||||
"""If compaction is in progress, emit end events and persist.
|
||||
|
||||
Returns a ``CompactionResult`` with ``just_ended=True`` and the
|
||||
captured ``transcript_path`` when a compaction cycle completes.
|
||||
This avoids a separate flag check (TOCTOU-safe).
|
||||
"""
|
||||
async def emit_end_if_ready(self, session: ChatSession) -> list[StreamBaseResponse]:
|
||||
"""If compaction is in progress, emit end events and persist."""
|
||||
# Yield so pending hook tasks can set compact_start
|
||||
await asyncio.sleep(0)
|
||||
|
||||
if self._done:
|
||||
return CompactionResult()
|
||||
return []
|
||||
if not self._start_emitted and not self._compact_start.is_set():
|
||||
return CompactionResult()
|
||||
return []
|
||||
|
||||
if self._start_emitted:
|
||||
# Close the open spinner
|
||||
@@ -262,12 +232,8 @@ class CompactionTracker:
|
||||
COMPACTION_DONE_MSG, tool_call_id=persist_id
|
||||
)
|
||||
|
||||
transcript_path = self._transcript_path
|
||||
self._compact_start.clear()
|
||||
self._start_emitted = False
|
||||
self._done = True
|
||||
self._transcript_path = ""
|
||||
_persist(session, persist_id, COMPACTION_DONE_MSG)
|
||||
return CompactionResult(
|
||||
events=done_events, just_ended=True, transcript_path=transcript_path
|
||||
)
|
||||
return done_events
|
||||
|
||||
@@ -195,11 +195,10 @@ class TestCompactionTracker:
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
assert result.just_ended is True
|
||||
assert len(result.events) == 2
|
||||
assert isinstance(result.events[0], StreamToolOutputAvailable)
|
||||
assert isinstance(result.events[1], StreamFinishStep)
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert len(evts) == 2
|
||||
assert isinstance(evts[0], StreamToolOutputAvailable)
|
||||
assert isinstance(evts[1], StreamFinishStep)
|
||||
# Should persist
|
||||
assert len(session.messages) == 2
|
||||
|
||||
@@ -211,32 +210,28 @@ class TestCompactionTracker:
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
# Don't call emit_start_if_ready
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
assert result.just_ended is True
|
||||
assert len(result.events) == 5 # Full self-contained event
|
||||
assert isinstance(result.events[0], StreamStartStep)
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert len(evts) == 5 # Full self-contained event
|
||||
assert isinstance(evts[0], StreamStartStep)
|
||||
assert len(session.messages) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_no_op_when_no_new_compaction(self):
|
||||
async def test_emit_end_no_op_when_done(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
result1 = await tracker.emit_end_if_ready(session)
|
||||
assert result1.just_ended is True
|
||||
# Second call should be no-op (no new on_compact)
|
||||
result2 = await tracker.emit_end_if_ready(session)
|
||||
assert result2.just_ended is False
|
||||
assert result2.events == []
|
||||
await tracker.emit_end_if_ready(session)
|
||||
# Second call should be no-op
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert evts == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_no_op_when_nothing_happened(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
assert result.just_ended is False
|
||||
assert result.events == []
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert evts == []
|
||||
|
||||
def test_emit_pre_query(self):
|
||||
tracker = CompactionTracker()
|
||||
@@ -251,29 +246,20 @@ class TestCompactionTracker:
|
||||
tracker._done = True
|
||||
tracker._start_emitted = True
|
||||
tracker._tool_call_id = "old"
|
||||
tracker._transcript_path = "/some/path"
|
||||
tracker.reset_for_query()
|
||||
assert tracker._done is False
|
||||
assert tracker._start_emitted is False
|
||||
assert tracker._tool_call_id == ""
|
||||
assert tracker._transcript_path == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_query_blocks_sdk_compaction_until_reset(self):
|
||||
"""After pre-query compaction, SDK compaction is blocked until
|
||||
reset_for_query is called."""
|
||||
async def test_pre_query_blocks_sdk_compaction(self):
|
||||
"""After pre-query compaction, SDK compaction events are suppressed."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.emit_pre_query(session)
|
||||
tracker.on_compact()
|
||||
# _done is True so emit_start_if_ready is blocked
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert evts == []
|
||||
# Reset clears _done, allowing subsequent compaction
|
||||
tracker.reset_for_query()
|
||||
tracker.on_compact()
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert len(evts) == 3
|
||||
assert evts == [] # _done blocks it
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_allows_new_compaction(self):
|
||||
@@ -293,9 +279,9 @@ class TestCompactionTracker:
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
start_evts = tracker.emit_start_if_ready()
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
end_evts = await tracker.emit_end_if_ready(session)
|
||||
start_evt = start_evts[1]
|
||||
end_evt = result.events[0]
|
||||
end_evt = end_evts[0]
|
||||
assert isinstance(start_evt, StreamToolInputStart)
|
||||
assert isinstance(end_evt, StreamToolOutputAvailable)
|
||||
assert start_evt.toolCallId == end_evt.toolCallId
|
||||
@@ -303,105 +289,3 @@ class TestCompactionTracker:
|
||||
tool_calls = session.messages[0].tool_calls
|
||||
assert tool_calls is not None
|
||||
assert tool_calls[0]["id"] == start_evt.toolCallId
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_compactions_within_query(self):
|
||||
"""Two mid-stream compactions within a single query both trigger."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
|
||||
# First compaction cycle
|
||||
tracker.on_compact("/path/1")
|
||||
tracker.emit_start_if_ready()
|
||||
result1 = await tracker.emit_end_if_ready(session)
|
||||
assert result1.just_ended is True
|
||||
assert len(result1.events) == 2
|
||||
assert result1.transcript_path == "/path/1"
|
||||
|
||||
# Second compaction cycle (should NOT be blocked — _done resets
|
||||
# because emit_end_if_ready sets it True, but the next on_compact
|
||||
# + emit_start_if_ready checks !_done which IS True now.
|
||||
# So we need reset_for_query between queries, but within a single
|
||||
# query multiple compactions work because _done blocks emit_start
|
||||
# until the next message arrives, at which point emit_end detects it)
|
||||
#
|
||||
# Actually: _done=True blocks emit_start_if_ready, so we need
|
||||
# the stream loop to reset. In practice service.py doesn't call
|
||||
# reset between compactions within the same query — let's verify
|
||||
# the actual behavior.
|
||||
tracker.on_compact("/path/2")
|
||||
# _done is True from first compaction, so start is blocked
|
||||
start_evts = tracker.emit_start_if_ready()
|
||||
assert start_evts == []
|
||||
# But emit_end returns no-op because _done is True
|
||||
result2 = await tracker.emit_end_if_ready(session)
|
||||
assert result2.just_ended is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_compactions_with_intervening_message(self):
|
||||
"""Multiple compactions work when the stream loop processes messages between them.
|
||||
|
||||
In the real service.py flow:
|
||||
1. PreCompact fires → on_compact()
|
||||
2. emit_start shows spinner
|
||||
3. Next message arrives → emit_end completes compaction (_done=True)
|
||||
4. Stream continues processing messages...
|
||||
5. If a second PreCompact fires, _done=True blocks emit_start
|
||||
6. But the next message triggers emit_end, which sees _done=True → no-op
|
||||
7. The stream loop needs to detect this and handle accordingly
|
||||
|
||||
The actual flow for multiple compactions within a query requires
|
||||
_done to be cleared between them. The service.py code uses
|
||||
CompactionResult.just_ended to trigger replace_entries, and _done
|
||||
stays True until reset_for_query.
|
||||
"""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
|
||||
# First compaction
|
||||
tracker.on_compact("/path/1")
|
||||
tracker.emit_start_if_ready()
|
||||
result1 = await tracker.emit_end_if_ready(session)
|
||||
assert result1.just_ended is True
|
||||
assert result1.transcript_path == "/path/1"
|
||||
|
||||
# Simulate reset between queries
|
||||
tracker.reset_for_query()
|
||||
|
||||
# Second compaction in new query
|
||||
tracker.on_compact("/path/2")
|
||||
start_evts = tracker.emit_start_if_ready()
|
||||
assert len(start_evts) == 3
|
||||
result2 = await tracker.emit_end_if_ready(session)
|
||||
assert result2.just_ended is True
|
||||
assert result2.transcript_path == "/path/2"
|
||||
|
||||
def test_on_compact_stores_transcript_path(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact("/some/path.jsonl")
|
||||
assert tracker._transcript_path == "/some/path.jsonl"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_returns_transcript_path(self):
|
||||
"""CompactionResult includes the transcript_path from on_compact."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact("/my/session.jsonl")
|
||||
tracker.emit_start_if_ready()
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
assert result.just_ended is True
|
||||
assert result.transcript_path == "/my/session.jsonl"
|
||||
# transcript_path is cleared after emit_end
|
||||
assert tracker._transcript_path == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_clears_transcript_path(self):
|
||||
"""After emit_end, _transcript_path is reset so it doesn't leak to
|
||||
subsequent non-compaction emit_end calls."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact("/first/path.jsonl")
|
||||
tracker.emit_start_if_ready()
|
||||
await tracker.emit_end_if_ready(session)
|
||||
# After compaction, _transcript_path is cleared
|
||||
assert tracker._transcript_path == ""
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
"""Shared test fixtures for copilot SDK tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_chat_config():
|
||||
"""Mock ChatConfig so compact_transcript tests skip real config lookup."""
|
||||
with patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type("Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"})(),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def build_test_transcript(pairs: list[tuple[str, str]]) -> str:
|
||||
"""Build a minimal valid JSONL transcript from (role, content) pairs.
|
||||
|
||||
Use this helper in any copilot SDK test that needs a well-formed
|
||||
transcript without hitting the real storage layer.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
last_uuid: str | None = None
|
||||
for role, content in pairs:
|
||||
uid = str(uuid4())
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
msg: dict = {"role": role, "content": content}
|
||||
if role == "assistant":
|
||||
msg.update(
|
||||
{
|
||||
"model": "",
|
||||
"id": f"msg_{uid[:8]}",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
)
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": last_uuid,
|
||||
"message": msg,
|
||||
}
|
||||
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
last_uuid = uid
|
||||
return "\n".join(lines) + "\n"
|
||||
@@ -1,17 +1,9 @@
|
||||
"""Dummy SDK service for testing copilot streaming.
|
||||
|
||||
Returns mock streaming responses without calling Claude Agent SDK.
|
||||
Enable via CHAT_TEST_MODE=true in .env (ChatConfig.test_mode).
|
||||
Enable via COPILOT_TEST_MODE=true environment variable.
|
||||
|
||||
WARNING: This is for testing only. Do not use in production.
|
||||
|
||||
Magic keywords (case-insensitive, anywhere in message):
|
||||
__test_transient_error__ — Simulate a transient Anthropic API error
|
||||
(ECONNRESET). Streams partial text, then
|
||||
yields StreamError with retryable prefix.
|
||||
__test_fatal_error__ — Simulate a non-retryable SDK error.
|
||||
__test_slow_response__ — Simulate a slow response (2s per word).
|
||||
(no keyword) — Normal dummy response.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -20,39 +12,12 @@ import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from ..constants import (
|
||||
COPILOT_ERROR_PREFIX,
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX,
|
||||
FRIENDLY_TRANSIENT_MSG,
|
||||
)
|
||||
from ..model import ChatMessage, ChatSession, get_chat_session, upsert_chat_session
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
)
|
||||
from ..model import ChatSession
|
||||
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _safe_upsert(session: ChatSession) -> None:
|
||||
"""Best-effort session persist — skip silently if DB is unavailable."""
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception:
|
||||
logger.debug("[TEST MODE] Could not persist session (DB unavailable)")
|
||||
|
||||
|
||||
def _has_keyword(message: str | None, keyword: str) -> bool:
|
||||
return keyword in (message or "").lower()
|
||||
|
||||
|
||||
async def stream_chat_completion_dummy(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
@@ -71,89 +36,24 @@ async def stream_chat_completion_dummy(
|
||||
- No timeout occurs
|
||||
- Text arrives in chunks
|
||||
- StreamFinish is sent by mark_session_completed
|
||||
|
||||
See module docstring for magic keywords that trigger error scenarios.
|
||||
"""
|
||||
logger.warning(
|
||||
f"[TEST MODE] Using dummy copilot streaming for session {session_id}"
|
||||
)
|
||||
|
||||
# Load session from DB (matches SDK service behaviour) so error markers
|
||||
# and the assistant reply are persisted and survive page refresh.
|
||||
# Best-effort: skip if DB is unavailable (e.g. unit tests).
|
||||
if session is None:
|
||||
try:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
except Exception:
|
||||
logger.debug("[TEST MODE] Could not load session (DB unavailable)")
|
||||
session = None
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
# Start the stream (matches baseline: StreamStart → StreamStartStep)
|
||||
# Start the stream
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
yield StreamStartStep()
|
||||
|
||||
# --- Magic keyword: transient error (retryable) -------------------------
|
||||
if _has_keyword(message, "__test_transient_error__"):
|
||||
# Stream some partial text first (simulates mid-stream failure)
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
for word in ["Working", "on", "it..."]:
|
||||
yield StreamTextDelta(id=text_block_id, delta=f"{word} ")
|
||||
await asyncio.sleep(0.1)
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
yield StreamFinishStep()
|
||||
# Persist retryable marker so "Try Again" button shows after refresh
|
||||
if session:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=f"{COPILOT_RETRYABLE_ERROR_PREFIX} {FRIENDLY_TRANSIENT_MSG}",
|
||||
)
|
||||
)
|
||||
await _safe_upsert(session)
|
||||
yield StreamError(
|
||||
errorText=FRIENDLY_TRANSIENT_MSG,
|
||||
code="transient_api_error",
|
||||
)
|
||||
return
|
||||
|
||||
# --- Magic keyword: fatal error (non-retryable) -------------------------
|
||||
if _has_keyword(message, "__test_fatal_error__"):
|
||||
yield StreamFinishStep()
|
||||
error_msg = "Internal SDK error: model refused to respond"
|
||||
# Persist non-retryable error marker
|
||||
if session:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=f"{COPILOT_ERROR_PREFIX} {error_msg}",
|
||||
)
|
||||
)
|
||||
await _safe_upsert(session)
|
||||
yield StreamError(errorText=error_msg, code="sdk_error")
|
||||
return
|
||||
|
||||
# --- Magic keyword: slow response ---------------------------------------
|
||||
delay = 2.0 if _has_keyword(message, "__test_slow_response__") else 0.1
|
||||
|
||||
# --- Normal dummy response ----------------------------------------------
|
||||
# Simulate streaming text response with delays
|
||||
dummy_response = "I counted: 1... 2... 3. All done!"
|
||||
words = dummy_response.split()
|
||||
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
for i, word in enumerate(words):
|
||||
# Add space except for last word
|
||||
text = word if i == len(words) - 1 else f"{word} "
|
||||
yield StreamTextDelta(id=text_block_id, delta=text)
|
||||
await asyncio.sleep(delay)
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
|
||||
# Persist the assistant reply so it survives page refresh
|
||||
if session:
|
||||
session.messages.append(ChatMessage(role="assistant", content=dummy_response))
|
||||
await _safe_upsert(session)
|
||||
|
||||
yield StreamFinishStep()
|
||||
yield StreamFinish()
|
||||
# Small delay to simulate real streaming
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@@ -26,41 +26,6 @@ from backend.copilot.context import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _check_sandbox_symlink_escape(
|
||||
sandbox: Any,
|
||||
parent: str,
|
||||
) -> str | None:
|
||||
"""Resolve the canonical parent path inside the sandbox to detect symlink escapes.
|
||||
|
||||
``normpath`` (used by ``resolve_sandbox_path``) only normalises the string;
|
||||
``readlink -f`` follows actual symlinks on the sandbox filesystem.
|
||||
|
||||
Returns the canonical parent path, or ``None`` if the path escapes
|
||||
``E2B_WORKDIR``.
|
||||
|
||||
Note: There is an inherent TOCTOU window between this check and the
|
||||
subsequent ``sandbox.files.write()``. A symlink could theoretically be
|
||||
replaced between the two operations. This is acceptable in the E2B
|
||||
sandbox model since the sandbox is single-user and ephemeral.
|
||||
"""
|
||||
canonical_res = await sandbox.commands.run(
|
||||
f"readlink -f {shlex.quote(parent or E2B_WORKDIR)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=5,
|
||||
)
|
||||
canonical_parent = (canonical_res.stdout or "").strip()
|
||||
if (
|
||||
canonical_res.exit_code != 0
|
||||
or not canonical_parent
|
||||
or (
|
||||
canonical_parent != E2B_WORKDIR
|
||||
and not canonical_parent.startswith(E2B_WORKDIR + "/")
|
||||
)
|
||||
):
|
||||
return None
|
||||
return canonical_parent
|
||||
|
||||
|
||||
def _get_sandbox():
|
||||
return get_current_sandbox()
|
||||
|
||||
@@ -141,10 +106,6 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
parent = os.path.dirname(remote)
|
||||
if parent and parent != E2B_WORKDIR:
|
||||
await sandbox.files.make_dir(parent)
|
||||
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
|
||||
if canonical_parent is None:
|
||||
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
|
||||
remote = os.path.join(canonical_parent, os.path.basename(remote))
|
||||
await sandbox.files.write(remote, content)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
@@ -169,12 +130,6 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
parent = os.path.dirname(remote)
|
||||
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
|
||||
if canonical_parent is None:
|
||||
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
|
||||
remote = os.path.join(canonical_parent, os.path.basename(remote))
|
||||
|
||||
try:
|
||||
raw: bytes = await sandbox.files.read(remote, format="bytes")
|
||||
content = raw.decode("utf-8", errors="replace")
|
||||
|
||||
@@ -4,19 +4,15 @@ Pure unit tests with no external dependencies (no E2B, no sandbox).
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_project_dir
|
||||
from backend.copilot.context import _current_project_dir
|
||||
|
||||
from .e2b_file_tools import _read_local, resolve_sandbox_path
|
||||
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
from .e2b_file_tools import (
|
||||
_check_sandbox_symlink_escape,
|
||||
_read_local,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
|
||||
@@ -25,48 +21,46 @@ from .e2b_file_tools import (
|
||||
|
||||
class TestResolveSandboxPath:
|
||||
def test_relative_path_resolved(self):
|
||||
assert resolve_sandbox_path("src/main.py") == f"{E2B_WORKDIR}/src/main.py"
|
||||
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
|
||||
|
||||
def test_absolute_within_sandbox(self):
|
||||
assert (
|
||||
resolve_sandbox_path(f"{E2B_WORKDIR}/file.txt") == f"{E2B_WORKDIR}/file.txt"
|
||||
)
|
||||
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
|
||||
|
||||
def test_workdir_itself(self):
|
||||
assert resolve_sandbox_path(E2B_WORKDIR) == E2B_WORKDIR
|
||||
assert resolve_sandbox_path("/home/user") == "/home/user"
|
||||
|
||||
def test_relative_dotslash(self):
|
||||
assert resolve_sandbox_path("./README.md") == f"{E2B_WORKDIR}/README.md"
|
||||
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
|
||||
|
||||
def test_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("../../etc/passwd")
|
||||
|
||||
def test_absolute_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path(f"{E2B_WORKDIR}/../../etc/passwd")
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
|
||||
def test_absolute_outside_sandbox_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
|
||||
def test_root_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/")
|
||||
|
||||
def test_home_other_user_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/home/other/file.txt")
|
||||
|
||||
def test_deep_nested_allowed(self):
|
||||
assert resolve_sandbox_path("a/b/c/d/e.txt") == f"{E2B_WORKDIR}/a/b/c/d/e.txt"
|
||||
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||
|
||||
def test_trailing_slash_normalised(self):
|
||||
assert resolve_sandbox_path("src/") == f"{E2B_WORKDIR}/src"
|
||||
assert resolve_sandbox_path("src/") == "/home/user/src"
|
||||
|
||||
def test_double_dots_within_sandbox_ok(self):
|
||||
"""Path that resolves back within E2B_WORKDIR is allowed."""
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == f"{E2B_WORKDIR}/a/c.txt"
|
||||
"""Path that resolves back within /home/user is allowed."""
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -79,13 +73,9 @@ class TestResolveSandboxPath:
|
||||
|
||||
|
||||
class TestReadLocal:
|
||||
_CONV_UUID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
|
||||
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
|
||||
"""Create a tool-results file under <encoded>/<uuid>/tool-results/."""
|
||||
tool_results_dir = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
|
||||
)
|
||||
"""Create a tool-results file and return its path."""
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, filename)
|
||||
with open(filepath, "w") as f:
|
||||
@@ -117,9 +107,7 @@ class TestReadLocal:
|
||||
def test_read_nonexistent_tool_results(self):
|
||||
"""A tool-results path that doesn't exist returns FileNotFoundError."""
|
||||
encoded = "-tmp-copilot-e2b-test-nofile"
|
||||
tool_results_dir = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
|
||||
)
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
|
||||
token = _current_project_dir.set(encoded)
|
||||
@@ -129,7 +117,7 @@ class TestReadLocal:
|
||||
assert "not found" in result["content"][0]["text"].lower()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
|
||||
os.rmdir(tool_results_dir)
|
||||
|
||||
def test_read_traversal_path_blocked(self):
|
||||
"""A traversal attempt that escapes allowed directories is blocked."""
|
||||
@@ -164,66 +152,3 @@ class TestReadLocal:
|
||||
"""Without _current_project_dir set, all paths are blocked."""
|
||||
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _check_sandbox_symlink_escape — symlink escape detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_sandbox(stdout: str, exit_code: int = 0) -> SimpleNamespace:
|
||||
"""Build a minimal sandbox mock whose commands.run returns a fixed result."""
|
||||
run_result = SimpleNamespace(stdout=stdout, exit_code=exit_code)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
return SimpleNamespace(commands=commands)
|
||||
|
||||
|
||||
class TestCheckSandboxSymlinkEscape:
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonical_path_within_workdir_returns_path(self):
|
||||
"""When readlink -f resolves to a path inside E2B_WORKDIR, returns it."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/src\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result == f"{E2B_WORKDIR}/src"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workdir_itself_returns_workdir(self):
|
||||
"""When readlink -f resolves to E2B_WORKDIR exactly, returns E2B_WORKDIR."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, E2B_WORKDIR)
|
||||
assert result == E2B_WORKDIR
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_symlink_escape_returns_none(self):
|
||||
"""When readlink -f resolves outside E2B_WORKDIR (symlink escape), returns None."""
|
||||
sandbox = _make_sandbox(stdout="/etc\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/evil")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonzero_exit_code_returns_none(self):
|
||||
"""A non-zero exit code from readlink -f returns None."""
|
||||
sandbox = _make_sandbox(stdout="", exit_code=1)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_stdout_returns_none(self):
|
||||
"""Empty stdout from readlink (e.g. path doesn't exist yet) returns None."""
|
||||
sandbox = _make_sandbox(stdout="", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prefix_collision_returns_none(self):
|
||||
"""A path prefixed with E2B_WORKDIR but not within it is rejected."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}-evil\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}-evil")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deeply_nested_path_within_workdir(self):
|
||||
"""Deep nested paths inside E2B_WORKDIR are allowed."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/a/b/c/d\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/a/b/c/d")
|
||||
assert result == f"{E2B_WORKDIR}/a/b/c/d"
|
||||
|
||||
@@ -1,531 +0,0 @@
|
||||
"""End-to-end compaction flow test.
|
||||
|
||||
Simulates the full service.py compaction lifecycle using real-format
|
||||
JSONL session files — no SDK subprocess needed. Exercises:
|
||||
|
||||
1. TranscriptBuilder loads a "downloaded" transcript
|
||||
2. User query appended, assistant response streamed
|
||||
3. PreCompact hook fires → CompactionTracker.on_compact()
|
||||
4. Next message → emit_start_if_ready() yields spinner events
|
||||
5. Message after that → emit_end_if_ready() returns CompactionResult
|
||||
6. read_compacted_entries() reads the CLI session file
|
||||
7. TranscriptBuilder.replace_entries() syncs state
|
||||
8. More messages appended post-compaction
|
||||
9. to_jsonl() exports full state for upload
|
||||
10. Fresh builder loads the export — roundtrip verified
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import (
|
||||
StreamFinishStep,
|
||||
StreamStartStep,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.sdk.compaction import CompactionTracker
|
||||
from backend.copilot.sdk.transcript import (
|
||||
read_compacted_entries,
|
||||
strip_progress_entries,
|
||||
)
|
||||
from backend.copilot.sdk.transcript_builder import TranscriptBuilder
|
||||
from backend.util import json
|
||||
|
||||
|
||||
def _make_jsonl(*entries: dict) -> str:
|
||||
return "\n".join(json.dumps(e) for e in entries) + "\n"
|
||||
|
||||
|
||||
def _run(coro):
|
||||
"""Run an async coroutine synchronously."""
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: realistic CLI session file content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Pre-compaction conversation
|
||||
USER_1 = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "What files are in this project?"},
|
||||
}
|
||||
ASST_1_THINKING = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-think",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_aaa",
|
||||
"type": "message",
|
||||
"content": [{"type": "thinking", "thinking": "Let me look at the files..."}],
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
ASST_1_TOOL = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-tool",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_aaa",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tu1",
|
||||
"name": "Bash",
|
||||
"input": {"command": "ls"},
|
||||
}
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
TOOL_RESULT_1 = {
|
||||
"type": "user",
|
||||
"uuid": "tr1",
|
||||
"parentUuid": "a1-tool",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "tu1",
|
||||
"content": "file1.py\nfile2.py",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
ASST_1_TEXT = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-text",
|
||||
"parentUuid": "tr1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_bbb",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "I found file1.py and file2.py."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
# Progress entries (should be stripped during upload)
|
||||
PROGRESS_1 = {
|
||||
"type": "progress",
|
||||
"uuid": "prog1",
|
||||
"parentUuid": "a1-tool",
|
||||
"data": {"type": "bash_progress", "stdout": "running ls..."},
|
||||
}
|
||||
# Second user message
|
||||
USER_2 = {
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"parentUuid": "a1-text",
|
||||
"message": {"role": "user", "content": "Show me file1.py"},
|
||||
}
|
||||
ASST_2 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "u2",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_ccc",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is file1.py content..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
# --- Compaction summary (written by CLI after context compaction) ---
|
||||
COMPACT_SUMMARY = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Summary: User asked about project files. Found file1.py and file2.py. "
|
||||
"User then asked to see file1.py."
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# Post-compaction assistant response
|
||||
POST_COMPACT_ASST = {
|
||||
"type": "assistant",
|
||||
"uuid": "a3",
|
||||
"parentUuid": "cs1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_ddd",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is the content of file1.py..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
# Post-compaction user follow-up
|
||||
USER_3 = {
|
||||
"type": "user",
|
||||
"uuid": "u3",
|
||||
"parentUuid": "a3",
|
||||
"message": {"role": "user", "content": "Now show file2.py"},
|
||||
}
|
||||
ASST_3 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a4",
|
||||
"parentUuid": "u3",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_eee",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is file2.py..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2E test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactionE2E:
|
||||
def _write_session_file(self, session_dir, entries):
|
||||
"""Write a CLI session JSONL file."""
|
||||
path = session_dir / "session.jsonl"
|
||||
path.write_text(_make_jsonl(*entries))
|
||||
return path
|
||||
|
||||
def test_full_compaction_lifecycle(self, tmp_path, monkeypatch):
|
||||
"""Simulate the complete service.py compaction flow.
|
||||
|
||||
Timeline:
|
||||
1. Previous turn uploaded transcript with [USER_1, ASST_1, USER_2, ASST_2]
|
||||
2. Current turn: download → load_previous
|
||||
3. User sends "Now show file2.py" → append_user
|
||||
4. SDK starts streaming response
|
||||
5. Mid-stream: PreCompact hook fires (context too large)
|
||||
6. CLI writes compaction summary to session file
|
||||
7. Next SDK message → emit_start (spinner)
|
||||
8. Following message → emit_end (CompactionResult)
|
||||
9. read_compacted_entries reads the session file
|
||||
10. replace_entries syncs TranscriptBuilder
|
||||
11. More assistant messages appended
|
||||
12. Export → upload → next turn downloads it
|
||||
"""
|
||||
# --- Setup CLI projects directory ---
|
||||
config_dir = tmp_path / "config"
|
||||
projects_dir = config_dir / "projects"
|
||||
session_dir = projects_dir / "proj"
|
||||
session_dir.mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
# --- Step 1-2: Load "downloaded" transcript from previous turn ---
|
||||
previous_transcript = _make_jsonl(
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(previous_transcript)
|
||||
assert builder.entry_count == 7
|
||||
|
||||
# --- Step 3: User sends new query ---
|
||||
builder.append_user("Now show file2.py")
|
||||
assert builder.entry_count == 8
|
||||
|
||||
# --- Step 4: SDK starts streaming ---
|
||||
builder.append_assistant(
|
||||
[{"type": "thinking", "thinking": "Let me read file2.py..."}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
)
|
||||
assert builder.entry_count == 9
|
||||
|
||||
# --- Step 5-6: PreCompact fires, CLI writes session file ---
|
||||
session_file = self._write_session_file(
|
||||
session_dir,
|
||||
[
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
PROGRESS_1,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
COMPACT_SUMMARY,
|
||||
POST_COMPACT_ASST,
|
||||
USER_3,
|
||||
ASST_3,
|
||||
],
|
||||
)
|
||||
|
||||
# --- Step 7: CompactionTracker receives PreCompact hook ---
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test-user")
|
||||
tracker.on_compact(str(session_file))
|
||||
|
||||
# --- Step 8: Next SDK message arrives → emit_start ---
|
||||
start_events = tracker.emit_start_if_ready()
|
||||
assert len(start_events) == 3
|
||||
assert isinstance(start_events[0], StreamStartStep)
|
||||
assert isinstance(start_events[1], StreamToolInputStart)
|
||||
assert isinstance(start_events[2], StreamToolInputAvailable)
|
||||
|
||||
# Verify tool_call_id is set
|
||||
tool_call_id = start_events[1].toolCallId
|
||||
assert tool_call_id.startswith("compaction-")
|
||||
|
||||
# --- Step 9: Following message → emit_end ---
|
||||
result = _run(tracker.emit_end_if_ready(session))
|
||||
assert result.just_ended is True
|
||||
assert result.transcript_path == str(session_file)
|
||||
assert len(result.events) == 2
|
||||
assert isinstance(result.events[0], StreamToolOutputAvailable)
|
||||
assert isinstance(result.events[1], StreamFinishStep)
|
||||
# Verify same tool_call_id
|
||||
assert result.events[0].toolCallId == tool_call_id
|
||||
|
||||
# Session should have compaction messages persisted
|
||||
assert len(session.messages) == 2
|
||||
assert session.messages[0].role == "assistant"
|
||||
assert session.messages[1].role == "tool"
|
||||
|
||||
# --- Step 10: read_compacted_entries + replace_entries ---
|
||||
compacted = read_compacted_entries(str(session_file))
|
||||
assert compacted is not None
|
||||
# Should have: COMPACT_SUMMARY + POST_COMPACT_ASST + USER_3 + ASST_3
|
||||
assert len(compacted) == 4
|
||||
assert compacted[0]["uuid"] == "cs1"
|
||||
assert compacted[0]["isCompactSummary"] is True
|
||||
|
||||
# Replace builder state with compacted entries
|
||||
old_count = builder.entry_count
|
||||
builder.replace_entries(compacted)
|
||||
assert builder.entry_count == 4 # Only compacted entries
|
||||
assert builder.entry_count < old_count # Compaction reduced entries
|
||||
|
||||
# --- Step 11: More assistant messages after compaction ---
|
||||
builder.append_assistant(
|
||||
[{"type": "text", "text": "Here is file2.py:\n\ndef hello():\n pass"}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
stop_reason="end_turn",
|
||||
)
|
||||
assert builder.entry_count == 5
|
||||
|
||||
# --- Step 12: Export for upload ---
|
||||
output = builder.to_jsonl()
|
||||
assert output # Not empty
|
||||
output_entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert len(output_entries) == 5
|
||||
|
||||
# Verify structure:
|
||||
# [COMPACT_SUMMARY, POST_COMPACT_ASST, USER_3, ASST_3, new_assistant]
|
||||
assert output_entries[0]["type"] == "summary"
|
||||
assert output_entries[0].get("isCompactSummary") is True
|
||||
assert output_entries[0]["uuid"] == "cs1"
|
||||
assert output_entries[1]["uuid"] == "a3"
|
||||
assert output_entries[2]["uuid"] == "u3"
|
||||
assert output_entries[3]["uuid"] == "a4"
|
||||
assert output_entries[4]["type"] == "assistant"
|
||||
|
||||
# Verify parent chain is intact
|
||||
assert output_entries[1]["parentUuid"] == "cs1" # a3 → cs1
|
||||
assert output_entries[2]["parentUuid"] == "a3" # u3 → a3
|
||||
assert output_entries[3]["parentUuid"] == "u3" # a4 → u3
|
||||
assert output_entries[4]["parentUuid"] == "a4" # new → a4
|
||||
|
||||
# --- Step 13: Roundtrip — next turn loads this export ---
|
||||
builder2 = TranscriptBuilder()
|
||||
builder2.load_previous(output)
|
||||
assert builder2.entry_count == 5
|
||||
|
||||
# isCompactSummary survives roundtrip
|
||||
output2 = builder2.to_jsonl()
|
||||
first_entry = json.loads(output2.strip().split("\n")[0])
|
||||
assert first_entry.get("isCompactSummary") is True
|
||||
|
||||
# Can append more messages
|
||||
builder2.append_user("What about file3.py?")
|
||||
assert builder2.entry_count == 6
|
||||
final_output = builder2.to_jsonl()
|
||||
last_entry = json.loads(final_output.strip().split("\n")[-1])
|
||||
assert last_entry["type"] == "user"
|
||||
# Parented to the last entry from previous turn
|
||||
assert last_entry["parentUuid"] == output_entries[-1]["uuid"]
|
||||
|
||||
def test_double_compaction_within_session(self, tmp_path, monkeypatch):
|
||||
"""Two compactions in the same session (across reset_for_query)."""
|
||||
config_dir = tmp_path / "config"
|
||||
projects_dir = config_dir / "projects"
|
||||
session_dir = projects_dir / "proj"
|
||||
session_dir.mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test")
|
||||
builder = TranscriptBuilder()
|
||||
|
||||
# --- First query with compaction ---
|
||||
builder.append_user("first question")
|
||||
builder.append_assistant([{"type": "text", "text": "first answer"}])
|
||||
|
||||
# Write session file for first compaction
|
||||
first_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs-first",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "First compaction summary"},
|
||||
}
|
||||
first_post = {
|
||||
"type": "assistant",
|
||||
"uuid": "a-first",
|
||||
"parentUuid": "cs-first",
|
||||
"message": {"role": "assistant", "content": "first post-compact"},
|
||||
}
|
||||
file1 = session_dir / "session1.jsonl"
|
||||
file1.write_text(_make_jsonl(first_summary, first_post))
|
||||
|
||||
tracker.on_compact(str(file1))
|
||||
tracker.emit_start_if_ready()
|
||||
result1 = _run(tracker.emit_end_if_ready(session))
|
||||
assert result1.just_ended is True
|
||||
|
||||
compacted1 = read_compacted_entries(str(file1))
|
||||
assert compacted1 is not None
|
||||
builder.replace_entries(compacted1)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# --- Reset for second query ---
|
||||
tracker.reset_for_query()
|
||||
|
||||
# --- Second query with compaction ---
|
||||
builder.append_user("second question")
|
||||
builder.append_assistant([{"type": "text", "text": "second answer"}])
|
||||
|
||||
second_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs-second",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "Second compaction summary"},
|
||||
}
|
||||
second_post = {
|
||||
"type": "assistant",
|
||||
"uuid": "a-second",
|
||||
"parentUuid": "cs-second",
|
||||
"message": {"role": "assistant", "content": "second post-compact"},
|
||||
}
|
||||
file2 = session_dir / "session2.jsonl"
|
||||
file2.write_text(_make_jsonl(second_summary, second_post))
|
||||
|
||||
tracker.on_compact(str(file2))
|
||||
tracker.emit_start_if_ready()
|
||||
result2 = _run(tracker.emit_end_if_ready(session))
|
||||
assert result2.just_ended is True
|
||||
|
||||
compacted2 = read_compacted_entries(str(file2))
|
||||
assert compacted2 is not None
|
||||
builder.replace_entries(compacted2)
|
||||
assert builder.entry_count == 2 # Only second compaction entries
|
||||
|
||||
# Export and verify
|
||||
output = builder.to_jsonl()
|
||||
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert entries[0]["uuid"] == "cs-second"
|
||||
assert entries[0].get("isCompactSummary") is True
|
||||
|
||||
def test_strip_progress_then_load_then_compact_roundtrip(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""Full pipeline: strip → load → compact → replace → export → reload.
|
||||
|
||||
This tests the exact sequence that happens across two turns:
|
||||
Turn 1: SDK produces transcript with progress entries
|
||||
Upload: strip_progress_entries removes progress, upload to cloud
|
||||
Turn 2: Download → load_previous → compaction fires → replace → export
|
||||
Turn 3: Download the Turn 2 export → load_previous (roundtrip)
|
||||
"""
|
||||
config_dir = tmp_path / "config"
|
||||
projects_dir = config_dir / "projects"
|
||||
session_dir = projects_dir / "proj"
|
||||
session_dir.mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
# --- Turn 1: SDK produces raw transcript ---
|
||||
raw_content = _make_jsonl(
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
PROGRESS_1,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
)
|
||||
|
||||
# Strip progress for upload
|
||||
stripped = strip_progress_entries(raw_content)
|
||||
stripped_entries = [
|
||||
json.loads(line) for line in stripped.strip().split("\n") if line.strip()
|
||||
]
|
||||
# Progress should be gone
|
||||
assert not any(e.get("type") == "progress" for e in stripped_entries)
|
||||
assert len(stripped_entries) == 7 # 8 - 1 progress
|
||||
|
||||
# --- Turn 2: Download stripped, load, compaction happens ---
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(stripped)
|
||||
assert builder.entry_count == 7
|
||||
|
||||
builder.append_user("Now show file2.py")
|
||||
builder.append_assistant(
|
||||
[{"type": "text", "text": "Reading file2.py..."}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
)
|
||||
|
||||
# CLI writes session file with compaction
|
||||
session_file = self._write_session_file(
|
||||
session_dir,
|
||||
[
|
||||
USER_1,
|
||||
ASST_1_TOOL,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
COMPACT_SUMMARY,
|
||||
POST_COMPACT_ASST,
|
||||
],
|
||||
)
|
||||
|
||||
compacted = read_compacted_entries(str(session_file))
|
||||
assert compacted is not None
|
||||
builder.replace_entries(compacted)
|
||||
|
||||
# Append post-compaction message
|
||||
builder.append_user("Thanks!")
|
||||
output = builder.to_jsonl()
|
||||
|
||||
# --- Turn 3: Fresh load of Turn 2 export ---
|
||||
builder3 = TranscriptBuilder()
|
||||
builder3.load_previous(output)
|
||||
# Should have: compact_summary + post_compact_asst + "Thanks!"
|
||||
assert builder3.entry_count == 3
|
||||
|
||||
# Compact summary survived the full pipeline
|
||||
first = json.loads(builder3.to_jsonl().strip().split("\n")[0])
|
||||
assert first.get("isCompactSummary") is True
|
||||
assert first["type"] == "summary"
|
||||
@@ -41,20 +41,12 @@ from typing import Any
|
||||
from backend.copilot.context import (
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
get_workspace_manager,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.workspace_files import get_manager
|
||||
from backend.util.file import parse_workspace_uri
|
||||
from backend.util.file_content_parser import (
|
||||
BINARY_FORMATS,
|
||||
MIME_TO_FORMAT,
|
||||
PARSE_EXCEPTIONS,
|
||||
infer_format_from_uri,
|
||||
parse_file_content,
|
||||
)
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
class FileRefExpansionError(Exception):
|
||||
@@ -82,8 +74,6 @@ _FILE_REF_RE = re.compile(
|
||||
_MAX_EXPAND_CHARS = 200_000
|
||||
# Maximum total characters across all @@agptfile: expansions in one string.
|
||||
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
|
||||
# Maximum raw byte size for bare ref structured parsing (10 MB).
|
||||
_MAX_BARE_REF_BYTES = 10_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -93,11 +83,6 @@ class FileRef:
|
||||
end_line: int | None # 1-indexed, inclusive
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API (top-down: main functions first, helpers below)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_file_ref(text: str) -> FileRef | None:
|
||||
"""Return a :class:`FileRef` if *text* is a bare file reference token.
|
||||
|
||||
@@ -119,6 +104,17 @@ def parse_file_ref(text: str) -> FileRef | None:
|
||||
return FileRef(uri=m.group(1), start_line=start, end_line=end)
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else len(lines)
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
return "".join(selected)
|
||||
|
||||
|
||||
async def read_file_bytes(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
@@ -134,47 +130,27 @@ async def read_file_bytes(
|
||||
if plain.startswith("workspace://"):
|
||||
if not user_id:
|
||||
raise ValueError("workspace:// file references require authentication")
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
ws = parse_workspace_uri(plain)
|
||||
try:
|
||||
data = await (
|
||||
return await (
|
||||
manager.read_file(ws.file_ref)
|
||||
if ws.is_path
|
||||
else manager.read_file_by_id(ws.file_ref)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except (PermissionError, OSError) as exc:
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
except (AttributeError, TypeError, RuntimeError) as exc:
|
||||
# AttributeError/TypeError: workspace manager returned an
|
||||
# unexpected type or interface; RuntimeError: async runtime issues.
|
||||
logger.warning("Unexpected error reading %s: %s", plain, exc)
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
# NOTE: Workspace API does not support pre-read size checks;
|
||||
# the full file is loaded before the size guard below.
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
|
||||
if is_allowed_local_path(plain, get_sdk_cwd()):
|
||||
resolved = os.path.realpath(os.path.expanduser(plain))
|
||||
try:
|
||||
# Read with a one-byte overshoot to detect files that exceed the limit
|
||||
# without a separate os.path.getsize call (avoids TOCTOU race).
|
||||
with open(resolved, "rb") as fh:
|
||||
data = fh.read(_MAX_BARE_REF_BYTES + 1)
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large (>{_MAX_BARE_REF_BYTES} bytes, "
|
||||
f"limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
return fh.read()
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except OSError as exc:
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
@@ -186,33 +162,9 @@ async def read_file_bytes(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
) from exc
|
||||
try:
|
||||
data = bytes(await sandbox.files.read(remote, format="bytes"))
|
||||
except (FileNotFoundError, OSError, UnicodeDecodeError) as exc:
|
||||
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||
return bytes(await sandbox.files.read(remote, format="bytes"))
|
||||
except Exception as exc:
|
||||
# E2B SDK raises SandboxException subclasses (NotFoundException,
|
||||
# TimeoutException, NotEnoughSpaceException, etc.) which don't
|
||||
# inherit from standard exceptions. Import lazily to avoid a
|
||||
# hard dependency on e2b at module level.
|
||||
try:
|
||||
from e2b.exceptions import SandboxException # noqa: PLC0415
|
||||
|
||||
if isinstance(exc, SandboxException):
|
||||
raise ValueError(
|
||||
f"Failed to read from sandbox: {plain}: {exc}"
|
||||
) from exc
|
||||
except ImportError:
|
||||
pass
|
||||
# Re-raise unexpected exceptions (TypeError, AttributeError, etc.)
|
||||
# so they surface as real bugs rather than being silently masked.
|
||||
raise
|
||||
# NOTE: E2B sandbox API does not support pre-read size checks;
|
||||
# the full file is loaded before the size guard below.
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
@@ -226,13 +178,15 @@ async def resolve_file_ref(
|
||||
) -> str:
|
||||
"""Resolve a :class:`FileRef` to its text content."""
|
||||
raw = await read_file_bytes(ref.uri, user_id, session)
|
||||
return _apply_line_range(_to_str(raw), ref.start_line, ref.end_line)
|
||||
return _apply_line_range(
|
||||
raw.decode("utf-8", errors="replace"), ref.start_line, ref.end_line
|
||||
)
|
||||
|
||||
|
||||
async def expand_file_refs_in_string(
|
||||
text: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
session: "ChatSession",
|
||||
*,
|
||||
raise_on_error: bool = False,
|
||||
) -> str:
|
||||
@@ -278,9 +232,6 @@ async def expand_file_refs_in_string(
|
||||
if len(content) > _MAX_EXPAND_CHARS:
|
||||
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
|
||||
# remaining == 0 means the budget was exactly exhausted by the
|
||||
# previous ref. The elif below (len > remaining) won't catch
|
||||
# this since 0 > 0 is false, so we need the <= 0 check.
|
||||
if remaining <= 0:
|
||||
content = "[file-ref budget exhausted: total expansion limit reached]"
|
||||
elif len(content) > remaining:
|
||||
@@ -301,31 +252,13 @@ async def expand_file_refs_in_string(
|
||||
async def expand_file_refs_in_args(
|
||||
args: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
*,
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
session: "ChatSession",
|
||||
) -> dict[str, Any]:
|
||||
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
|
||||
|
||||
String values are expanded in-place. Nested dicts and lists are
|
||||
traversed. Non-string scalars are returned unchanged.
|
||||
|
||||
**Bare references** (the entire argument value is a single
|
||||
``@@agptfile:...`` token with no surrounding text) are resolved and then
|
||||
parsed according to the file's extension or MIME type. See
|
||||
:mod:`backend.util.file_content_parser` for the full list of supported
|
||||
formats (JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel).
|
||||
|
||||
When *input_schema* is provided and the target property has
|
||||
``"type": "string"``, structured parsing is skipped — the raw file content
|
||||
is returned as a plain string so blocks receive the original text.
|
||||
|
||||
If the format is unrecognised or parsing fails, the content is returned as
|
||||
a plain string (the fallback).
|
||||
|
||||
**Embedded references** (``@@agptfile:`` mixed with other text) always
|
||||
produce a plain string — structured parsing only applies to bare refs.
|
||||
|
||||
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
|
||||
so the tool is *not* executed with an error string as its input. The
|
||||
caller (the MCP tool wrapper) should convert this into an MCP error
|
||||
@@ -334,382 +267,15 @@ async def expand_file_refs_in_args(
|
||||
if not args:
|
||||
return args
|
||||
|
||||
properties = (input_schema or {}).get("properties", {})
|
||||
|
||||
async def _expand(
|
||||
value: Any,
|
||||
*,
|
||||
prop_schema: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Recursively expand a single argument value.
|
||||
|
||||
Strings are checked for ``@@agptfile:`` references and expanded
|
||||
(bare refs get structured parsing; embedded refs get inline
|
||||
substitution). Dicts and lists are traversed recursively,
|
||||
threading the corresponding sub-schema from *prop_schema* so
|
||||
that nested fields also receive correct type-aware expansion.
|
||||
Non-string scalars pass through unchanged.
|
||||
"""
|
||||
async def _expand(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
ref = parse_file_ref(value)
|
||||
if ref is not None:
|
||||
# MediaFileType fields: return the raw URI immediately —
|
||||
# no file reading, no format inference, no content parsing.
|
||||
if _is_media_file_field(prop_schema):
|
||||
return ref.uri
|
||||
|
||||
fmt = infer_format_from_uri(ref.uri)
|
||||
# Workspace URIs by ID (workspace://abc123) have no extension.
|
||||
# When the MIME fragment is also missing, fall back to the
|
||||
# workspace file manager's metadata for format detection.
|
||||
if fmt is None and ref.uri.startswith("workspace://"):
|
||||
fmt = await _infer_format_from_workspace(ref.uri, user_id, session)
|
||||
return await _expand_bare_ref(ref, fmt, user_id, session, prop_schema)
|
||||
|
||||
# Not a bare ref — do normal inline expansion.
|
||||
return await expand_file_refs_in_string(
|
||||
value, user_id, session, raise_on_error=True
|
||||
)
|
||||
if isinstance(value, dict):
|
||||
# When the schema says this is an object but doesn't define
|
||||
# inner properties, skip expansion — the caller (e.g.
|
||||
# RunBlockTool) will expand with the actual nested schema.
|
||||
if (
|
||||
prop_schema is not None
|
||||
and prop_schema.get("type") == "object"
|
||||
and "properties" not in prop_schema
|
||||
):
|
||||
return value
|
||||
nested_props = (prop_schema or {}).get("properties", {})
|
||||
return {
|
||||
k: await _expand(v, prop_schema=nested_props.get(k))
|
||||
for k, v in value.items()
|
||||
}
|
||||
return {k: await _expand(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
items_schema = (prop_schema or {}).get("items")
|
||||
return [await _expand(item, prop_schema=items_schema) for item in value]
|
||||
return [await _expand(item) for item in value]
|
||||
return value
|
||||
|
||||
return {k: await _expand(v, prop_schema=properties.get(k)) for k, v in args.items()}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers (used by the public functions above)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive).
|
||||
|
||||
When the requested range extends beyond the file, a note is appended
|
||||
so the LLM knows it received the entire remaining content.
|
||||
"""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
total = len(lines)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else total
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
result = "".join(selected)
|
||||
if end is not None and end > total:
|
||||
result += f"\n[Note: file has only {total} lines]\n"
|
||||
return result
|
||||
|
||||
|
||||
def _to_str(content: str | bytes) -> str:
|
||||
"""Decode *content* to a string if it is bytes, otherwise return as-is."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return content.decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def _check_content_size(content: str | bytes) -> None:
|
||||
"""Raise :class:`ValueError` if *content* exceeds the byte limit.
|
||||
|
||||
Raises ``ValueError`` (not ``FileRefExpansionError``) so that the caller
|
||||
(``_expand_bare_ref``) can unify all resolution errors into a single
|
||||
``except ValueError`` → ``FileRefExpansionError`` handler, keeping the
|
||||
error-flow consistent with ``read_file_bytes`` and ``resolve_file_ref``.
|
||||
|
||||
For ``bytes``, the length is the byte count directly. For ``str``,
|
||||
we encode to UTF-8 first because multi-byte characters (e.g. emoji)
|
||||
mean the byte size can be up to 4x the character count.
|
||||
"""
|
||||
if isinstance(content, bytes):
|
||||
size = len(content)
|
||||
else:
|
||||
char_len = len(content)
|
||||
# Fast lower bound: UTF-8 byte count >= char count.
|
||||
# If char count already exceeds the limit, reject immediately
|
||||
# without allocating an encoded copy.
|
||||
if char_len > _MAX_BARE_REF_BYTES:
|
||||
size = char_len # real byte size is even larger
|
||||
# Fast upper bound: each char is at most 4 UTF-8 bytes.
|
||||
# If worst-case is still under the limit, skip encoding entirely.
|
||||
elif char_len * 4 <= _MAX_BARE_REF_BYTES:
|
||||
return
|
||||
else:
|
||||
# Edge case: char count is under limit but multibyte chars
|
||||
# might push byte count over. Encode to get exact size.
|
||||
size = len(content.encode("utf-8"))
|
||||
if size > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large for structured parsing "
|
||||
f"({size} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
|
||||
|
||||
async def _infer_format_from_workspace(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> str | None:
|
||||
"""Look up workspace file metadata to infer the format.
|
||||
|
||||
Workspace URIs by ID (``workspace://abc123``) have no file extension.
|
||||
When the MIME fragment is also absent, we query the workspace file
|
||||
manager for the file's stored MIME type and original filename.
|
||||
"""
|
||||
if not user_id:
|
||||
return None
|
||||
try:
|
||||
ws = parse_workspace_uri(uri)
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
info = await (
|
||||
manager.get_file_info(ws.file_ref)
|
||||
if not ws.is_path
|
||||
else manager.get_file_info_by_path(ws.file_ref)
|
||||
)
|
||||
if info is None:
|
||||
return None
|
||||
# Try MIME type first, then filename extension.
|
||||
mime = (info.mime_type or "").split(";", 1)[0].strip().lower()
|
||||
return MIME_TO_FORMAT.get(mime) or infer_format_from_uri(info.name)
|
||||
except (
|
||||
ValueError,
|
||||
FileNotFoundError,
|
||||
OSError,
|
||||
PermissionError,
|
||||
AttributeError,
|
||||
TypeError,
|
||||
):
|
||||
# Expected failures: bad URI, missing file, permission denied, or
|
||||
# workspace manager returning unexpected types. Propagate anything
|
||||
# else (e.g. programming errors) so they don't get silently swallowed.
|
||||
logger.debug("workspace metadata lookup failed for %s", uri, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def _is_media_file_field(prop_schema: dict[str, Any] | None) -> bool:
|
||||
"""Return True if *prop_schema* describes a MediaFileType field (format: file)."""
|
||||
if prop_schema is None:
|
||||
return False
|
||||
return (
|
||||
prop_schema.get("type") == "string"
|
||||
and prop_schema.get("format") == MediaFileType.string_format
|
||||
)
|
||||
|
||||
|
||||
async def _expand_bare_ref(
|
||||
ref: FileRef,
|
||||
fmt: str | None,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
prop_schema: dict[str, Any] | None,
|
||||
) -> Any:
|
||||
"""Resolve and parse a bare ``@@agptfile:`` reference.
|
||||
|
||||
This is the structured-parsing path: the file is read, optionally parsed
|
||||
according to *fmt*, and adapted to the target *prop_schema*.
|
||||
|
||||
Raises :class:`FileRefExpansionError` on resolution or parsing failure.
|
||||
|
||||
Note: MediaFileType fields (format: "file") are handled earlier in
|
||||
``_expand`` to avoid unnecessary format inference and file I/O.
|
||||
"""
|
||||
try:
|
||||
if fmt is not None and fmt in BINARY_FORMATS:
|
||||
# Binary formats need raw bytes, not UTF-8 text.
|
||||
# Line ranges are meaningless for binary formats (parquet/xlsx)
|
||||
# — ignore them and parse full bytes. Warn so the caller/model
|
||||
# knows the range was silently dropped.
|
||||
if ref.start_line is not None or ref.end_line is not None:
|
||||
logger.warning(
|
||||
"Line range [%s-%s] ignored for binary format %s (%s); "
|
||||
"binary formats are always parsed in full.",
|
||||
ref.start_line,
|
||||
ref.end_line,
|
||||
fmt,
|
||||
ref.uri,
|
||||
)
|
||||
content: str | bytes = await read_file_bytes(ref.uri, user_id, session)
|
||||
else:
|
||||
content = await resolve_file_ref(ref, user_id, session)
|
||||
except ValueError as exc:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
|
||||
# For known formats this rejects files >10 MB before parsing.
|
||||
# For unknown formats _MAX_EXPAND_CHARS (200K chars) below is stricter,
|
||||
# but this check still guards the parsing path which has no char limit.
|
||||
# _check_content_size raises ValueError, which we unify here just like
|
||||
# resolution errors above.
|
||||
try:
|
||||
_check_content_size(content)
|
||||
except ValueError as exc:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
|
||||
# When the schema declares this parameter as "string",
|
||||
# return raw file content — don't parse into a structured
|
||||
# type that would need json.dumps() serialisation.
|
||||
expect_string = (prop_schema or {}).get("type") == "string"
|
||||
if expect_string:
|
||||
if isinstance(content, bytes):
|
||||
raise FileRefExpansionError(
|
||||
f"Cannot use {fmt} file as text input: "
|
||||
f"binary formats (parquet, xlsx) must be passed "
|
||||
f"to a block that accepts structured data (list/object), "
|
||||
f"not a string-typed parameter."
|
||||
)
|
||||
return content
|
||||
|
||||
if fmt is not None:
|
||||
# Use strict mode for binary formats so we surface the
|
||||
# actual error (e.g. missing pyarrow/openpyxl, corrupt
|
||||
# file) instead of silently returning garbled bytes.
|
||||
strict = fmt in BINARY_FORMATS
|
||||
try:
|
||||
parsed = parse_file_content(content, fmt, strict=strict)
|
||||
except PARSE_EXCEPTIONS as exc:
|
||||
raise FileRefExpansionError(f"Failed to parse {fmt} file: {exc}") from exc
|
||||
# Normalize bytes fallback to str so tools never
|
||||
# receive raw bytes when parsing fails.
|
||||
if isinstance(parsed, bytes):
|
||||
parsed = _to_str(parsed)
|
||||
return _adapt_to_schema(parsed, prop_schema)
|
||||
|
||||
# Unknown format — return as plain string, but apply
|
||||
# the same per-ref character limit used by inline refs
|
||||
# to prevent injecting unexpectedly large content.
|
||||
text = _to_str(content)
|
||||
if len(text) > _MAX_EXPAND_CHARS:
|
||||
text = text[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
return text
|
||||
|
||||
|
||||
def _adapt_to_schema(parsed: Any, prop_schema: dict[str, Any] | None) -> Any:
|
||||
"""Adapt a parsed file value to better fit the target schema type.
|
||||
|
||||
When the parser returns a natural type (e.g. dict from YAML, list from CSV)
|
||||
that doesn't match the block's expected type, this function converts it to
|
||||
a more useful representation instead of relying on pydantic's generic
|
||||
coercion (which can produce awkward results like flattened dicts → lists).
|
||||
|
||||
Returns *parsed* unchanged when no adaptation is needed.
|
||||
"""
|
||||
if prop_schema is None:
|
||||
return parsed
|
||||
|
||||
target_type = prop_schema.get("type")
|
||||
|
||||
# Dict → array: delegate to helper.
|
||||
if isinstance(parsed, dict) and target_type == "array":
|
||||
return _adapt_dict_to_array(parsed, prop_schema)
|
||||
|
||||
# List → object: delegate to helper (raises for non-tabular lists).
|
||||
if isinstance(parsed, list) and target_type == "object":
|
||||
return _adapt_list_to_object(parsed)
|
||||
|
||||
# Tabular list → Any (no type): convert to list of dicts.
|
||||
# Blocks like FindInDictionaryBlock have `input: Any` which produces
|
||||
# a schema with no "type" key. Tabular [[header],[rows]] is unusable
|
||||
# for key lookup, but [{col: val}, ...] works with FindInDict's
|
||||
# list-of-dicts branch (line 195-199 in data_manipulation.py).
|
||||
if isinstance(parsed, list) and target_type is None and _is_tabular(parsed):
|
||||
return _tabular_to_list_of_dicts(parsed)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def _adapt_dict_to_array(parsed: dict, prop_schema: dict[str, Any]) -> Any:
|
||||
"""Adapt a parsed dict to an array-typed field.
|
||||
|
||||
Extracts list-valued entries when the target item type is ``array``,
|
||||
passes through unchanged when item type is ``string`` (lets pydantic error),
|
||||
or wraps in ``[parsed]`` as a fallback.
|
||||
"""
|
||||
items_type = (prop_schema.get("items") or {}).get("type")
|
||||
if items_type == "array":
|
||||
# Target is List[List[Any]] — extract list-typed values from the
|
||||
# dict as inner lists. E.g. YAML {"fruits": [{...},...]}} with
|
||||
# ConcatenateLists (List[List[Any]]) → [[{...},...]].
|
||||
list_values = [v for v in parsed.values() if isinstance(v, list)]
|
||||
if list_values:
|
||||
return list_values
|
||||
if items_type == "string":
|
||||
# Target is List[str] — wrapping a dict would give [dict]
|
||||
# which can't coerce to strings. Return unchanged and let
|
||||
# pydantic surface a clear validation error.
|
||||
return parsed
|
||||
# Fallback: wrap in a single-element list so the block gets [dict]
|
||||
# instead of pydantic flattening keys/values into a flat list.
|
||||
return [parsed]
|
||||
|
||||
|
||||
def _adapt_list_to_object(parsed: list) -> Any:
|
||||
"""Adapt a parsed list to an object-typed field.
|
||||
|
||||
Converts tabular lists to column-dicts; raises for non-tabular lists.
|
||||
"""
|
||||
if _is_tabular(parsed):
|
||||
return _tabular_to_column_dict(parsed)
|
||||
# Non-tabular list (e.g. a plain Python list from a YAML file) cannot
|
||||
# be meaningfully coerced to an object. Raise explicitly so callers
|
||||
# get a clear error rather than pydantic silently wrapping the list.
|
||||
raise FileRefExpansionError(
|
||||
"Cannot adapt a non-tabular list to an object-typed field. "
|
||||
"Expected a tabular structure ([[header], [row1], ...]) or a dict."
|
||||
)
|
||||
|
||||
|
||||
def _is_tabular(parsed: Any) -> bool:
|
||||
"""Check if parsed data is in tabular format: [[header], [row1], ...].
|
||||
|
||||
Uses isinstance checks because this is a structural type guard on
|
||||
opaque parser output (Any), not duck typing. A Protocol wouldn't
|
||||
help here — we need to verify exact list-of-lists shape.
|
||||
"""
|
||||
if not isinstance(parsed, list) or len(parsed) < 2:
|
||||
return False
|
||||
header = parsed[0]
|
||||
if not isinstance(header, list) or not header:
|
||||
return False
|
||||
if not all(isinstance(h, str) for h in header):
|
||||
return False
|
||||
return all(isinstance(row, list) for row in parsed[1:])
|
||||
|
||||
|
||||
def _tabular_to_list_of_dicts(parsed: list) -> list[dict[str, Any]]:
|
||||
"""Convert [[header], [row1], ...] → [{header[0]: row[0], ...}, ...].
|
||||
|
||||
Ragged rows (fewer columns than the header) get None for missing values.
|
||||
Extra values beyond the header length are silently dropped.
|
||||
"""
|
||||
header = parsed[0]
|
||||
return [
|
||||
dict(itertools.zip_longest(header, row[: len(header)], fillvalue=None))
|
||||
for row in parsed[1:]
|
||||
]
|
||||
|
||||
|
||||
def _tabular_to_column_dict(parsed: list) -> dict[str, list]:
|
||||
"""Convert [[header], [row1], ...] → {"col1": [val1, ...], ...}.
|
||||
|
||||
Ragged rows (fewer columns than the header) get None for missing values,
|
||||
ensuring all columns have equal length.
|
||||
"""
|
||||
header = parsed[0]
|
||||
return {
|
||||
col: [row[i] if i < len(row) else None for row in parsed[1:]]
|
||||
for i, col in enumerate(header)
|
||||
}
|
||||
return {k: await _expand(v) for k, v in args.items()}
|
||||
|
||||
@@ -175,199 +175,6 @@ async def test_expand_args_replaces_file_ref_in_nested_dict():
|
||||
assert result["count"] == 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_args — bare ref structured parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_json_returns_parsed_dict():
|
||||
"""Bare ref to a .json file returns parsed dict, not raw string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "data.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write('{"key": "value", "count": 42}')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{json_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == {"key": "value", "count": 42}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_csv_returns_parsed_table():
|
||||
"""Bare ref to a .csv file returns list[list[str]] table."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
csv_file = os.path.join(sdk_cwd, "data.csv")
|
||||
with open(csv_file, "w") as f:
|
||||
f.write("Name,Score\nAlice,90\nBob,85")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"input": f"@@agptfile:{csv_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["input"] == [
|
||||
["Name", "Score"],
|
||||
["Alice", "90"],
|
||||
["Bob", "85"],
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_unknown_extension_returns_string():
|
||||
"""Bare ref to a file with unknown extension returns plain string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
txt_file = os.path.join(sdk_cwd, "readme.txt")
|
||||
with open(txt_file, "w") as f:
|
||||
f.write("plain text content")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{txt_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == "plain text content"
|
||||
assert isinstance(result["data"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_invalid_json_falls_back_to_string():
|
||||
"""Bare ref to a .json file with invalid JSON falls back to string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "bad.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write("not valid json {{{")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{json_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == "not valid json {{{"
|
||||
assert isinstance(result["data"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedded_ref_always_returns_string_even_for_json():
|
||||
"""Embedded ref (text around it) returns plain string, not parsed JSON."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "data.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write('{"key": "value"}')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"prefix @@agptfile:{json_file} suffix"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert isinstance(result["data"], str)
|
||||
assert result["data"].startswith("prefix ")
|
||||
assert result["data"].endswith(" suffix")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_yaml_returns_parsed_dict():
|
||||
"""Bare ref to a .yaml file returns parsed dict."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
yaml_file = os.path.join(sdk_cwd, "config.yaml")
|
||||
with open(yaml_file, "w") as f:
|
||||
f.write("name: test\ncount: 42\n")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"config": f"@@agptfile:{yaml_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["config"] == {"name": "test", "count": 42}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_binary_with_line_range_ignores_range():
|
||||
"""Bare ref to a binary file (.parquet) with line range parses the full file.
|
||||
|
||||
Binary formats (parquet, xlsx) ignore line ranges — the full content is
|
||||
parsed and the range is silently dropped with a log warning.
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pytest.skip("pandas not installed")
|
||||
try:
|
||||
import pyarrow # noqa: F401 # pyright: ignore[reportMissingImports]
|
||||
except ImportError:
|
||||
pytest.skip("pyarrow not installed")
|
||||
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
parquet_file = os.path.join(sdk_cwd, "data.parquet")
|
||||
import io as _io
|
||||
|
||||
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
|
||||
buf = _io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
with open(parquet_file, "wb") as f:
|
||||
f.write(buf.getvalue())
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
# Line range [1-2] should be silently ignored for binary formats.
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{parquet_file}[1-2]"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
# Full file is returned despite the line range.
|
||||
assert result["data"] == [["A", "B"], [1, 4], [2, 5], [3, 6]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_toml_returns_parsed_dict():
|
||||
"""Bare ref to a .toml file returns parsed dict."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
toml_file = os.path.join(sdk_cwd, "config.toml")
|
||||
with open(toml_file, "w") as f:
|
||||
f.write('name = "test"\ncount = 42\n')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"config": f"@@agptfile:{toml_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["config"] == {"name": "test", "count": 42}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_file_handler — extended to accept workspace:// and local paths
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -412,7 +219,7 @@ async def test_read_file_handler_workspace_uri():
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", mock_session),
|
||||
), patch(
|
||||
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await _read_file_handler(
|
||||
@@ -469,7 +276,7 @@ async def test_read_file_bytes_workspace_virtual_path():
|
||||
mock_manager.read_file.return_value = b"virtual path content"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -20,40 +20,9 @@ Use these URLs directly without asking the user:
|
||||
| Cloudflare | `https://mcp.cloudflare.com/mcp` |
|
||||
| Atlassian / Jira | `https://mcp.atlassian.com/mcp` |
|
||||
|
||||
For other services, search the MCP registry API:
|
||||
```http
|
||||
GET https://registry.modelcontextprotocol.io/v0/servers?q=<search_term>
|
||||
```
|
||||
Each result includes a `remotes` array with the exact server URL to use.
|
||||
|
||||
### Important: Check blocks first
|
||||
|
||||
Before using `run_mcp_tool`, always check if the platform already has blocks for the service
|
||||
using `find_block`. The platform has hundreds of built-in blocks (Google Sheets, Google Docs,
|
||||
Google Calendar, Gmail, etc.) that work without MCP setup.
|
||||
|
||||
Only use `run_mcp_tool` when:
|
||||
- The service is in the known hosted MCP servers list above, OR
|
||||
- You searched `find_block` first and found no matching blocks
|
||||
|
||||
**Never guess or construct MCP server URLs.** Only use URLs from the known servers list above
|
||||
or from the `remotes[].url` field in MCP registry search results.
|
||||
For other services, search the MCP registry at https://registry.modelcontextprotocol.io/.
|
||||
|
||||
### Authentication
|
||||
|
||||
If the server requires credentials, a `SetupRequirementsResponse` is returned with an OAuth
|
||||
login prompt. Once the user completes the flow and confirms, retry the same call immediately.
|
||||
|
||||
### Communication style
|
||||
|
||||
Avoid technical jargon like "MCP server", "OAuth", or "credentials" when talking to the user.
|
||||
Use plain, friendly language instead:
|
||||
|
||||
| Instead of… | Say… |
|
||||
|---|---|
|
||||
| "Let me connect to Sentry's MCP server and discover what tools are available." | "I can connect to Sentry and help identify important issues." |
|
||||
| "Let me connect to Sentry's MCP server now." | "Next, I'll connect to Sentry." |
|
||||
| "The MCP server at mcp.sentry.dev requires authentication. Please connect your credentials to continue." | "To continue, sign in to Sentry and approve access." |
|
||||
| "Sentry's MCP server needs OAuth authentication. You should see a prompt to connect your Sentry account…" | "You should see a prompt to sign in to Sentry. Once connected, I can help surface critical issues right away." |
|
||||
|
||||
Use **"connect to [Service]"** or **"sign in to [Service]"** — never "MCP server", "OAuth", or "credentials".
|
||||
|
||||
@@ -1,651 +0,0 @@
|
||||
"""Tests for retry logic and transcript compaction helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _friendly_error_text, _is_prompt_too_long
|
||||
from .transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_run_compression,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_assistant_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenAssistantContent:
|
||||
def test_text_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": "World"},
|
||||
]
|
||||
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
|
||||
|
||||
def test_tool_use_blocks(self):
|
||||
blocks = [{"type": "tool_use", "name": "read_file", "input": {}}]
|
||||
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
|
||||
|
||||
def test_mixed_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Let me read that."},
|
||||
{"type": "tool_use", "name": "Read", "input": {"path": "/foo"}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "Let me read that." in result
|
||||
assert "[tool_use: Read]" in result
|
||||
|
||||
def test_raw_strings(self):
|
||||
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
|
||||
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "See this image:"},
|
||||
{"type": "image", "source": {"type": "base64", "data": "..."}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "See this image:" in result
|
||||
assert "[__image__]" in result
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_assistant_content([]) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_tool_result_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenToolResultContent:
|
||||
def test_tool_result_with_text(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "123",
|
||||
"content": [{"type": "text", "text": "file contents here"}],
|
||||
}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "file contents here"
|
||||
|
||||
def test_tool_result_with_string_content(self):
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "123", "content": "ok"}]
|
||||
assert _flatten_tool_result_content(blocks) == "ok"
|
||||
|
||||
def test_text_block(self):
|
||||
blocks = [{"type": "text", "text": "plain text"}]
|
||||
assert _flatten_tool_result_content(blocks) == "plain text"
|
||||
|
||||
def test_raw_string(self):
|
||||
assert _flatten_tool_result_content(["raw"]) == "raw"
|
||||
|
||||
def test_tool_result_with_none_content(self):
|
||||
"""tool_result with content=None should produce empty string."""
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": None}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_tool_result_with_empty_list_content(self):
|
||||
"""tool_result with content=[] should produce empty string."""
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": []}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_tool_result_content([]) == ""
|
||||
|
||||
def test_nested_dict_without_text(self):
|
||||
"""Dict blocks without text key use json.dumps fallback."""
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "x",
|
||||
"content": [{"type": "image", "source": "data:..."}],
|
||||
}
|
||||
]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "image" in result # json.dumps fallback
|
||||
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
blocks = [{"type": "image", "source": {"type": "base64", "data": "..."}}]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "[__image__]" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _transcript_to_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_entry(entry_type: str, role: str, content: str | list, **kwargs) -> str:
|
||||
"""Build a JSONL line for testing."""
|
||||
uid = str(uuid4())
|
||||
msg: dict = {"role": role, "content": content}
|
||||
msg.update(kwargs)
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": None,
|
||||
"message": msg,
|
||||
}
|
||||
return json.dumps(entry, separators=(",", ":"))
|
||||
|
||||
|
||||
class TestTranscriptToMessages:
|
||||
def test_basic_roundtrip(self):
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert messages[0] == {"role": "user", "content": "Hello"}
|
||||
assert messages[1] == {"role": "assistant", "content": "Hi"}
|
||||
|
||||
def test_skips_strippable_types(self):
|
||||
"""Progress and metadata entries are excluded."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "progress",
|
||||
"uuid": str(uuid4()),
|
||||
"parentUuid": None,
|
||||
"message": {"role": "assistant", "content": "..."},
|
||||
}
|
||||
),
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_empty_content(self):
|
||||
assert _transcript_to_messages("") == []
|
||||
|
||||
def test_tool_result_content(self):
|
||||
"""User entries with tool_result content blocks are flattened."""
|
||||
lines = [
|
||||
_make_entry(
|
||||
"user",
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "123",
|
||||
"content": "tool output",
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "tool output"
|
||||
|
||||
def test_malformed_json_lines_skipped(self):
|
||||
"""Malformed JSON lines in transcript are silently skipped."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
"this is not valid json",
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_empty_lines_skipped(self):
|
||||
"""Empty lines and whitespace-only lines are skipped."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
"",
|
||||
" ",
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_unicode_content_preserved(self):
|
||||
"""Unicode characters survive transcript roundtrip."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello 你好 🌍"),
|
||||
_make_entry(
|
||||
"assistant",
|
||||
"assistant",
|
||||
[{"type": "text", "text": "Bonjour 日本語 émojis 🎉"}],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert messages[0]["content"] == "Hello 你好 🌍"
|
||||
assert messages[1]["content"] == "Bonjour 日本語 émojis 🎉"
|
||||
|
||||
def test_entry_without_role_skipped(self):
|
||||
"""Entries with missing role in message are skipped."""
|
||||
entry_no_role = json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": str(uuid4()),
|
||||
"parentUuid": None,
|
||||
"message": {"content": "no role here"},
|
||||
}
|
||||
)
|
||||
lines = [
|
||||
entry_no_role,
|
||||
_make_entry("user", "user", "Hello"),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "Hello"
|
||||
|
||||
def test_tool_use_and_result_pairs(self):
|
||||
"""Tool use + tool result pairs are properly flattened."""
|
||||
lines = [
|
||||
_make_entry(
|
||||
"assistant",
|
||||
"assistant",
|
||||
[
|
||||
{"type": "text", "text": "Let me check."},
|
||||
{"type": "tool_use", "name": "read_file", "input": {"path": "/x"}},
|
||||
],
|
||||
),
|
||||
_make_entry(
|
||||
"user",
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "abc",
|
||||
"content": [{"type": "text", "text": "file contents"}],
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert "Let me check." in messages[0]["content"]
|
||||
assert "[tool_use: read_file]" in messages[0]["content"]
|
||||
assert messages[1]["content"] == "file contents"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _messages_to_transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessagesToTranscript:
|
||||
def test_produces_valid_jsonl(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = result.strip().split("\n")
|
||||
assert len(lines) == 2
|
||||
for line in lines:
|
||||
parsed = json.loads(line)
|
||||
assert "type" in parsed
|
||||
assert "uuid" in parsed
|
||||
assert "message" in parsed
|
||||
|
||||
def test_assistant_has_proper_structure(self):
|
||||
messages = [{"role": "assistant", "content": "Hello"}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["type"] == "assistant"
|
||||
msg = entry["message"]
|
||||
assert msg["role"] == "assistant"
|
||||
assert msg["type"] == "message"
|
||||
assert msg["stop_reason"] == "end_turn"
|
||||
assert isinstance(msg["content"], list)
|
||||
assert msg["content"][0]["type"] == "text"
|
||||
|
||||
def test_user_has_plain_content(self):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["type"] == "user"
|
||||
assert entry["message"]["content"] == "Hi"
|
||||
|
||||
def test_parent_uuid_chain(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "A"},
|
||||
{"role": "assistant", "content": "B"},
|
||||
{"role": "user", "content": "C"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = result.strip().split("\n")
|
||||
entries = [json.loads(line) for line in lines]
|
||||
assert entries[0]["parentUuid"] == ""
|
||||
assert entries[1]["parentUuid"] == entries[0]["uuid"]
|
||||
assert entries[2]["parentUuid"] == entries[1]["uuid"]
|
||||
|
||||
def test_empty_messages(self):
|
||||
assert _messages_to_transcript([]) == ""
|
||||
|
||||
def test_output_is_valid_transcript(self):
|
||||
"""Output should pass validate_transcript if it has assistant entries."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert validate_transcript(result)
|
||||
|
||||
def test_roundtrip_to_messages(self):
|
||||
"""Messages → transcript → messages preserves structure."""
|
||||
original = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
transcript = _messages_to_transcript(original)
|
||||
restored = _transcript_to_messages(transcript)
|
||||
assert len(restored) == len(original)
|
||||
for orig, rest in zip(original, restored):
|
||||
assert orig["role"] == rest["role"]
|
||||
assert orig["content"] == rest["content"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compact_transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactTranscript:
|
||||
@pytest.mark.asyncio
|
||||
async def test_too_few_messages_returns_none(self, mock_chat_config):
|
||||
"""compact_transcript returns None when transcript has < 2 messages."""
|
||||
transcript = _build_transcript([("user", "Hello")])
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_compacted(self, mock_chat_config):
|
||||
"""When compress_context says no compaction needed, returns None.
|
||||
The compressor couldn't reduce it, so retrying with the same
|
||||
content would fail identically."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi there"),
|
||||
]
|
||||
)
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": False,
|
||||
"messages": [],
|
||||
"original_token_count": 100,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 0,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_compacted_transcript(self, mock_chat_config):
|
||||
"""When compaction succeeds, returns a valid compacted transcript."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi"),
|
||||
("user", "More"),
|
||||
("assistant", "Details"),
|
||||
]
|
||||
)
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "[summary]"},
|
||||
{"role": "assistant", "content": "Summarized response"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 500,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 2,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[1]["content"] == "Summarized response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_compression_failure(self, mock_chat_config):
|
||||
"""When _run_compression raises, returns None."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi"),
|
||||
]
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_prompt_too_long
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsPromptTooLong:
|
||||
"""Unit tests for _is_prompt_too_long pattern matching."""
|
||||
|
||||
def test_prompt_is_too_long(self):
|
||||
err = RuntimeError("prompt is too long for model context")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_request_too_large(self):
|
||||
err = Exception("request too large: 250000 tokens")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_maximum_context_length(self):
|
||||
err = ValueError("maximum context length exceeded")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_context_length_exceeded(self):
|
||||
err = Exception("context_length_exceeded")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_input_tokens_exceed(self):
|
||||
err = Exception("input tokens exceed the max_tokens limit")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_input_is_too_long(self):
|
||||
err = Exception("input is too long for the model")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_content_length_exceeds(self):
|
||||
err = Exception("content length exceeds maximum")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_unrelated_error_returns_false(self):
|
||||
err = RuntimeError("network timeout")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
def test_auth_error_returns_false(self):
|
||||
err = Exception("authentication failed: invalid API key")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
def test_chained_exception_detected(self):
|
||||
"""Prompt-too-long error wrapped in another exception is detected."""
|
||||
inner = RuntimeError("prompt is too long")
|
||||
outer = Exception("SDK error")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_case_insensitive(self):
|
||||
err = Exception("PROMPT IS TOO LONG")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_old_max_tokens_exceeded_not_matched(self):
|
||||
"""The old broad 'max_tokens_exceeded' pattern was removed.
|
||||
Only 'input tokens exceed' should match now."""
|
||||
err = Exception("max_tokens_exceeded")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_compression timeout fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunCompressionTimeout:
|
||||
"""Verify _run_compression falls back to truncation when LLM times out."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_falls_back_to_truncation(self):
|
||||
"""When compress_context with LLM client times out,
|
||||
_run_compression falls back to truncation (client=None)."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
truncation_result = CompressResult(
|
||||
messages=messages,
|
||||
was_compacted=False,
|
||||
original_token_count=50,
|
||||
token_count=50,
|
||||
messages_summarized=0,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
call_args: list[dict] = []
|
||||
|
||||
async def _mock_compress(**kwargs):
|
||||
call_args.append(kwargs)
|
||||
if kwargs.get("client") is not None:
|
||||
# Simulate timeout by raising asyncio.TimeoutError
|
||||
raise asyncio.TimeoutError("LLM compaction timed out")
|
||||
return truncation_result
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value="fake-client",
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
side_effect=_mock_compress,
|
||||
),
|
||||
):
|
||||
result = await _run_compression(messages, "test-model", "[test]")
|
||||
|
||||
assert result == truncation_result
|
||||
# Should have been called twice: once with client, once without
|
||||
assert len(call_args) == 2
|
||||
assert call_args[0]["client"] is not None # LLM attempt
|
||||
assert call_args[1]["client"] is None # truncation fallback
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_client_uses_truncation_directly(self):
|
||||
"""When no OpenAI client is configured, goes straight to truncation."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
truncation_result = CompressResult(
|
||||
messages=messages,
|
||||
was_compacted=False,
|
||||
original_token_count=50,
|
||||
token_count=50,
|
||||
messages_summarized=0,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
):
|
||||
result = await _run_compression(messages, "test-model", "[test]")
|
||||
|
||||
assert result == truncation_result
|
||||
mock_compress.assert_called_once()
|
||||
# When no client, compress_context is called with client=None
|
||||
assert mock_compress.call_args.kwargs.get("client") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _friendly_error_text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFriendlyErrorText:
|
||||
"""Verify user-friendly error message mapping."""
|
||||
|
||||
def test_authentication_error(self):
|
||||
result = _friendly_error_text("authentication failed: invalid API key")
|
||||
assert "Authentication" in result
|
||||
assert "API key" in result
|
||||
|
||||
def test_rate_limit_error(self):
|
||||
result = _friendly_error_text("rate limit exceeded")
|
||||
assert "Rate limit" in result
|
||||
|
||||
def test_overloaded_error(self):
|
||||
result = _friendly_error_text("API is overloaded")
|
||||
assert "overloaded" in result
|
||||
|
||||
def test_timeout_error(self):
|
||||
result = _friendly_error_text("Request timeout after 30s")
|
||||
assert "timed out" in result
|
||||
|
||||
def test_connection_error(self):
|
||||
result = _friendly_error_text("Connection refused")
|
||||
assert "Connection" in result or "connection" in result
|
||||
|
||||
def test_unknown_error_passthrough(self):
|
||||
result = _friendly_error_text("some unknown error XYZ")
|
||||
assert "SDK stream error:" in result
|
||||
assert "XYZ" in result
|
||||
|
||||
def test_unauthorized_error(self):
|
||||
result = _friendly_error_text("401 Unauthorized")
|
||||
assert "Authentication" in result
|
||||
@@ -20,7 +20,6 @@ from claude_agent_sdk import (
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from backend.copilot.constants import FRIENDLY_TRANSIENT_MSG, is_transient_api_error
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
@@ -215,12 +214,10 @@ class SDKResponseAdapter:
|
||||
if sdk_message.subtype == "success":
|
||||
responses.append(StreamFinish())
|
||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
||||
raw_error = str(sdk_message.result or "Unknown error")
|
||||
if is_transient_api_error(raw_error):
|
||||
error_text, code = FRIENDLY_TRANSIENT_MSG, "transient_api_error"
|
||||
else:
|
||||
error_text, code = raw_error, "sdk_error"
|
||||
responses.append(StreamError(errorText=error_text, code=code))
|
||||
error_msg = sdk_message.result or "Unknown error"
|
||||
responses.append(
|
||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
else:
|
||||
logger.warning(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -42,7 +42,7 @@ def _validate_workspace_path(
|
||||
Delegates to :func:`is_allowed_local_path` which permits:
|
||||
- The SDK working directory (``/tmp/copilot-<session>/``)
|
||||
- The current session's tool-results directory
|
||||
(``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/``)
|
||||
(``~/.claude/projects/<encoded-cwd>/tool-results/``)
|
||||
"""
|
||||
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
||||
if not path:
|
||||
@@ -127,7 +127,7 @@ def create_security_hooks(
|
||||
user_id: str | None,
|
||||
sdk_cwd: str | None = None,
|
||||
max_subtasks: int = 3,
|
||||
on_compact: Callable[[str], None] | None = None,
|
||||
on_compact: Callable[[], None] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create the security hooks configuration for Claude Agent SDK.
|
||||
|
||||
@@ -142,7 +142,6 @@ def create_security_hooks(
|
||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
||||
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
|
||||
on_compact: Callback invoked when SDK starts compacting context.
|
||||
Receives the transcript_path from the hook input.
|
||||
|
||||
Returns:
|
||||
Hooks configuration dict for ClaudeAgentOptions
|
||||
@@ -302,25 +301,11 @@ def create_security_hooks(
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
# Sanitize untrusted input: strip control chars for logging AND
|
||||
# for the value passed downstream. read_compacted_entries()
|
||||
# validates against _projects_base() as defence-in-depth, but
|
||||
# sanitizing here prevents log injection and rejects obviously
|
||||
# malformed paths early.
|
||||
transcript_path = (
|
||||
str(input_data.get("transcript_path", ""))
|
||||
.replace("\n", "")
|
||||
.replace("\r", "")
|
||||
)
|
||||
logger.info(
|
||||
"[SDK] Context compaction triggered: %s, user=%s, "
|
||||
"transcript_path=%s",
|
||||
trigger,
|
||||
user_id,
|
||||
transcript_path,
|
||||
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
||||
)
|
||||
if on_compact is not None:
|
||||
on_compact(transcript_path)
|
||||
on_compact()
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
hooks: dict[str, Any] = {
|
||||
|
||||
@@ -122,7 +122,7 @@ def test_read_no_cwd_denies_absolute():
|
||||
|
||||
def test_read_tool_results_allowed():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/a1b2c3d4-e5f6-7890-abcd-ef1234567890/tool-results/12345.txt"
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||
# is_allowed_local_path requires the session's encoded cwd to be set
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,283 +0,0 @@
|
||||
"""Unit tests for extracted service helpers.
|
||||
|
||||
Covers ``_is_prompt_too_long``, ``_reduce_context``, ``_iter_sdk_messages``,
|
||||
and the ``ReducedContext`` named tuple.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import (
|
||||
ReducedContext,
|
||||
_is_prompt_too_long,
|
||||
_iter_sdk_messages,
|
||||
_reduce_context,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_prompt_too_long
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsPromptTooLong:
|
||||
def test_direct_match(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("prompt is too long")) is True
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("PROMPT IS TOO LONG")) is True
|
||||
|
||||
def test_no_match(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("network timeout")) is False
|
||||
|
||||
def test_request_too_large(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("request too large for model")) is True
|
||||
|
||||
def test_context_length_exceeded(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("context_length_exceeded")) is True
|
||||
|
||||
def test_max_tokens_exceeded_not_matched(self) -> None:
|
||||
"""'max_tokens_exceeded' is intentionally excluded (too broad)."""
|
||||
assert _is_prompt_too_long(Exception("max_tokens_exceeded")) is False
|
||||
|
||||
def test_max_tokens_config_error_no_match(self) -> None:
|
||||
"""'max_tokens must be at least 1' should NOT match."""
|
||||
assert _is_prompt_too_long(Exception("max_tokens must be at least 1")) is False
|
||||
|
||||
def test_chained_cause(self) -> None:
|
||||
inner = Exception("prompt is too long")
|
||||
outer = RuntimeError("SDK error")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_chained_context(self) -> None:
|
||||
inner = Exception("request too large")
|
||||
outer = RuntimeError("wrapped")
|
||||
outer.__context__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_deep_chain(self) -> None:
|
||||
bottom = Exception("maximum context length")
|
||||
middle = RuntimeError("middle")
|
||||
middle.__cause__ = bottom
|
||||
top = ValueError("top")
|
||||
top.__cause__ = middle
|
||||
assert _is_prompt_too_long(top) is True
|
||||
|
||||
def test_chain_no_match(self) -> None:
|
||||
inner = Exception("rate limit exceeded")
|
||||
outer = RuntimeError("wrapped")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is False
|
||||
|
||||
def test_cycle_detection(self) -> None:
|
||||
"""Exception chain with a cycle should not infinite-loop."""
|
||||
a = Exception("error a")
|
||||
b = Exception("error b")
|
||||
a.__cause__ = b
|
||||
b.__cause__ = a # cycle
|
||||
assert _is_prompt_too_long(a) is False
|
||||
|
||||
def test_all_patterns(self) -> None:
|
||||
patterns = [
|
||||
"prompt is too long",
|
||||
"request too large",
|
||||
"maximum context length",
|
||||
"context_length_exceeded",
|
||||
"input tokens exceed",
|
||||
"input is too long",
|
||||
"content length exceeds",
|
||||
]
|
||||
for pattern in patterns:
|
||||
assert _is_prompt_too_long(Exception(pattern)) is True, pattern
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _reduce_context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReduceContext:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_retry_compaction_success(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=compacted,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.validate_transcript",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.write_transcript_to_tempfile",
|
||||
return_value="/tmp/resume.jsonl",
|
||||
),
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert isinstance(ctx, ReducedContext)
|
||||
assert ctx.use_resume is True
|
||||
assert ctx.resume_file == "/tmp/resume.jsonl"
|
||||
assert ctx.transcript_lost is False
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_fails_drops_transcript(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.resume_file is None
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_tried_compaction_skips(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
ctx = await _reduce_context(transcript, True, "sess-123", "/tmp/cwd", "[test]")
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_transcript_drops(self) -> None:
|
||||
ctx = await _reduce_context("", False, "sess-123", "/tmp/cwd", "[test]")
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_returns_same_content_drops(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=transcript, # same content
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_tempfile_fails_drops(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=compacted,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.validate_transcript",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.write_transcript_to_tempfile",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _iter_sdk_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIterSdkMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_yields_messages(self) -> None:
|
||||
messages = ["msg1", "msg2", "msg3"]
|
||||
client = AsyncMock()
|
||||
|
||||
async def _fake_receive() -> AsyncGenerator[str]:
|
||||
for m in messages:
|
||||
yield m
|
||||
|
||||
client.receive_response = _fake_receive
|
||||
result = [msg async for msg in _iter_sdk_messages(client)]
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartbeat_on_timeout(self) -> None:
|
||||
"""Yields None when asyncio.wait times out."""
|
||||
client = AsyncMock()
|
||||
received: list = []
|
||||
|
||||
async def _slow_receive() -> AsyncGenerator[str]:
|
||||
await asyncio.sleep(100) # never completes
|
||||
yield "never" # pragma: no cover — unreachable, yield makes this an async generator
|
||||
|
||||
client.receive_response = _slow_receive
|
||||
|
||||
with patch("backend.copilot.sdk.service._HEARTBEAT_INTERVAL", 0.01):
|
||||
count = 0
|
||||
async for msg in _iter_sdk_messages(client):
|
||||
received.append(msg)
|
||||
count += 1
|
||||
if count >= 3:
|
||||
break
|
||||
|
||||
assert all(m is None for m in received)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_propagates(self) -> None:
|
||||
client = AsyncMock()
|
||||
|
||||
async def _error_receive() -> AsyncGenerator[str]:
|
||||
raise RuntimeError("SDK crash")
|
||||
yield # pragma: no cover — unreachable, yield makes this an async generator
|
||||
|
||||
client.receive_response = _error_receive
|
||||
|
||||
with pytest.raises(RuntimeError, match="SDK crash"):
|
||||
async for _ in _iter_sdk_messages(client):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_cleanup_on_break(self) -> None:
|
||||
"""Pending task is cancelled when generator is closed."""
|
||||
client = AsyncMock()
|
||||
|
||||
async def _slow_receive() -> AsyncGenerator[str]:
|
||||
yield "first"
|
||||
await asyncio.sleep(100)
|
||||
yield "second"
|
||||
|
||||
client.receive_response = _slow_receive
|
||||
|
||||
gen = _iter_sdk_messages(client)
|
||||
first = await gen.__anext__()
|
||||
assert first == "first"
|
||||
await gen.aclose() # should cancel pending task cleanly
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user