mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
12 Commits
fix/copilo
...
feat/analy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9afd9fa01 | ||
|
|
ddb4f6e9de | ||
|
|
f585d97928 | ||
|
|
7d39234fdd | ||
|
|
6e9d4c4333 | ||
|
|
8aad333a45 | ||
|
|
856f0d980d | ||
|
|
3c3aadd361 | ||
|
|
e87a693fdd | ||
|
|
fe265c10d4 | ||
|
|
5d00a94693 | ||
|
|
6e1605994d |
17
.claude/skills/backend-check/SKILL.md
Normal file
17
.claude/skills/backend-check/SKILL.md
Normal file
@@ -0,0 +1,17 @@
|
||||
---
|
||||
name: backend-check
|
||||
description: Run the full backend formatting, linting, and test suite. Ensures code quality before commits and PRs. TRIGGER when backend Python code has been modified and needs validation.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Backend Check
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Format**: `poetry run format` — runs formatting AND linting. NEVER run ruff/black/isort individually
|
||||
2. **Fix** any remaining errors manually, re-run until clean
|
||||
3. **Test**: `poetry run test` (runs DB setup + pytest). For specific files: `poetry run pytest -s -vvv <test_files>`
|
||||
4. **Snapshots** (if needed): `poetry run pytest path/to/test.py --snapshot-update` — review with `git diff`
|
||||
35
.claude/skills/code-style/SKILL.md
Normal file
35
.claude/skills/code-style/SKILL.md
Normal file
@@ -0,0 +1,35 @@
|
||||
---
|
||||
name: code-style
|
||||
description: Python code style preferences for the AutoGPT backend. Apply when writing or reviewing Python code. TRIGGER when writing new Python code, reviewing PRs, or refactoring backend code.
|
||||
user-invocable: false
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Code Style
|
||||
|
||||
## Imports
|
||||
|
||||
- **Top-level only** — no local/inner imports. Move all imports to the top of the file.
|
||||
|
||||
## Typing
|
||||
|
||||
- **No duck typing** — avoid `hasattr`, `getattr`, `isinstance` for type dispatch. Use proper typed interfaces, unions, or protocols.
|
||||
- **Pydantic models** over dataclass, namedtuple, or raw dict for structured data.
|
||||
- **No linter suppressors** — avoid `# type: ignore`, `# noqa`, `# pyright: ignore` etc. 99% of the time the right fix is fixing the type/code, not silencing the tool.
|
||||
|
||||
## Code Structure
|
||||
|
||||
- **List comprehensions** over manual loop-and-append.
|
||||
- **Early return** — guard clauses first, avoid deep nesting.
|
||||
- **Flatten inline** — prefer short, concise expressions. Reduce `if/else` chains with direct returns or ternaries when readable.
|
||||
- **Modular functions** — break complex logic into small, focused functions rather than long blocks with nested conditionals.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
Before finishing, always ask:
|
||||
- Can any function be split into smaller pieces?
|
||||
- Is there unnecessary nesting that an early return would eliminate?
|
||||
- Can any loop be a comprehension?
|
||||
- Is there a simpler way to express this logic?
|
||||
16
.claude/skills/frontend-check/SKILL.md
Normal file
16
.claude/skills/frontend-check/SKILL.md
Normal file
@@ -0,0 +1,16 @@
|
||||
---
|
||||
name: frontend-check
|
||||
description: Run the full frontend formatting, linting, and type checking suite. Ensures code quality before commits and PRs. TRIGGER when frontend TypeScript/React code has been modified and needs validation.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Frontend Check
|
||||
|
||||
## Steps (in order)
|
||||
|
||||
1. **Format**: `pnpm format` — NEVER run individual formatters
|
||||
2. **Lint**: `pnpm lint` — fix errors, re-run until clean
|
||||
3. **Types**: `pnpm types` — if it keeps failing after multiple attempts, stop and ask the user
|
||||
29
.claude/skills/new-block/SKILL.md
Normal file
29
.claude/skills/new-block/SKILL.md
Normal file
@@ -0,0 +1,29 @@
|
||||
---
|
||||
name: new-block
|
||||
description: Create a new backend block following the Block SDK Guide. Guides through provider configuration, schema definition, authentication, and testing. TRIGGER when user asks to create a new block, add a new integration, or build a new node for the graph editor.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# New Block Creation
|
||||
|
||||
Read `docs/platform/block-sdk-guide.md` first for the full guide.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Provider config** (if external service): create `_config.py` with `ProviderBuilder`
|
||||
2. **Block file** in `backend/blocks/` (from `autogpt_platform/backend/`):
|
||||
- Generate a UUID once with `uuid.uuid4()`, then **hard-code that string** as `id` (IDs must be stable across imports)
|
||||
- `Input(BlockSchema)` and `Output(BlockSchema)` classes
|
||||
- `async def run` that `yield`s output fields
|
||||
3. **Files**: use `store_media_file()` with `"for_block_output"` for outputs
|
||||
4. **Test**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[MyBlock]' -xvs`
|
||||
5. **Format**: `poetry run format`
|
||||
|
||||
## Rules
|
||||
|
||||
- Analyze interfaces: do inputs/outputs connect well with other blocks in a graph?
|
||||
- Use top-level imports, avoid duck typing
|
||||
- Always use `for_block_output` for block outputs
|
||||
28
.claude/skills/openapi-regen/SKILL.md
Normal file
28
.claude/skills/openapi-regen/SKILL.md
Normal file
@@ -0,0 +1,28 @@
|
||||
---
|
||||
name: openapi-regen
|
||||
description: Regenerate the OpenAPI spec and frontend API client. Starts the backend REST server, fetches the spec, and regenerates the typed frontend hooks. TRIGGER when API routes change, new endpoints are added, or frontend API types are stale.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# OpenAPI Spec Regeneration
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Run end-to-end** in a single shell block (so `REST_PID` persists):
|
||||
```bash
|
||||
cd autogpt_platform/backend && poetry run rest &
|
||||
REST_PID=$!
|
||||
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && kill $REST_PID && exit 1; done
|
||||
cd ../frontend && pnpm generate:api:force
|
||||
kill $REST_PID
|
||||
pnpm types && pnpm lint && pnpm format
|
||||
```
|
||||
|
||||
## Rules
|
||||
|
||||
- Always use `pnpm generate:api:force` (not `pnpm generate:api`)
|
||||
- Don't manually edit files in `src/app/api/__generated__/`
|
||||
- Generated hooks follow: `use{Method}{Version}{OperationName}`
|
||||
@@ -1,79 +0,0 @@
|
||||
---
|
||||
name: pr-address
|
||||
description: Address PR review comments and loop until CI green and all comments resolved. TRIGGER when user asks to address comments, fix PR feedback, respond to reviewers, or babysit/monitor a PR.
|
||||
user-invocable: true
|
||||
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# PR Address
|
||||
|
||||
## Find the PR
|
||||
|
||||
```bash
|
||||
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Fetch comments (all sources)
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews # top-level reviews
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments # inline review comments
|
||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments # PR conversation comments
|
||||
```
|
||||
|
||||
**Bots to watch for:**
|
||||
- `autogpt-reviewer` — posts "Blockers", "Should Fix", "Nice to Have". Address ALL of them.
|
||||
- `sentry[bot]` — bug predictions. Fix real bugs, explain false positives.
|
||||
- `coderabbitai[bot]` — automated review. Address actionable items.
|
||||
|
||||
## For each unaddressed comment
|
||||
|
||||
Address comments **one at a time**: fix → commit → push → inline reply → next.
|
||||
|
||||
1. Read the referenced code, make the fix (or reply explaining why it's not needed)
|
||||
2. Commit and push the fix
|
||||
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
|
||||
|
||||
| Comment type | How to reply |
|
||||
|---|---|
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="Fixed in <commit-sha>: <description>"` |
|
||||
|
||||
## Format and commit
|
||||
|
||||
After fixing, format the changed code:
|
||||
|
||||
- **Backend** (from `autogpt_platform/backend/`): `poetry run format`
|
||||
- **Frontend** (from `autogpt_platform/frontend/`): `pnpm format && pnpm lint && pnpm types`
|
||||
|
||||
If API routes changed, regenerate the frontend client:
|
||||
```bash
|
||||
cd autogpt_platform/backend && poetry run rest &
|
||||
REST_PID=$!
|
||||
trap "kill $REST_PID 2>/dev/null" EXIT
|
||||
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && exit 1; done
|
||||
cd ../frontend && pnpm generate:api:force
|
||||
kill $REST_PID 2>/dev/null; trap - EXIT
|
||||
```
|
||||
Never manually edit files in `src/app/api/__generated__/`.
|
||||
|
||||
Then commit and **push immediately** — never batch commits without pushing.
|
||||
|
||||
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
||||
|
||||
## The loop
|
||||
|
||||
```text
|
||||
address comments → format → commit → push
|
||||
→ re-check comments → fix new ones → push
|
||||
→ wait for CI → re-check comments after CI settles
|
||||
→ repeat until: all comments addressed AND CI green AND no new comments arriving
|
||||
```
|
||||
|
||||
While CI runs, stay productive: run local tests, address remaining comments.
|
||||
|
||||
**The loop ends when:** CI fully green + all comments addressed + no new comments since CI settled.
|
||||
31
.claude/skills/pr-create/SKILL.md
Normal file
31
.claude/skills/pr-create/SKILL.md
Normal file
@@ -0,0 +1,31 @@
|
||||
---
|
||||
name: pr-create
|
||||
description: Create a pull request for the current branch. TRIGGER when user asks to create a PR, open a pull request, push changes for review, or submit work for merging.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Create Pull Request
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Check for existing PR**: `gh pr view --json url -q .url 2>/dev/null` — if a PR already exists, output its URL and stop
|
||||
2. **Understand changes**: `git status`, `git diff dev...HEAD`, `git log dev..HEAD --oneline`
|
||||
3. **Read PR template**: `.github/PULL_REQUEST_TEMPLATE.md`
|
||||
4. **Draft PR title**: Use conventional commits format (see CLAUDE.md for types and scopes)
|
||||
5. **Fill out PR template** as the body — be thorough in the Changes section
|
||||
6. **Format first** (if relevant changes exist):
|
||||
- Backend: `cd autogpt_platform/backend && poetry run format`
|
||||
- Frontend: `cd autogpt_platform/frontend && pnpm format`
|
||||
- Fix any lint errors, then commit formatting changes before pushing
|
||||
7. **Push**: `git push -u origin HEAD`
|
||||
8. **Create PR**: `gh pr create --base dev`
|
||||
9. **Output** the PR URL
|
||||
|
||||
## Rules
|
||||
|
||||
- Always target `dev` branch
|
||||
- Do NOT run tests — CI will handle that
|
||||
- Use the PR template from `.github/PULL_REQUEST_TEMPLATE.md`
|
||||
@@ -1,74 +1,51 @@
|
||||
---
|
||||
name: pr-review
|
||||
description: Review a PR for correctness, security, code quality, and testing issues. TRIGGER when user asks to review a PR, check PR quality, or give feedback on a PR.
|
||||
description: Address all open PR review comments systematically. Fetches comments, addresses each one, reacts +1/-1, and replies when clarification is needed. Keeps iterating until all comments are addressed and CI is green. TRIGGER when user shares a PR URL, asks to address review comments, fix PR feedback, or respond to reviewer comments.
|
||||
user-invocable: true
|
||||
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# PR Review
|
||||
# PR Review Comment Workflow
|
||||
|
||||
## Find the PR
|
||||
## Steps
|
||||
|
||||
```bash
|
||||
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
|
||||
gh pr view {N}
|
||||
```
|
||||
1. **Find PR**: `gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT`
|
||||
2. **Fetch comments** (all three sources):
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` (top-level reviews)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` (inline review comments)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` (PR conversation comments)
|
||||
3. **Skip** comments already reacted to by PR author
|
||||
4. **For each unreacted comment**:
|
||||
- Read referenced code, make the fix (or reply if you disagree/need info)
|
||||
- **Inline review comments** (`pulls/{N}/comments`):
|
||||
- React: `gh api repos/.../pulls/comments/{ID}/reactions -f content="+1"` (or `-1`)
|
||||
- Reply: `gh api repos/.../pulls/{N}/comments/{ID}/replies -f body="..."`
|
||||
- **PR conversation comments** (`issues/{N}/comments`):
|
||||
- React: `gh api repos/.../issues/comments/{ID}/reactions -f content="+1"` (or `-1`)
|
||||
- No threaded replies — post a new issue comment if needed
|
||||
- **Top-level reviews**: no reaction API — address in code, reply via issue comment if needed
|
||||
5. **Include autogpt-reviewer bot fixes** too
|
||||
6. **Format**: `cd autogpt_platform/backend && poetry run format`, `cd autogpt_platform/frontend && pnpm format`
|
||||
7. **Commit & push**
|
||||
8. **Re-fetch comments** immediately — address any new unreacted ones before waiting on CI
|
||||
9. **Stay productive while CI runs** — don't idle. In priority order:
|
||||
- Run any pending local tests (`poetry run pytest`, e2e, etc.) and fix failures
|
||||
- Address any remaining comments
|
||||
- Only poll `gh pr checks {N}` as the last resort when there's truly nothing left to do
|
||||
10. **If CI fails** — fix, go back to step 6
|
||||
11. **Re-fetch comments again** after CI is green — address anything that appeared while CI was running
|
||||
12. **Done** only when: all comments reacted AND CI is green.
|
||||
|
||||
## Read the diff
|
||||
## CRITICAL: Do Not Stop
|
||||
|
||||
```bash
|
||||
gh pr diff {N}
|
||||
```
|
||||
**Loop is: address → format → commit → push → re-check comments → run local tests → wait CI → re-check comments → repeat.**
|
||||
|
||||
## Fetch existing review comments
|
||||
Never idle. If CI is running and you have nothing to address, run local tests. Waiting on CI is the last resort.
|
||||
|
||||
Before posting anything, fetch existing inline comments to avoid duplicates:
|
||||
## Rules
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
|
||||
```
|
||||
|
||||
## What to check
|
||||
|
||||
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
|
||||
|
||||
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
|
||||
|
||||
**Code quality:** apply rules from backend/frontend CLAUDE.md files.
|
||||
|
||||
**Architecture:** DRY, single responsibility, modular functions. `Security()` vs `Depends()` for FastAPI auth. `data:` for SSE events, `: comment` for heartbeats. `transaction=True` for Redis pipelines.
|
||||
|
||||
**Testing:** edge cases covered, colocated `*_test.py` (backend) / `__tests__/` (frontend), mocks target where symbol is **used** not defined, `AsyncMock` for async.
|
||||
|
||||
## Output format
|
||||
|
||||
Every comment **must** be prefixed with `🤖` and a criticality badge:
|
||||
|
||||
| Tier | Badge | Meaning |
|
||||
|---|---|---|
|
||||
| Blocker | `🔴 **Blocker**` | Must fix before merge |
|
||||
| Should Fix | `🟠 **Should Fix**` | Important improvement |
|
||||
| Nice to Have | `🟡 **Nice to Have**` | Minor suggestion |
|
||||
| Nit | `🔵 **Nit**` | Style / wording |
|
||||
|
||||
Example: `🤖 🔴 **Blocker**: Missing error handling for X — suggest wrapping in try/except.`
|
||||
|
||||
## Post inline comments
|
||||
|
||||
For each finding, post an inline comment on the PR (do not just write a local report):
|
||||
|
||||
```bash
|
||||
# Get the latest commit SHA for the PR
|
||||
COMMIT_SHA=$(gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.head.sha')
|
||||
|
||||
# Post an inline comment on a specific file/line
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments \
|
||||
-f body="🤖 🔴 **Blocker**: <description>" \
|
||||
-f commit_id="$COMMIT_SHA" \
|
||||
-f path="<file path>" \
|
||||
-F line=<line number>
|
||||
```
|
||||
- One todo per comment
|
||||
- For inline review comments: reply on existing threads. For PR conversation comments: post a new issue comment (API doesn't support threaded replies)
|
||||
- React to every comment: +1 addressed, -1 disagreed (with explanation)
|
||||
|
||||
45
.claude/skills/worktree-setup/SKILL.md
Normal file
45
.claude/skills/worktree-setup/SKILL.md
Normal file
@@ -0,0 +1,45 @@
|
||||
---
|
||||
name: worktree-setup
|
||||
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, generates Prisma client, and optionally starts the app (with port conflict resolution) or runs tests. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Worktree Setup
|
||||
|
||||
## Preferred: Use Branchlet
|
||||
|
||||
The repo has a `.branchlet.json` config — it handles env file copying, dependency installation, and Prisma generation automatically.
|
||||
|
||||
```bash
|
||||
npm install -g branchlet # install once
|
||||
branchlet create -n <name> -s <source-branch> -b <new-branch>
|
||||
branchlet list --json # list all worktrees
|
||||
```
|
||||
|
||||
## Manual Fallback
|
||||
|
||||
If branchlet isn't available:
|
||||
|
||||
1. `git worktree add ../<RepoName><N> <branch-name>`
|
||||
2. Copy `.env` files: `backend/.env`, `frontend/.env`, `autogpt_platform/.env`, `db/docker/.env`
|
||||
3. Install deps:
|
||||
- `cd autogpt_platform/backend && poetry install && poetry run prisma generate`
|
||||
- `cd autogpt_platform/frontend && pnpm install`
|
||||
|
||||
## Running the App
|
||||
|
||||
Free ports first — backend uses: 8001, 8002, 8003, 8005, 8006, 8007, 8008.
|
||||
|
||||
```bash
|
||||
for port in 8001 8002 8003 8005 8006 8007 8008; do
|
||||
lsof -ti :$port | xargs kill -9 2>/dev/null || true
|
||||
done
|
||||
cd <worktree>/autogpt_platform/backend && poetry run app
|
||||
```
|
||||
|
||||
## CoPilot Testing Gotcha
|
||||
|
||||
SDK mode spawns a Claude subprocess — **won't work inside Claude Code**. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
|
||||
@@ -1,85 +0,0 @@
|
||||
---
|
||||
name: worktree
|
||||
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, and generates Prisma client. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
|
||||
user-invocable: true
|
||||
args: "[name] — optional worktree name (e.g., 'AutoGPT7'). If omitted, uses next available AutoGPT<N>."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "3.0.0"
|
||||
---
|
||||
|
||||
# Worktree Setup
|
||||
|
||||
## Create the worktree
|
||||
|
||||
Derive paths from the git toplevel. If a name is provided as argument, use it. Otherwise, check `git worktree list` and pick the next `AutoGPT<N>`.
|
||||
|
||||
```bash
|
||||
ROOT=$(git rev-parse --show-toplevel)
|
||||
PARENT=$(dirname "$ROOT")
|
||||
|
||||
# From an existing branch
|
||||
git worktree add "$PARENT/<NAME>" <branch-name>
|
||||
|
||||
# From a new branch off dev
|
||||
git worktree add -b <new-branch> "$PARENT/<NAME>" dev
|
||||
```
|
||||
|
||||
## Copy environment files
|
||||
|
||||
Copy `.env` from the root worktree. Falls back to `.env.default` if `.env` doesn't exist.
|
||||
|
||||
```bash
|
||||
ROOT=$(git rev-parse --show-toplevel)
|
||||
TARGET="$(dirname "$ROOT")/<NAME>"
|
||||
|
||||
for envpath in autogpt_platform/backend autogpt_platform/frontend autogpt_platform; do
|
||||
if [ -f "$ROOT/$envpath/.env" ]; then
|
||||
cp "$ROOT/$envpath/.env" "$TARGET/$envpath/.env"
|
||||
elif [ -f "$ROOT/$envpath/.env.default" ]; then
|
||||
cp "$ROOT/$envpath/.env.default" "$TARGET/$envpath/.env"
|
||||
fi
|
||||
done
|
||||
```
|
||||
|
||||
## Install dependencies
|
||||
|
||||
```bash
|
||||
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||
cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install
|
||||
cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate
|
||||
cd "$TARGET/autogpt_platform/frontend" && pnpm install
|
||||
```
|
||||
|
||||
Replace `<NAME>` with the actual worktree name (e.g., `AutoGPT7`).
|
||||
|
||||
## Running the app (optional)
|
||||
|
||||
Backend uses ports: 8001, 8002, 8003, 8005, 8006, 8007, 8008. Free them first if needed:
|
||||
|
||||
```bash
|
||||
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||
for port in 8001 8002 8003 8005 8006 8007 8008; do
|
||||
lsof -ti :$port | xargs kill -9 2>/dev/null || true
|
||||
done
|
||||
cd "$TARGET/autogpt_platform/backend" && poetry run app
|
||||
```
|
||||
|
||||
## CoPilot testing
|
||||
|
||||
SDK mode spawns a Claude subprocess — won't work inside Claude Code. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
|
||||
|
||||
## Cleanup
|
||||
|
||||
```bash
|
||||
# Replace <NAME> with the actual worktree name (e.g., AutoGPT7)
|
||||
git worktree remove "$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||
```
|
||||
|
||||
## Alternative: Branchlet (optional)
|
||||
|
||||
If [branchlet](https://www.npmjs.com/package/branchlet) is installed:
|
||||
|
||||
```bash
|
||||
branchlet create -n <name> -s <source-branch> -b <new-branch>
|
||||
```
|
||||
@@ -60,12 +60,9 @@ AutoGPT Platform is a monorepo containing:
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` — inline review comments
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
|
||||
@@ -37,10 +37,6 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||
|
||||
## ===== SIGNUP / INVITE GATE ===== ##
|
||||
# Set to true to require an invite before users can sign up
|
||||
ENABLE_INVITE_GATE=false
|
||||
|
||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||
# Platform URLs (set these for webhooks and OAuth to work)
|
||||
PLATFORM_BASE_URL=http://localhost:8000
|
||||
|
||||
@@ -58,31 +58,10 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **Lazy `%s` logging** — `logger.info("Processing %s items", count)` not `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
## Database Schema
|
||||
|
||||
|
||||
@@ -1,17 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
import prisma.enums
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
|
||||
|
||||
|
||||
class UserHistoryResponse(BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
@@ -23,70 +14,3 @@ class UserHistoryResponse(BaseModel):
|
||||
class AddUserCreditsResponse(BaseModel):
|
||||
new_balance: int
|
||||
transaction_key: str
|
||||
|
||||
|
||||
class CreateInvitedUserRequest(BaseModel):
|
||||
email: EmailStr
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class InvitedUserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
status: prisma.enums.InvitedUserStatus
|
||||
auth_user_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tally_understanding: Optional[dict[str, Any]] = None
|
||||
tally_status: prisma.enums.TallyComputationStatus
|
||||
tally_computed_at: Optional[datetime] = None
|
||||
tally_error: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
|
||||
return cls.model_validate(record.model_dump())
|
||||
|
||||
|
||||
class InvitedUsersResponse(BaseModel):
|
||||
invited_users: list[InvitedUserResponse]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class BulkInvitedUserRowResponse(BaseModel):
|
||||
row_number: int
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||
message: str
|
||||
invited_user: Optional[InvitedUserResponse] = None
|
||||
|
||||
|
||||
class BulkInvitedUsersResponse(BaseModel):
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
error_count: int
|
||||
results: list[BulkInvitedUserRowResponse]
|
||||
|
||||
@classmethod
|
||||
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
|
||||
return cls(
|
||||
created_count=result.created_count,
|
||||
skipped_count=result.skipped_count,
|
||||
error_count=result.error_count,
|
||||
results=[
|
||||
BulkInvitedUserRowResponse(
|
||||
row_number=row.row_number,
|
||||
email=row.email,
|
||||
name=row.name,
|
||||
status=row.status,
|
||||
message=row.message,
|
||||
invited_user=(
|
||||
InvitedUserResponse.from_record(row.invited_user)
|
||||
if row.invited_user is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for row in result.results
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
import logging
|
||||
import math
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, File, Query, Security, UploadFile
|
||||
|
||||
from backend.data.invited_user import (
|
||||
bulk_create_invited_users_from_file,
|
||||
create_invited_user,
|
||||
list_invited_users,
|
||||
retry_invited_user_tally,
|
||||
revoke_invited_user,
|
||||
)
|
||||
from backend.data.tally import mask_email
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .model import (
|
||||
BulkInvitedUsersResponse,
|
||||
CreateInvitedUserRequest,
|
||||
InvitedUserResponse,
|
||||
InvitedUsersResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["users", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/invited-users",
|
||||
response_model=InvitedUsersResponse,
|
||||
summary="List Invited Users",
|
||||
)
|
||||
async def get_invited_users(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
) -> InvitedUsersResponse:
|
||||
logger.info("Admin user %s requested invited users", admin_user_id)
|
||||
invited_users, total = await list_invited_users(page=page, page_size=page_size)
|
||||
return InvitedUsersResponse(
|
||||
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=max(1, math.ceil(total / page_size)),
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Create Invited User",
|
||||
)
|
||||
async def create_invited_user_route(
|
||||
request: CreateInvitedUserRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s creating invited user for %s",
|
||||
admin_user_id,
|
||||
mask_email(request.email),
|
||||
)
|
||||
invited_user = await create_invited_user(request.email, request.name)
|
||||
logger.info(
|
||||
"Admin user %s created invited user %s",
|
||||
admin_user_id,
|
||||
invited_user.id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/bulk",
|
||||
response_model=BulkInvitedUsersResponse,
|
||||
summary="Bulk Create Invited Users",
|
||||
operation_id="postV2BulkCreateInvitedUsers",
|
||||
)
|
||||
async def bulk_create_invited_users_route(
|
||||
file: UploadFile = File(...),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> BulkInvitedUsersResponse:
|
||||
logger.info(
|
||||
"Admin user %s bulk invited users from %s",
|
||||
admin_user_id,
|
||||
file.filename or "<unnamed>",
|
||||
)
|
||||
content = await file.read()
|
||||
result = await bulk_create_invited_users_from_file(file.filename, content)
|
||||
return BulkInvitedUsersResponse.from_result(result)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/revoke",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Revoke Invited User",
|
||||
)
|
||||
async def revoke_invited_user_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
|
||||
)
|
||||
invited_user = await revoke_invited_user(invited_user_id)
|
||||
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/retry-tally",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Retry Invited User Tally",
|
||||
)
|
||||
async def retry_invited_user_tally_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s retrying Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
invited_user = await retry_invited_user_tally(invited_user_id)
|
||||
logger.info(
|
||||
"Admin user %s retried Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
@@ -1,168 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import prisma.enums
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.invited_user import (
|
||||
BulkInvitedUserRowResult,
|
||||
BulkInvitedUsersResult,
|
||||
InvitedUserRecord,
|
||||
)
|
||||
|
||||
from .user_admin_routes import router as user_admin_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(user_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _sample_invited_user() -> InvitedUserRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
return InvitedUserRecord(
|
||||
id="invite-1",
|
||||
email="invited@example.com",
|
||||
status=prisma.enums.InvitedUserStatus.INVITED,
|
||||
auth_user_id=None,
|
||||
name="Invited User",
|
||||
tally_understanding=None,
|
||||
tally_status=prisma.enums.TallyComputationStatus.PENDING,
|
||||
tally_computed_at=None,
|
||||
tally_error=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
|
||||
return BulkInvitedUsersResult(
|
||||
created_count=1,
|
||||
skipped_count=1,
|
||||
error_count=0,
|
||||
results=[
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=1,
|
||||
email="invited@example.com",
|
||||
name=None,
|
||||
status="CREATED",
|
||||
message="Invite created",
|
||||
invited_user=_sample_invited_user(),
|
||||
),
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=2,
|
||||
email="duplicate@example.com",
|
||||
name=None,
|
||||
status="SKIPPED",
|
||||
message="An invited user with this email already exists",
|
||||
invited_user=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_get_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.list_invited_users",
|
||||
AsyncMock(return_value=([_sample_invited_user()], 1)),
|
||||
)
|
||||
|
||||
response = client.get("/admin/invited-users")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["invited_users"]) == 1
|
||||
assert data["invited_users"][0]["email"] == "invited@example.com"
|
||||
assert data["invited_users"][0]["status"] == "INVITED"
|
||||
assert data["pagination"]["total_items"] == 1
|
||||
assert data["pagination"]["current_page"] == 1
|
||||
assert data["pagination"]["page_size"] == 50
|
||||
|
||||
|
||||
def test_create_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.create_invited_user",
|
||||
AsyncMock(return_value=_sample_invited_user()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users",
|
||||
json={"email": "invited@example.com", "name": "Invited User"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == "invited@example.com"
|
||||
assert data["name"] == "Invited User"
|
||||
|
||||
|
||||
def test_bulk_create_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
|
||||
AsyncMock(return_value=_sample_bulk_invited_users_result()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users/bulk",
|
||||
files={
|
||||
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created_count"] == 1
|
||||
assert data["skipped_count"] == 1
|
||||
assert data["results"][0]["status"] == "CREATED"
|
||||
assert data["results"][1]["status"] == "SKIPPED"
|
||||
|
||||
|
||||
def test_revoke_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
revoked = _sample_invited_user().model_copy(
|
||||
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
|
||||
AsyncMock(return_value=revoked),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/revoke")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "REVOKED"
|
||||
|
||||
|
||||
def test_retry_invited_user_tally(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
retried = _sample_invited_user().model_copy(
|
||||
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
|
||||
AsyncMock(return_value=retried),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/retry-tally")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["tally_status"] == "RUNNING"
|
||||
@@ -53,8 +53,6 @@ from backend.copilot.tools.models import (
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import get_business_understanding
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -129,7 +127,6 @@ class SessionSummaryResponse(BaseModel):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
title: str | None = None
|
||||
is_processing: bool
|
||||
|
||||
|
||||
class ListSessionsResponse(BaseModel):
|
||||
@@ -188,28 +185,6 @@ async def list_sessions(
|
||||
"""
|
||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||
|
||||
# Batch-check Redis for active stream status on each session
|
||||
processing_set: set[str] = set()
|
||||
if sessions:
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
pipe = redis.pipeline(transaction=False)
|
||||
for session in sessions:
|
||||
pipe.hget(
|
||||
f"{config.session_meta_prefix}{session.session_id}",
|
||||
"status",
|
||||
)
|
||||
statuses = await pipe.execute()
|
||||
processing_set = {
|
||||
session.session_id
|
||||
for session, st in zip(sessions, statuses)
|
||||
if st == "running"
|
||||
}
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch processing status from Redis; " "defaulting to empty"
|
||||
)
|
||||
|
||||
return ListSessionsResponse(
|
||||
sessions=[
|
||||
SessionSummaryResponse(
|
||||
@@ -217,7 +192,6 @@ async def list_sessions(
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
title=session.title,
|
||||
is_processing=session.session_id in processing_set,
|
||||
)
|
||||
for session in sessions
|
||||
],
|
||||
@@ -854,36 +828,6 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Suggested Prompts ==========
|
||||
|
||||
|
||||
class SuggestedPromptsResponse(BaseModel):
|
||||
"""Response model for user-specific suggested prompts."""
|
||||
|
||||
prompts: list[str]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/suggested-prompts",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def get_suggested_prompts(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> SuggestedPromptsResponse:
|
||||
"""
|
||||
Get LLM-generated suggested prompts for the authenticated user.
|
||||
|
||||
Returns personalized quick-action prompts based on the user's
|
||||
business understanding. Returns an empty list if no custom prompts
|
||||
are available.
|
||||
"""
|
||||
understanding = await get_business_understanding(user_id)
|
||||
if understanding is None:
|
||||
return SuggestedPromptsResponse(prompts=[])
|
||||
|
||||
return SuggestedPromptsResponse(prompts=understanding.suggested_prompts)
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
|
||||
"""Tests for chat API routes: session title update and file attachment validation."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -249,62 +249,3 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_get_business_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
return_value=None,
|
||||
):
|
||||
"""Mock get_business_understanding."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=return_value,
|
||||
)
|
||||
|
||||
|
||||
def test_suggested_prompts_returns_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding and prompts gets them back."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = ["Do X", "Do Y", "Do Z"]
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": ["Do X", "Do Y", "Do Z"]}
|
||||
|
||||
|
||||
def test_suggested_prompts_no_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with no understanding gets empty list."""
|
||||
_mock_get_business_understanding(mocker, return_value=None)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
|
||||
|
||||
def test_suggested_prompts_empty_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding but no prompts gets empty list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = []
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
|
||||
@@ -165,6 +165,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
owner_user_id: str
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -205,9 +206,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
default_factory=list,
|
||||
description="List of recent executions with status, score, and summary",
|
||||
)
|
||||
can_access_graph: bool = pydantic.Field(
|
||||
description="Indicates whether the same user owns the corresponding graph"
|
||||
)
|
||||
can_access_graph: bool
|
||||
is_latest_version: bool
|
||||
is_favorite: bool
|
||||
folder_id: str | None = None
|
||||
@@ -325,6 +324,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id=agent.id,
|
||||
graph_id=agent.agentGraphId,
|
||||
graph_version=agent.agentGraphVersion,
|
||||
owner_user_id=agent.userId,
|
||||
image_url=agent.imageUrl,
|
||||
creator_name=creator_name,
|
||||
creator_image_url=creator_image_url,
|
||||
|
||||
@@ -42,6 +42,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -66,6 +67,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-2",
|
||||
graph_id="test-agent-2",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
@@ -129,6 +131,7 @@ async def test_get_favorite_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Favorite Agent 1",
|
||||
description="Test Favorite Description 1",
|
||||
image_url=None,
|
||||
@@ -181,6 +184,7 @@ def test_add_agent_to_library_success(
|
||||
id="test-library-agent-id",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
|
||||
@@ -55,7 +55,6 @@ from backend.data.credit import (
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.invited_user import get_or_activate_user
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
@@ -71,6 +70,7 @@ from backend.data.onboarding import (
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
@@ -136,10 +136,12 @@ _tally_background_tasks: set[asyncio.Task] = set()
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
user = await get_or_activate_user(user_data)
|
||||
user = await get_or_create_user(user_data)
|
||||
|
||||
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
|
||||
# not produce a stored result before first activation.
|
||||
# Fire-and-forget: populate business understanding from Tally form.
|
||||
# We use created_at proximity instead of an is_new flag because
|
||||
# get_or_create_user is cached — a separate is_new return value would be
|
||||
# unreliable on repeated calls within the cache TTL.
|
||||
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
|
||||
if age_seconds < 30:
|
||||
try:
|
||||
@@ -163,8 +165,7 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_email_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
email: str = Body(...),
|
||||
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
|
||||
) -> dict[str, str]:
|
||||
await update_user_email(user_id, email)
|
||||
|
||||
@@ -178,16 +179,10 @@ async def update_user_email_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> TimezoneResponse:
|
||||
"""Get user timezone setting."""
|
||||
try:
|
||||
user = await get_user_by_id(user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail="User not found. Please complete activation via /auth/user first.",
|
||||
)
|
||||
user = await get_or_create_user(user_data)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
@@ -198,8 +193,7 @@ async def get_user_timezone_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
request: UpdateTimezoneRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
|
||||
) -> TimezoneResponse:
|
||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
||||
user = await update_user_timezone(user_id, str(request.timezone))
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_or_activate_user",
|
||||
"backend.api.features.v1.get_or_create_user",
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -94,8 +94,3 @@ class NotificationPayload(pydantic.BaseModel):
|
||||
|
||||
class OnboardingNotificationPayload(NotificationPayload):
|
||||
step: OnboardingStep | None
|
||||
|
||||
|
||||
class CopilotCompletionPayload(NotificationPayload):
|
||||
session_id: str
|
||||
status: Literal["completed", "failed"]
|
||||
|
||||
@@ -19,7 +19,6 @@ from prisma.errors import PrismaError
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.admin.user_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
import backend.api.features.chat.routes as chat_routes
|
||||
@@ -312,11 +311,6 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/executions",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.user_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/users",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
|
||||
@@ -96,7 +96,6 @@ class SendEmailBlock(Block):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Email sent successfully")],
|
||||
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
def github_repo_path(repo_url: str) -> str:
|
||||
"""Extract 'owner/repo' from a GitHub repository URL."""
|
||||
return repo_url.replace("https://github.com/", "")
|
||||
@@ -1,408 +0,0 @@
|
||||
import asyncio
|
||||
from enum import StrEnum
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import parse_data_uri, resolve_media_content
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListCommitsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to list commits from",
|
||||
default="main",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of commits to return (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class CommitItem(TypedDict):
|
||||
sha: str
|
||||
message: str
|
||||
author: str
|
||||
date: str
|
||||
url: str
|
||||
|
||||
commit: CommitItem = SchemaField(
|
||||
title="Commit", description="A commit with its details"
|
||||
)
|
||||
commits: list[CommitItem] = SchemaField(
|
||||
description="List of commits with their details"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing commits failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8b13f579-d8b6-4dc2-a140-f770428805de",
|
||||
description="This block lists commits on a branch in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListCommitsBlock.Input,
|
||||
output_schema=GithubListCommitsBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "main",
|
||||
"per_page": 30,
|
||||
"page": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"commits",
|
||||
[
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"commit",
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_commits": lambda *args, **kwargs: [
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_commits(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
per_page: int,
|
||||
page: int,
|
||||
) -> list[Output.CommitItem]:
|
||||
api = get_api(credentials)
|
||||
commits_url = repo_url + "/commits"
|
||||
params = {"sha": branch, "per_page": str(per_page), "page": str(page)}
|
||||
response = await api.get(commits_url, params=params)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
return [
|
||||
GithubListCommitsBlock.Output.CommitItem(
|
||||
sha=c["sha"],
|
||||
message=c["commit"]["message"],
|
||||
author=(c["commit"].get("author") or {}).get("name", "Unknown"),
|
||||
date=(c["commit"].get("author") or {}).get("date", ""),
|
||||
url=f"https://github.com/{repo_path}/commit/{c['sha']}",
|
||||
)
|
||||
for c in data
|
||||
]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
commits = await self.list_commits(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.per_page,
|
||||
input_data.page,
|
||||
)
|
||||
yield "commits", commits
|
||||
for commit in commits:
|
||||
yield "commit", commit
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class FileOperation(StrEnum):
|
||||
"""File operations for GithubMultiFileCommitBlock.
|
||||
|
||||
UPSERT creates or overwrites a file (the Git Trees API does not distinguish
|
||||
between creation and update — the blob is placed at the given path regardless
|
||||
of whether a file already exists there).
|
||||
|
||||
DELETE removes a file from the tree.
|
||||
"""
|
||||
|
||||
UPSERT = "upsert"
|
||||
DELETE = "delete"
|
||||
|
||||
|
||||
class FileOperationInput(TypedDict):
|
||||
path: str
|
||||
# MediaFileType is a str NewType — no runtime breakage for existing callers.
|
||||
content: MediaFileType
|
||||
operation: FileOperation
|
||||
|
||||
|
||||
class GithubMultiFileCommitBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to commit to",
|
||||
placeholder="feature-branch",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Commit message",
|
||||
placeholder="Add new feature",
|
||||
)
|
||||
files: list[FileOperationInput] = SchemaField(
|
||||
description=(
|
||||
"List of file operations. Each item has: "
|
||||
"'path' (file path), 'content' (file content, ignored for delete), "
|
||||
"'operation' (upsert/delete)"
|
||||
),
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
sha: str = SchemaField(description="SHA of the new commit")
|
||||
url: str = SchemaField(description="URL of the new commit")
|
||||
error: str = SchemaField(description="Error message if the commit failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="389eee51-a95e-4230-9bed-92167a327802",
|
||||
description=(
|
||||
"This block creates a single commit with multiple file "
|
||||
"upsert/delete operations using the Git Trees API."
|
||||
),
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMultiFileCommitBlock.Input,
|
||||
output_schema=GithubMultiFileCommitBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "feature",
|
||||
"commit_message": "Add files",
|
||||
"files": [
|
||||
{
|
||||
"path": "src/new.py",
|
||||
"content": "print('hello')",
|
||||
"operation": "upsert",
|
||||
},
|
||||
{
|
||||
"path": "src/old.py",
|
||||
"content": "",
|
||||
"operation": "delete",
|
||||
},
|
||||
],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("sha", "newcommitsha"),
|
||||
("url", "https://github.com/owner/repo/commit/newcommitsha"),
|
||||
],
|
||||
test_mock={
|
||||
"multi_file_commit": lambda *args, **kwargs: (
|
||||
"newcommitsha",
|
||||
"https://github.com/owner/repo/commit/newcommitsha",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def multi_file_commit(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
files: list[FileOperationInput],
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
safe_branch = quote(branch, safe="")
|
||||
|
||||
# 1. Get the latest commit SHA for the branch
|
||||
ref_url = repo_url + f"/git/refs/heads/{safe_branch}"
|
||||
response = await api.get(ref_url)
|
||||
ref_data = response.json()
|
||||
latest_commit_sha = ref_data["object"]["sha"]
|
||||
|
||||
# 2. Get the tree SHA of the latest commit
|
||||
commit_url = repo_url + f"/git/commits/{latest_commit_sha}"
|
||||
response = await api.get(commit_url)
|
||||
commit_data = response.json()
|
||||
base_tree_sha = commit_data["tree"]["sha"]
|
||||
|
||||
# 3. Build tree entries for each file operation (blobs created concurrently)
|
||||
async def _create_blob(content: str, encoding: str = "utf-8") -> str:
|
||||
blob_url = repo_url + "/git/blobs"
|
||||
blob_response = await api.post(
|
||||
blob_url,
|
||||
json={"content": content, "encoding": encoding},
|
||||
)
|
||||
return blob_response.json()["sha"]
|
||||
|
||||
tree_entries: list[dict] = []
|
||||
upsert_files = []
|
||||
for file_op in files:
|
||||
path = file_op["path"]
|
||||
operation = FileOperation(file_op.get("operation", "upsert"))
|
||||
|
||||
if operation == FileOperation.DELETE:
|
||||
tree_entries.append(
|
||||
{
|
||||
"path": path,
|
||||
"mode": "100644",
|
||||
"type": "blob",
|
||||
"sha": None, # null SHA = delete
|
||||
}
|
||||
)
|
||||
else:
|
||||
upsert_files.append((path, file_op.get("content", "")))
|
||||
|
||||
# Create all blobs concurrently. Data URIs (from store_media_file)
|
||||
# are sent as base64 blobs to preserve binary content.
|
||||
if upsert_files:
|
||||
|
||||
async def _make_blob(content: str) -> str:
|
||||
parsed = parse_data_uri(content)
|
||||
if parsed is not None:
|
||||
_, b64_payload = parsed
|
||||
return await _create_blob(b64_payload, encoding="base64")
|
||||
return await _create_blob(content)
|
||||
|
||||
blob_shas = await asyncio.gather(
|
||||
*[_make_blob(content) for _, content in upsert_files]
|
||||
)
|
||||
for (path, _), blob_sha in zip(upsert_files, blob_shas):
|
||||
tree_entries.append(
|
||||
{
|
||||
"path": path,
|
||||
"mode": "100644",
|
||||
"type": "blob",
|
||||
"sha": blob_sha,
|
||||
}
|
||||
)
|
||||
|
||||
# 4. Create a new tree
|
||||
tree_url = repo_url + "/git/trees"
|
||||
tree_response = await api.post(
|
||||
tree_url,
|
||||
json={"base_tree": base_tree_sha, "tree": tree_entries},
|
||||
)
|
||||
new_tree_sha = tree_response.json()["sha"]
|
||||
|
||||
# 5. Create a new commit
|
||||
new_commit_url = repo_url + "/git/commits"
|
||||
commit_response = await api.post(
|
||||
new_commit_url,
|
||||
json={
|
||||
"message": commit_message,
|
||||
"tree": new_tree_sha,
|
||||
"parents": [latest_commit_sha],
|
||||
},
|
||||
)
|
||||
new_commit_sha = commit_response.json()["sha"]
|
||||
|
||||
# 6. Update the branch reference
|
||||
try:
|
||||
await api.patch(
|
||||
ref_url,
|
||||
json={"sha": new_commit_sha},
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Commit {new_commit_sha} was created but failed to update "
|
||||
f"ref heads/{branch}: {e}. "
|
||||
f"You can recover by manually updating the branch to {new_commit_sha}."
|
||||
) from e
|
||||
|
||||
repo_path = github_repo_path(repo_url)
|
||||
commit_web_url = f"https://github.com/{repo_path}/commit/{new_commit_sha}"
|
||||
return new_commit_sha, commit_web_url
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
# Resolve media references (workspace://, data:, URLs) to data
|
||||
# URIs so _make_blob can send binary content correctly.
|
||||
resolved_files: list[FileOperationInput] = []
|
||||
for file_op in input_data.files:
|
||||
content = file_op.get("content", "")
|
||||
operation = FileOperation(file_op.get("operation", "upsert"))
|
||||
if operation != FileOperation.DELETE:
|
||||
content = await resolve_media_content(
|
||||
MediaFileType(content),
|
||||
execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
resolved_files.append(
|
||||
FileOperationInput(
|
||||
path=file_op["path"],
|
||||
content=MediaFileType(content),
|
||||
operation=operation,
|
||||
)
|
||||
)
|
||||
|
||||
sha, url = await self.multi_file_commit(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
resolved_files,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "url", url
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@@ -21,8 +20,6 @@ from ._auth import (
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
MergeMethod = Literal["merge", "squash", "rebase"]
|
||||
|
||||
|
||||
class GithubListPullRequestsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
@@ -561,109 +558,12 @@ class GithubListPRReviewersBlock(Block):
|
||||
yield "reviewer", reviewer
|
||||
|
||||
|
||||
class GithubMergePullRequestBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
placeholder="https://github.com/owner/repo/pull/1",
|
||||
)
|
||||
merge_method: MergeMethod = SchemaField(
|
||||
description="Merge method to use: merge, squash, or rebase",
|
||||
default="merge",
|
||||
)
|
||||
commit_title: str = SchemaField(
|
||||
description="Title for the merge commit (optional, used for merge and squash)",
|
||||
default="",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the merge commit (optional, used for merge and squash)",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
sha: str = SchemaField(description="SHA of the merge commit")
|
||||
merged: bool = SchemaField(description="Whether the PR was merged")
|
||||
message: str = SchemaField(description="Merge status message")
|
||||
error: str = SchemaField(description="Error message if the merge failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="77456c22-33d8-4fd4-9eef-50b46a35bb48",
|
||||
description="This block merges a pull request using merge, squash, or rebase.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMergePullRequestBlock.Input,
|
||||
output_schema=GithubMergePullRequestBlock.Output,
|
||||
test_input={
|
||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||
"merge_method": "squash",
|
||||
"commit_title": "",
|
||||
"commit_message": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("sha", "abc123"),
|
||||
("merged", True),
|
||||
("message", "Pull Request successfully merged"),
|
||||
],
|
||||
test_mock={
|
||||
"merge_pr": lambda *args, **kwargs: (
|
||||
"abc123",
|
||||
True,
|
||||
"Pull Request successfully merged",
|
||||
)
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def merge_pr(
|
||||
credentials: GithubCredentials,
|
||||
pr_url: str,
|
||||
merge_method: MergeMethod,
|
||||
commit_title: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, bool, str]:
|
||||
api = get_api(credentials)
|
||||
merge_url = prepare_pr_api_url(pr_url=pr_url, path="merge")
|
||||
data: dict[str, str] = {"merge_method": merge_method}
|
||||
if commit_title:
|
||||
data["commit_title"] = commit_title
|
||||
if commit_message:
|
||||
data["commit_message"] = commit_message
|
||||
response = await api.put(merge_url, json=data)
|
||||
result = response.json()
|
||||
return result["sha"], result["merged"], result["message"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
sha, merged, message = await self.merge_pr(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
input_data.merge_method,
|
||||
input_data.commit_title,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "merged", merged
|
||||
yield "message", message
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
def prepare_pr_api_url(pr_url: str, path: str) -> str:
|
||||
# Pattern to capture the base repository URL and the pull request number
|
||||
pattern = r"^(?:(https?)://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
||||
pattern = r"^(?:https?://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
||||
match = re.match(pattern, pr_url)
|
||||
if not match:
|
||||
return pr_url
|
||||
|
||||
scheme, base_url, pr_number = match.groups()
|
||||
return f"{scheme or 'https'}://{base_url}/pulls/{pr_number}/{path}"
|
||||
base_url, pr_number = match.groups()
|
||||
return f"{base_url}/pulls/{pr_number}/{path}"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import base64
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -17,7 +19,6 @@ from ._auth import (
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListTagsBlock(Block):
|
||||
@@ -88,7 +89,7 @@ class GithubListTagsBlock(Block):
|
||||
tags_url = repo_url + "/tags"
|
||||
response = await api.get(tags_url)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
tags: list[GithubListTagsBlock.Output.TagItem] = [
|
||||
{
|
||||
"name": tag["name"],
|
||||
@@ -114,6 +115,101 @@ class GithubListTagsBlock(Block):
|
||||
yield "tag", tag
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class BranchItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
branch: BranchItem = SchemaField(
|
||||
title="Branch",
|
||||
description="Branches with their name and file tree browser URL",
|
||||
)
|
||||
branches: list[BranchItem] = SchemaField(
|
||||
description="List of branches with their name and file tree browser URL"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
||||
description="This block lists all branches for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListBranchesBlock.Input,
|
||||
output_schema=GithubListBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"branches",
|
||||
[
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"branch",
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_branches": lambda *args, **kwargs: [
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_branches(
|
||||
credentials: GithubCredentials, repo_url: str
|
||||
) -> list[Output.BranchItem]:
|
||||
api = get_api(credentials)
|
||||
branches_url = repo_url + "/branches"
|
||||
response = await api.get(branches_url)
|
||||
data = response.json()
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
||||
{
|
||||
"name": branch["name"],
|
||||
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
|
||||
}
|
||||
for branch in data
|
||||
]
|
||||
return branches
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
branches = await self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "branches", branches
|
||||
for branch in branches:
|
||||
yield "branch", branch
|
||||
|
||||
|
||||
class GithubListDiscussionsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
@@ -187,7 +283,7 @@ class GithubListDiscussionsBlock(Block):
|
||||
) -> list[Output.DiscussionItem]:
|
||||
api = get_api(credentials)
|
||||
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
|
||||
repo_path = github_repo_path(repo_url)
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
owner, repo = repo_path.split("/")
|
||||
query = """
|
||||
query($owner: String!, $repo: String!, $num: Int!) {
|
||||
@@ -320,6 +416,564 @@ class GithubListReleasesBlock(Block):
|
||||
yield "release", release
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file in the repository",
|
||||
placeholder="path/to/file",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to read from",
|
||||
placeholder="branch_name",
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
text_content: str = SchemaField(
|
||||
description="Content of the file (decoded as UTF-8 text)"
|
||||
)
|
||||
raw_content: str = SchemaField(
|
||||
description="Raw base64-encoded content of the file"
|
||||
)
|
||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
||||
description="This block reads the content of a specified file from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFileBlock.Input,
|
||||
output_schema=GithubReadFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "path/to/file",
|
||||
"branch": "master",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("raw_content", "RmlsZSBjb250ZW50"),
|
||||
("text_content", "File content"),
|
||||
("size", 13),
|
||||
],
|
||||
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_file(
|
||||
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
||||
) -> tuple[str, int]:
|
||||
api = get_api(credentials)
|
||||
content_url = repo_url + f"/contents/{file_path}?ref={branch}"
|
||||
response = await api.get(content_url)
|
||||
data = response.json()
|
||||
|
||||
if isinstance(data, list):
|
||||
# Multiple entries of different types exist at this path
|
||||
if not (file := next((f for f in data if f["type"] == "file"), None)):
|
||||
raise TypeError("Not a file")
|
||||
data = file
|
||||
|
||||
if data["type"] != "file":
|
||||
raise TypeError("Not a file")
|
||||
|
||||
return data["content"], data["size"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
content, size = await self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", content
|
||||
yield "text_content", base64.b64decode(content).decode("utf-8")
|
||||
yield "size", size
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
folder_path: str = SchemaField(
|
||||
description="Path to the folder in the repository",
|
||||
placeholder="path/to/folder",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to read from (defaults to master)",
|
||||
placeholder="branch_name",
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class DirEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
|
||||
class FileEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
size: int
|
||||
|
||||
file: FileEntry = SchemaField(description="Files in the folder")
|
||||
dir: DirEntry = SchemaField(description="Directories in the folder")
|
||||
error: str = SchemaField(
|
||||
description="Error message if reading the folder failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
||||
description="This block reads the content of a specified folder from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFolderBlock.Input,
|
||||
output_schema=GithubReadFolderBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"folder_path": "path/to/folder",
|
||||
"branch": "master",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
},
|
||||
),
|
||||
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
||||
],
|
||||
test_mock={
|
||||
"read_folder": lambda *args, **kwargs: (
|
||||
[
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
}
|
||||
],
|
||||
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_folder(
|
||||
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
||||
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{folder_path}?ref={branch}"
|
||||
response = await api.get(contents_url)
|
||||
data = response.json()
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise TypeError("Not a folder")
|
||||
|
||||
files: list[GithubReadFolderBlock.Output.FileEntry] = [
|
||||
GithubReadFolderBlock.Output.FileEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
size=entry["size"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "file"
|
||||
]
|
||||
|
||||
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
|
||||
GithubReadFolderBlock.Output.DirEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "dir"
|
||||
]
|
||||
|
||||
return files, dirs
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
files, dirs = await self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
for file in files:
|
||||
yield "file", file
|
||||
for dir in dirs:
|
||||
yield "dir", dir
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
new_branch: str = SchemaField(
|
||||
description="Name of the new branch",
|
||||
placeholder="new_branch_name",
|
||||
)
|
||||
source_branch: str = SchemaField(
|
||||
description="Name of the source branch",
|
||||
placeholder="source_branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch creation operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
||||
description="This block creates a new branch from a specified source branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMakeBranchBlock.Input,
|
||||
output_schema=GithubMakeBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"new_branch": "new_branch_name",
|
||||
"source_branch": "source_branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch created successfully")],
|
||||
test_mock={
|
||||
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_branch(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
new_branch: str,
|
||||
source_branch: str,
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{source_branch}"
|
||||
response = await api.get(ref_url)
|
||||
data = response.json()
|
||||
sha = data["object"]["sha"]
|
||||
|
||||
# Create the new branch
|
||||
new_ref_url = repo_url + "/git/refs"
|
||||
data = {
|
||||
"ref": f"refs/heads/{new_branch}",
|
||||
"sha": sha,
|
||||
}
|
||||
response = await api.post(new_ref_url, json=data)
|
||||
return "Branch created successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
status = await self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Name of the branch to delete",
|
||||
placeholder="branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch deletion failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
||||
description="This block deletes a specified branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubDeleteBranchBlock.Input,
|
||||
output_schema=GithubDeleteBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch deleted successfully")],
|
||||
test_mock={
|
||||
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_branch(
|
||||
credentials: GithubCredentials, repo_url: str, branch: str
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{branch}"
|
||||
await api.delete(ref_url)
|
||||
return "Branch deleted successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
status = await self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubCreateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path where the file should be created",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="Content to write to the file",
|
||||
placeholder="File content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch where the file should be created",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Create new file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the created file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the file creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
||||
description="This block creates a new file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateFileBlock.Input,
|
||||
output_schema=GithubCreateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Test content",
|
||||
"branch": "main",
|
||||
"commit_message": "Create test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "abc123"),
|
||||
],
|
||||
test_mock={
|
||||
"create_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"abc123",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{file_path}"
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.create_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubUpdateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file to update",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="New content for the file",
|
||||
placeholder="Updated content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch containing the file",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Update file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the updated file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
||||
description="This block updates an existing file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUpdateFileBlock.Input,
|
||||
output_schema=GithubUpdateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Updated content",
|
||||
"branch": "main",
|
||||
"commit_message": "Update test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "def456"),
|
||||
],
|
||||
test_mock={
|
||||
"update_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"def456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{file_path}"
|
||||
params = {"ref": branch}
|
||||
response = await api.get(contents_url, params=params)
|
||||
data = response.json()
|
||||
|
||||
# Convert new content to base64
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"sha": data["sha"],
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.update_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCreateRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
@@ -449,7 +1103,7 @@ class GithubListStargazersBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e96d01ec-b55e-4a99-8ce8-c8776dce850b", # Generated unique UUID
|
||||
id="a4b9c2d1-e5f6-4g7h-8i9j-0k1l2m3n4o5p", # Generated unique UUID
|
||||
description="This block lists all users who have starred a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListStargazersBlock.Input,
|
||||
@@ -518,230 +1172,3 @@ class GithubListStargazersBlock(Block):
|
||||
yield "stargazers", stargazers
|
||||
for stargazer in stargazers:
|
||||
yield "stargazer", stargazer
|
||||
|
||||
|
||||
class GithubGetRepositoryInfoBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
name: str = SchemaField(description="Repository name")
|
||||
full_name: str = SchemaField(description="Full repository name (owner/repo)")
|
||||
description: str = SchemaField(description="Repository description")
|
||||
default_branch: str = SchemaField(description="Default branch name (e.g. main)")
|
||||
private: bool = SchemaField(description="Whether the repository is private")
|
||||
html_url: str = SchemaField(description="Web URL of the repository")
|
||||
clone_url: str = SchemaField(description="Git clone URL")
|
||||
stars: int = SchemaField(description="Number of stars")
|
||||
forks: int = SchemaField(description="Number of forks")
|
||||
open_issues: int = SchemaField(description="Number of open issues")
|
||||
error: str = SchemaField(
|
||||
description="Error message if fetching repo info failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="59d4f241-968a-4040-95da-348ac5c5ce27",
|
||||
description="This block retrieves metadata about a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetRepositoryInfoBlock.Input,
|
||||
output_schema=GithubGetRepositoryInfoBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("name", "repo"),
|
||||
("full_name", "owner/repo"),
|
||||
("description", "A test repo"),
|
||||
("default_branch", "main"),
|
||||
("private", False),
|
||||
("html_url", "https://github.com/owner/repo"),
|
||||
("clone_url", "https://github.com/owner/repo.git"),
|
||||
("stars", 42),
|
||||
("forks", 5),
|
||||
("open_issues", 3),
|
||||
],
|
||||
test_mock={
|
||||
"get_repo_info": lambda *args, **kwargs: {
|
||||
"name": "repo",
|
||||
"full_name": "owner/repo",
|
||||
"description": "A test repo",
|
||||
"default_branch": "main",
|
||||
"private": False,
|
||||
"html_url": "https://github.com/owner/repo",
|
||||
"clone_url": "https://github.com/owner/repo.git",
|
||||
"stargazers_count": 42,
|
||||
"forks_count": 5,
|
||||
"open_issues_count": 3,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_repo_info(credentials: GithubCredentials, repo_url: str) -> dict:
|
||||
api = get_api(credentials)
|
||||
response = await api.get(repo_url)
|
||||
return response.json()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.get_repo_info(credentials, input_data.repo_url)
|
||||
yield "name", data["name"]
|
||||
yield "full_name", data["full_name"]
|
||||
yield "description", data.get("description", "") or ""
|
||||
yield "default_branch", data["default_branch"]
|
||||
yield "private", data["private"]
|
||||
yield "html_url", data["html_url"]
|
||||
yield "clone_url", data["clone_url"]
|
||||
yield "stars", data["stargazers_count"]
|
||||
yield "forks", data["forks_count"]
|
||||
yield "open_issues", data["open_issues_count"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubForkRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository to fork",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
organization: str = SchemaField(
|
||||
description="Organization to fork into (leave empty to fork to your account)",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the forked repository")
|
||||
clone_url: str = SchemaField(description="Git clone URL of the fork")
|
||||
full_name: str = SchemaField(description="Full name of the fork (owner/repo)")
|
||||
error: str = SchemaField(description="Error message if the fork failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a439f2f4-835f-4dae-ba7b-0205ffa70be6",
|
||||
description="This block forks a GitHub repository to your account or an organization.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubForkRepositoryBlock.Input,
|
||||
output_schema=GithubForkRepositoryBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"organization": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/myuser/repo"),
|
||||
("clone_url", "https://github.com/myuser/repo.git"),
|
||||
("full_name", "myuser/repo"),
|
||||
],
|
||||
test_mock={
|
||||
"fork_repo": lambda *args, **kwargs: (
|
||||
"https://github.com/myuser/repo",
|
||||
"https://github.com/myuser/repo.git",
|
||||
"myuser/repo",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def fork_repo(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
organization: str,
|
||||
) -> tuple[str, str, str]:
|
||||
api = get_api(credentials)
|
||||
forks_url = repo_url + "/forks"
|
||||
data: dict[str, str] = {}
|
||||
if organization:
|
||||
data["organization"] = organization
|
||||
response = await api.post(forks_url, json=data)
|
||||
result = response.json()
|
||||
return result["html_url"], result["clone_url"], result["full_name"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, clone_url, full_name = await self.fork_repo(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.organization,
|
||||
)
|
||||
yield "url", url
|
||||
yield "clone_url", clone_url
|
||||
yield "full_name", full_name
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubStarRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository to star",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the star operation")
|
||||
error: str = SchemaField(description="Error message if starring failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bd700764-53e3-44dd-a969-d1854088458f",
|
||||
description="This block stars a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubStarRepositoryBlock.Input,
|
||||
output_schema=GithubStarRepositoryBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Repository starred successfully")],
|
||||
test_mock={
|
||||
"star_repo": lambda *args, **kwargs: "Repository starred successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def star_repo(credentials: GithubCredentials, repo_url: str) -> str:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
repo_path = github_repo_path(repo_url)
|
||||
owner, repo = repo_path.split("/")
|
||||
await api.put(
|
||||
f"https://api.github.com/user/starred/{owner}/{repo}",
|
||||
headers={"Content-Length": "0"},
|
||||
)
|
||||
return "Repository starred successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.star_repo(credentials, input_data.repo_url)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -1,452 +0,0 @@
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of branches to return per page (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class BranchItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
branch: BranchItem = SchemaField(
|
||||
title="Branch",
|
||||
description="Branches with their name and file tree browser URL",
|
||||
)
|
||||
branches: list[BranchItem] = SchemaField(
|
||||
description="List of branches with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing branches failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
||||
description="This block lists all branches for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListBranchesBlock.Input,
|
||||
output_schema=GithubListBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"per_page": 30,
|
||||
"page": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"branches",
|
||||
[
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"branch",
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_branches": lambda *args, **kwargs: [
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_branches(
|
||||
credentials: GithubCredentials, repo_url: str, per_page: int, page: int
|
||||
) -> list[Output.BranchItem]:
|
||||
api = get_api(credentials)
|
||||
branches_url = repo_url + "/branches"
|
||||
response = await api.get(
|
||||
branches_url, params={"per_page": str(per_page), "page": str(page)}
|
||||
)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
||||
{
|
||||
"name": branch["name"],
|
||||
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
|
||||
}
|
||||
for branch in data
|
||||
]
|
||||
return branches
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
branches = await self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.per_page,
|
||||
input_data.page,
|
||||
)
|
||||
yield "branches", branches
|
||||
for branch in branches:
|
||||
yield "branch", branch
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
new_branch: str = SchemaField(
|
||||
description="Name of the new branch",
|
||||
placeholder="new_branch_name",
|
||||
)
|
||||
source_branch: str = SchemaField(
|
||||
description="Name of the source branch",
|
||||
placeholder="source_branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch creation operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
||||
description="This block creates a new branch from a specified source branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMakeBranchBlock.Input,
|
||||
output_schema=GithubMakeBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"new_branch": "new_branch_name",
|
||||
"source_branch": "source_branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch created successfully")],
|
||||
test_mock={
|
||||
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_branch(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
new_branch: str,
|
||||
source_branch: str,
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{quote(source_branch, safe='')}"
|
||||
response = await api.get(ref_url)
|
||||
data = response.json()
|
||||
sha = data["object"]["sha"]
|
||||
|
||||
# Create the new branch
|
||||
new_ref_url = repo_url + "/git/refs"
|
||||
data = {
|
||||
"ref": f"refs/heads/{new_branch}",
|
||||
"sha": sha,
|
||||
}
|
||||
response = await api.post(new_ref_url, json=data)
|
||||
return "Branch created successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Name of the branch to delete",
|
||||
placeholder="branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch deletion failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
||||
description="This block deletes a specified branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubDeleteBranchBlock.Input,
|
||||
output_schema=GithubDeleteBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch deleted successfully")],
|
||||
test_mock={
|
||||
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_branch(
|
||||
credentials: GithubCredentials, repo_url: str, branch: str
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{quote(branch, safe='')}"
|
||||
await api.delete(ref_url)
|
||||
return "Branch deleted successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCompareBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
base: str = SchemaField(
|
||||
description="Base branch or commit SHA",
|
||||
placeholder="main",
|
||||
)
|
||||
head: str = SchemaField(
|
||||
description="Head branch or commit SHA to compare against base",
|
||||
placeholder="feature-branch",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class FileChange(TypedDict):
|
||||
filename: str
|
||||
status: str
|
||||
additions: int
|
||||
deletions: int
|
||||
patch: str
|
||||
|
||||
status: str = SchemaField(
|
||||
description="Comparison status: ahead, behind, diverged, or identical"
|
||||
)
|
||||
ahead_by: int = SchemaField(
|
||||
description="Number of commits head is ahead of base"
|
||||
)
|
||||
behind_by: int = SchemaField(
|
||||
description="Number of commits head is behind base"
|
||||
)
|
||||
total_commits: int = SchemaField(
|
||||
description="Total number of commits in the comparison"
|
||||
)
|
||||
diff: str = SchemaField(description="Unified diff of all file changes")
|
||||
file: FileChange = SchemaField(
|
||||
title="Changed File", description="A changed file with its diff"
|
||||
)
|
||||
files: list[FileChange] = SchemaField(
|
||||
description="List of changed files with their diffs"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if comparison failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2e4faa8c-6086-4546-ba77-172d1d560186",
|
||||
description="This block compares two branches or commits in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCompareBranchesBlock.Input,
|
||||
output_schema=GithubCompareBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"base": "main",
|
||||
"head": "feature",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("status", "ahead"),
|
||||
("ahead_by", 2),
|
||||
("behind_by", 0),
|
||||
("total_commits", 2),
|
||||
("diff", "+++ b/file.py\n+new line"),
|
||||
(
|
||||
"files",
|
||||
[
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"compare_branches": lambda *args, **kwargs: {
|
||||
"status": "ahead",
|
||||
"ahead_by": 2,
|
||||
"behind_by": 0,
|
||||
"total_commits": 2,
|
||||
"files": [
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def compare_branches(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
base: str,
|
||||
head: str,
|
||||
) -> dict:
|
||||
api = get_api(credentials)
|
||||
safe_base = quote(base, safe="")
|
||||
safe_head = quote(head, safe="")
|
||||
compare_url = repo_url + f"/compare/{safe_base}...{safe_head}"
|
||||
response = await api.get(compare_url)
|
||||
return response.json()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.compare_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.base,
|
||||
input_data.head,
|
||||
)
|
||||
yield "status", data["status"]
|
||||
yield "ahead_by", data["ahead_by"]
|
||||
yield "behind_by", data["behind_by"]
|
||||
yield "total_commits", data["total_commits"]
|
||||
|
||||
files: list[GithubCompareBranchesBlock.Output.FileChange] = [
|
||||
GithubCompareBranchesBlock.Output.FileChange(
|
||||
filename=f["filename"],
|
||||
status=f["status"],
|
||||
additions=f["additions"],
|
||||
deletions=f["deletions"],
|
||||
patch=f.get("patch", ""),
|
||||
)
|
||||
for f in data.get("files", [])
|
||||
]
|
||||
|
||||
# Build unified diff
|
||||
diff_parts = []
|
||||
for f in data.get("files", []):
|
||||
patch = f.get("patch", "")
|
||||
if patch:
|
||||
diff_parts.append(f"+++ b/{f['filename']}\n{patch}")
|
||||
yield "diff", "\n".join(diff_parts)
|
||||
|
||||
yield "files", files
|
||||
for file in files:
|
||||
yield "file", file
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,720 +0,0 @@
|
||||
import base64
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file in the repository",
|
||||
placeholder="path/to/file",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to read from",
|
||||
placeholder="branch_name",
|
||||
default="main",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
text_content: str = SchemaField(
|
||||
description="Content of the file (decoded as UTF-8 text)"
|
||||
)
|
||||
raw_content: str = SchemaField(
|
||||
description="Raw base64-encoded content of the file"
|
||||
)
|
||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||
error: str = SchemaField(description="Error message if reading the file failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
||||
description="This block reads the content of a specified file from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFileBlock.Input,
|
||||
output_schema=GithubReadFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "path/to/file",
|
||||
"branch": "main",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("raw_content", "RmlsZSBjb250ZW50"),
|
||||
("text_content", "File content"),
|
||||
("size", 13),
|
||||
],
|
||||
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_file(
|
||||
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
||||
) -> tuple[str, int]:
|
||||
api = get_api(credentials)
|
||||
content_url = (
|
||||
repo_url
|
||||
+ f"/contents/{quote(file_path, safe='')}?ref={quote(branch, safe='')}"
|
||||
)
|
||||
response = await api.get(content_url)
|
||||
data = response.json()
|
||||
|
||||
if isinstance(data, list):
|
||||
# Multiple entries of different types exist at this path
|
||||
if not (file := next((f for f in data if f["type"] == "file"), None)):
|
||||
raise TypeError("Not a file")
|
||||
data = file
|
||||
|
||||
if data["type"] != "file":
|
||||
raise TypeError("Not a file")
|
||||
|
||||
return data["content"], data["size"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
content, size = await self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", content
|
||||
yield "text_content", base64.b64decode(content).decode("utf-8")
|
||||
yield "size", size
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
folder_path: str = SchemaField(
|
||||
description="Path to the folder in the repository",
|
||||
placeholder="path/to/folder",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to read from (defaults to main)",
|
||||
placeholder="branch_name",
|
||||
default="main",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class DirEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
|
||||
class FileEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
size: int
|
||||
|
||||
file: FileEntry = SchemaField(description="Files in the folder")
|
||||
dir: DirEntry = SchemaField(description="Directories in the folder")
|
||||
error: str = SchemaField(
|
||||
description="Error message if reading the folder failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
||||
description="This block reads the content of a specified folder from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFolderBlock.Input,
|
||||
output_schema=GithubReadFolderBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"folder_path": "path/to/folder",
|
||||
"branch": "main",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
},
|
||||
),
|
||||
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
||||
],
|
||||
test_mock={
|
||||
"read_folder": lambda *args, **kwargs: (
|
||||
[
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
}
|
||||
],
|
||||
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_folder(
|
||||
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
||||
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
||||
api = get_api(credentials)
|
||||
contents_url = (
|
||||
repo_url
|
||||
+ f"/contents/{quote(folder_path, safe='/')}?ref={quote(branch, safe='')}"
|
||||
)
|
||||
response = await api.get(contents_url)
|
||||
data = response.json()
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise TypeError("Not a folder")
|
||||
|
||||
files: list[GithubReadFolderBlock.Output.FileEntry] = [
|
||||
GithubReadFolderBlock.Output.FileEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
size=entry["size"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "file"
|
||||
]
|
||||
|
||||
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
|
||||
GithubReadFolderBlock.Output.DirEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "dir"
|
||||
]
|
||||
|
||||
return files, dirs
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
files, dirs = await self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
for file in files:
|
||||
yield "file", file
|
||||
for dir in dirs:
|
||||
yield "dir", dir
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCreateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path where the file should be created",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="Content to write to the file",
|
||||
placeholder="File content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch where the file should be created",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Create new file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the created file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the file creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
||||
description="This block creates a new file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateFileBlock.Input,
|
||||
output_schema=GithubCreateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Test content",
|
||||
"branch": "main",
|
||||
"commit_message": "Create test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "abc123"),
|
||||
],
|
||||
test_mock={
|
||||
"create_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"abc123",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.create_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubUpdateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file to update",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="New content for the file",
|
||||
placeholder="Updated content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch containing the file",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Update file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the updated file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
||||
description="This block updates an existing file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUpdateFileBlock.Input,
|
||||
output_schema=GithubUpdateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Updated content",
|
||||
"branch": "main",
|
||||
"commit_message": "Update test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "def456"),
|
||||
],
|
||||
test_mock={
|
||||
"update_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"def456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
|
||||
params = {"ref": branch}
|
||||
response = await api.get(contents_url, params=params)
|
||||
data = response.json()
|
||||
|
||||
# Convert new content to base64
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"sha": data["sha"],
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.update_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubSearchCodeBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
query: str = SchemaField(
|
||||
description="Search query (GitHub code search syntax)",
|
||||
placeholder="className language:python",
|
||||
)
|
||||
repo: str = SchemaField(
|
||||
description="Restrict search to a repository (owner/repo format, optional)",
|
||||
default="",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of results to return (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class SearchResult(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
repository: str
|
||||
url: str
|
||||
score: float
|
||||
|
||||
result: SearchResult = SchemaField(
|
||||
title="Result", description="A code search result"
|
||||
)
|
||||
results: list[SearchResult] = SchemaField(
|
||||
description="List of code search results"
|
||||
)
|
||||
total_count: int = SchemaField(description="Total number of matching results")
|
||||
error: str = SchemaField(description="Error message if search failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="47f94891-a2b1-4f1c-b5f2-573c043f721e",
|
||||
description="This block searches for code in GitHub repositories.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubSearchCodeBlock.Input,
|
||||
output_schema=GithubSearchCodeBlock.Output,
|
||||
test_input={
|
||||
"query": "addClass",
|
||||
"repo": "owner/repo",
|
||||
"per_page": 30,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("total_count", 1),
|
||||
(
|
||||
"results",
|
||||
[
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"result",
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"search_code": lambda *args, **kwargs: (
|
||||
1,
|
||||
[
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
}
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def search_code(
|
||||
credentials: GithubCredentials,
|
||||
query: str,
|
||||
repo: str,
|
||||
per_page: int,
|
||||
) -> tuple[int, list[Output.SearchResult]]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
full_query = f"{query} repo:{repo}" if repo else query
|
||||
params = {"q": full_query, "per_page": str(per_page)}
|
||||
response = await api.get("https://api.github.com/search/code", params=params)
|
||||
data = response.json()
|
||||
results: list[GithubSearchCodeBlock.Output.SearchResult] = [
|
||||
GithubSearchCodeBlock.Output.SearchResult(
|
||||
name=item["name"],
|
||||
path=item["path"],
|
||||
repository=item["repository"]["full_name"],
|
||||
url=item["html_url"],
|
||||
score=item["score"],
|
||||
)
|
||||
for item in data["items"]
|
||||
]
|
||||
return data["total_count"], results
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
total_count, results = await self.search_code(
|
||||
credentials,
|
||||
input_data.query,
|
||||
input_data.repo,
|
||||
input_data.per_page,
|
||||
)
|
||||
yield "total_count", total_count
|
||||
yield "results", results
|
||||
for result in results:
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubGetRepositoryTreeBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to get the tree from",
|
||||
default="main",
|
||||
)
|
||||
recursive: bool = SchemaField(
|
||||
description="Whether to recursively list the entire tree",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class TreeEntry(TypedDict):
|
||||
path: str
|
||||
type: str
|
||||
size: int
|
||||
sha: str
|
||||
|
||||
entry: TreeEntry = SchemaField(
|
||||
title="Tree Entry", description="A file or directory in the tree"
|
||||
)
|
||||
entries: list[TreeEntry] = SchemaField(
|
||||
description="List of all files and directories in the tree"
|
||||
)
|
||||
truncated: bool = SchemaField(
|
||||
description="Whether the tree was truncated due to size"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if getting tree failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="89c5c0ec-172e-4001-a32c-bdfe4d0c9e81",
|
||||
description="This block lists the entire file tree of a GitHub repository recursively.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetRepositoryTreeBlock.Input,
|
||||
output_schema=GithubGetRepositoryTreeBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "main",
|
||||
"recursive": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("truncated", False),
|
||||
(
|
||||
"entries",
|
||||
[
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"entry",
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_tree": lambda *args, **kwargs: (
|
||||
False,
|
||||
[
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
}
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_tree(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
recursive: bool,
|
||||
) -> tuple[bool, list[Output.TreeEntry]]:
|
||||
api = get_api(credentials)
|
||||
tree_url = repo_url + f"/git/trees/{quote(branch, safe='')}"
|
||||
params = {"recursive": "1"} if recursive else {}
|
||||
response = await api.get(tree_url, params=params)
|
||||
data = response.json()
|
||||
entries: list[GithubGetRepositoryTreeBlock.Output.TreeEntry] = [
|
||||
GithubGetRepositoryTreeBlock.Output.TreeEntry(
|
||||
path=item["path"],
|
||||
type=item["type"],
|
||||
size=item.get("size", 0),
|
||||
sha=item["sha"],
|
||||
)
|
||||
for item in data["tree"]
|
||||
]
|
||||
return data.get("truncated", False), entries
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
truncated, entries = await self.get_tree(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.recursive,
|
||||
)
|
||||
yield "truncated", truncated
|
||||
yield "entries", entries
|
||||
for entry in entries:
|
||||
yield "entry", entry
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,125 +0,0 @@
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.github._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from backend.blocks.github.commits import FileOperation, GithubMultiFileCommitBlock
|
||||
from backend.blocks.github.pull_requests import (
|
||||
GithubMergePullRequestBlock,
|
||||
prepare_pr_api_url,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
# ── prepare_pr_api_url tests ──
|
||||
|
||||
|
||||
class TestPreparePrApiUrl:
|
||||
def test_https_scheme_preserved(self):
|
||||
result = prepare_pr_api_url("https://github.com/owner/repo/pull/42", "merge")
|
||||
assert result == "https://github.com/owner/repo/pulls/42/merge"
|
||||
|
||||
def test_http_scheme_preserved(self):
|
||||
result = prepare_pr_api_url("http://github.com/owner/repo/pull/1", "files")
|
||||
assert result == "http://github.com/owner/repo/pulls/1/files"
|
||||
|
||||
def test_no_scheme_defaults_to_https(self):
|
||||
result = prepare_pr_api_url("github.com/owner/repo/pull/5", "merge")
|
||||
assert result == "https://github.com/owner/repo/pulls/5/merge"
|
||||
|
||||
def test_reviewers_path(self):
|
||||
result = prepare_pr_api_url(
|
||||
"https://github.com/owner/repo/pull/99", "requested_reviewers"
|
||||
)
|
||||
assert result == "https://github.com/owner/repo/pulls/99/requested_reviewers"
|
||||
|
||||
def test_invalid_url_returned_as_is(self):
|
||||
url = "https://example.com/not-a-pr"
|
||||
assert prepare_pr_api_url(url, "merge") == url
|
||||
|
||||
def test_empty_string(self):
|
||||
assert prepare_pr_api_url("", "merge") == ""
|
||||
|
||||
|
||||
# ── Error-path block tests ──
|
||||
# When a block's run() yields ("error", msg), _execute() converts it to a
|
||||
# BlockExecutionError. We call block.execute() directly (not execute_block_test,
|
||||
# which returns early on empty test_output).
|
||||
|
||||
|
||||
def _mock_block(block, mocks: dict):
|
||||
"""Apply mocks to a block's static methods, wrapping sync mocks as async."""
|
||||
for name, mock_fn in mocks.items():
|
||||
original = getattr(block, name)
|
||||
if inspect.iscoroutinefunction(original):
|
||||
|
||||
async def async_mock(*args, _fn=mock_fn, **kwargs):
|
||||
return _fn(*args, **kwargs)
|
||||
|
||||
setattr(block, name, async_mock)
|
||||
else:
|
||||
setattr(block, name, mock_fn)
|
||||
|
||||
|
||||
def _raise(exc: Exception):
|
||||
"""Helper that returns a callable which raises the given exception."""
|
||||
|
||||
def _raiser(*args, **kwargs):
|
||||
raise exc
|
||||
|
||||
return _raiser
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_pr_error_path():
|
||||
block = GithubMergePullRequestBlock()
|
||||
_mock_block(block, {"merge_pr": _raise(RuntimeError("PR not mergeable"))})
|
||||
input_data = {
|
||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||
"merge_method": "squash",
|
||||
"commit_title": "",
|
||||
"commit_message": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="PR not mergeable"):
|
||||
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_file_commit_error_path():
|
||||
block = GithubMultiFileCommitBlock()
|
||||
_mock_block(block, {"multi_file_commit": _raise(RuntimeError("ref update failed"))})
|
||||
input_data = {
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "feature",
|
||||
"commit_message": "test",
|
||||
"files": [{"path": "a.py", "content": "x", "operation": "upsert"}],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="ref update failed"):
|
||||
async for _ in block.execute(
|
||||
input_data,
|
||||
credentials=TEST_CREDENTIALS,
|
||||
execution_context=ExecutionContext(),
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
# ── FileOperation enum tests ──
|
||||
|
||||
|
||||
class TestFileOperation:
|
||||
def test_upsert_value(self):
|
||||
assert FileOperation.UPSERT == "upsert"
|
||||
|
||||
def test_delete_value(self):
|
||||
assert FileOperation.DELETE == "delete"
|
||||
|
||||
def test_invalid_value_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
FileOperation("create")
|
||||
|
||||
def test_invalid_value_raises_typo(self):
|
||||
with pytest.raises(ValueError):
|
||||
FileOperation("upser")
|
||||
@@ -241,8 +241,8 @@ class GmailBase(Block, ABC):
|
||||
h.ignore_links = False
|
||||
h.ignore_images = True
|
||||
return h.handle(html_content)
|
||||
except Exception:
|
||||
# Keep extraction resilient if html2text is unavailable or fails.
|
||||
except ImportError:
|
||||
# Fallback: return raw HTML if html2text is not available
|
||||
return html_content
|
||||
|
||||
# Handle content stored as attachment
|
||||
|
||||
@@ -140,31 +140,19 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenRouter models
|
||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
||||
GEMINI_2_5_PRO_PREVIEW = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro"
|
||||
GEMINI_3_1_PRO_PREVIEW = "google/gemini-3.1-pro-preview"
|
||||
GEMINI_3_FLASH_PREVIEW = "google/gemini-3-flash-preview"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
|
||||
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
||||
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
||||
GEMINI_3_1_FLASH_LITE_PREVIEW = "google/gemini-3.1-flash-lite-preview"
|
||||
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
|
||||
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
|
||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||
MISTRAL_LARGE_3 = "mistralai/mistral-large-2512"
|
||||
MISTRAL_MEDIUM_3_1 = "mistralai/mistral-medium-3.1"
|
||||
MISTRAL_SMALL_3_2 = "mistralai/mistral-small-3.2-24b-instruct"
|
||||
CODESTRAL = "mistralai/codestral-2508"
|
||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
|
||||
COHERE_COMMAND_A_03_2025 = "cohere/command-a-03-2025"
|
||||
COHERE_COMMAND_A_TRANSLATE_08_2025 = "cohere/command-a-translate-08-2025"
|
||||
COHERE_COMMAND_A_REASONING_08_2025 = "cohere/command-a-reasoning-08-2025"
|
||||
COHERE_COMMAND_A_VISION_07_2025 = "cohere/command-a-vision-07-2025"
|
||||
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
|
||||
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
|
||||
PERPLEXITY_SONAR = "perplexity/sonar"
|
||||
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
||||
PERPLEXITY_SONAR_REASONING_PRO = "perplexity/sonar-reasoning-pro"
|
||||
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||
@@ -172,11 +160,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||
MICROSOFT_PHI_4 = "microsoft/phi-4"
|
||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
||||
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
||||
GROK_3 = "x-ai/grok-3"
|
||||
GROK_4 = "x-ai/grok-4"
|
||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||
@@ -354,41 +340,17 @@ MODEL_METADATA = {
|
||||
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
|
||||
),
|
||||
# https://openrouter.ai/models
|
||||
LlmModel.GEMINI_2_5_PRO_PREVIEW: ModelMetadata(
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
1050000,
|
||||
8192,
|
||||
"Gemini 2.5 Pro Preview 03.25",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 2.5 Pro",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_1_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3.1 Pro Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_FLASH_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3 Flash Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
|
||||
@@ -396,15 +358,6 @@ MODEL_METADATA = {
|
||||
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
|
||||
),
|
||||
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3.1 Flash Lite Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
@@ -426,78 +379,12 @@ MODEL_METADATA = {
|
||||
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
|
||||
),
|
||||
LlmModel.MISTRAL_LARGE_3: ModelMetadata(
|
||||
"open_router",
|
||||
262144,
|
||||
None,
|
||||
"Mistral Large 3 2512",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
2,
|
||||
),
|
||||
LlmModel.MISTRAL_MEDIUM_3_1: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
None,
|
||||
"Mistral Medium 3.1",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
2,
|
||||
),
|
||||
LlmModel.MISTRAL_SMALL_3_2: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
131072,
|
||||
"Mistral Small 3.2 24B",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
1,
|
||||
),
|
||||
LlmModel.CODESTRAL: ModelMetadata(
|
||||
"open_router",
|
||||
256000,
|
||||
None,
|
||||
"Codestral 2508",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
1,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_03_2025: ModelMetadata(
|
||||
"open_router", 256000, 8192, "Command A 03.2025", "OpenRouter", "Cohere", 2
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8192,
|
||||
"Command A Translate 08.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
2,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: ModelMetadata(
|
||||
"open_router",
|
||||
256000,
|
||||
32768,
|
||||
"Command A Reasoning 08.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
3,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_VISION_07_2025: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8192,
|
||||
"Command A Vision 07.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
2,
|
||||
),
|
||||
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
||||
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
|
||||
),
|
||||
@@ -510,15 +397,6 @@ MODEL_METADATA = {
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
|
||||
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8000,
|
||||
"Sonar Reasoning Pro",
|
||||
"OpenRouter",
|
||||
"Perplexity",
|
||||
2,
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
@@ -564,9 +442,6 @@ MODEL_METADATA = {
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
||||
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
|
||||
),
|
||||
LlmModel.MICROSOFT_PHI_4: ModelMetadata(
|
||||
"open_router", 16384, 16384, "Phi-4", "OpenRouter", "Microsoft", 1
|
||||
),
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
||||
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
|
||||
),
|
||||
@@ -576,15 +451,6 @@ MODEL_METADATA = {
|
||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
|
||||
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
|
||||
),
|
||||
LlmModel.GROK_3: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
131072,
|
||||
"Grok 3",
|
||||
"OpenRouter",
|
||||
"xAI",
|
||||
2,
|
||||
),
|
||||
LlmModel.GROK_4: ModelMetadata(
|
||||
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
|
||||
),
|
||||
|
||||
@@ -4,7 +4,7 @@ from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
import openai
|
||||
from pydantic import SecretStr, field_validator
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
@@ -13,7 +13,6 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -36,20 +35,6 @@ class PerplexityModel(str, Enum):
|
||||
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
|
||||
|
||||
def _sanitize_perplexity_model(value: Any) -> PerplexityModel:
|
||||
"""Return a valid PerplexityModel, falling back to SONAR for invalid values."""
|
||||
if isinstance(value, PerplexityModel):
|
||||
return value
|
||||
try:
|
||||
return PerplexityModel(value)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid PerplexityModel '{value}', "
|
||||
f"falling back to {PerplexityModel.SONAR.value}"
|
||||
)
|
||||
return PerplexityModel.SONAR
|
||||
|
||||
|
||||
PerplexityCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
|
||||
]
|
||||
@@ -88,25 +73,6 @@ class PerplexityBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
credentials: PerplexityCredentials = PerplexityCredentialsField()
|
||||
|
||||
@field_validator("model", mode="before")
|
||||
@classmethod
|
||||
def fallback_invalid_model(cls, v: Any) -> PerplexityModel:
|
||||
"""Fall back to SONAR if the model value is not a valid
|
||||
PerplexityModel (e.g. an OpenAI model ID set by the agent
|
||||
generator)."""
|
||||
return _sanitize_perplexity_model(v)
|
||||
|
||||
@classmethod
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
"""Sanitize the model field before JSON schema validation so that
|
||||
invalid values are replaced with the default instead of raising a
|
||||
BlockInputError."""
|
||||
model_value = data.get("model")
|
||||
if model_value is not None:
|
||||
data["model"] = _sanitize_perplexity_model(model_value).value
|
||||
return super().validate_data(data)
|
||||
|
||||
system_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="",
|
||||
|
||||
@@ -2232,7 +2232,6 @@ class DeleteRedditPostBlock(Block):
|
||||
("post_id", "abc123"),
|
||||
],
|
||||
test_mock={"delete_post": lambda creds, post_id: True},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -2291,7 +2290,6 @@ class DeleteRedditCommentBlock(Block):
|
||||
("comment_id", "xyz789"),
|
||||
],
|
||||
test_mock={"delete_comment": lambda creds, comment_id: True},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -72,7 +72,6 @@ class Slant3DCreateOrderBlock(Slant3DBlockBase):
|
||||
"_make_request": lambda *args, **kwargs: {"orderId": "314144241"},
|
||||
"_convert_to_color": lambda *args, **kwargs: "black",
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
async def run(
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
"""Unit tests for PerplexityBlock model fallback behavior."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.perplexity import (
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
PerplexityBlock,
|
||||
PerplexityModel,
|
||||
)
|
||||
|
||||
|
||||
def _make_input(**overrides) -> dict:
|
||||
defaults = {
|
||||
"prompt": "test query",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return defaults
|
||||
|
||||
|
||||
class TestPerplexityModelFallback:
|
||||
"""Tests for fallback_invalid_model field_validator."""
|
||||
|
||||
def test_invalid_model_falls_back_to_sonar(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="gpt-5.2-2025-12-11"))
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
def test_another_invalid_model_falls_back_to_sonar(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="gpt-4o"))
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
def test_valid_model_string_is_kept(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="perplexity/sonar-pro"))
|
||||
assert inp.model == PerplexityModel.SONAR_PRO
|
||||
|
||||
def test_valid_enum_value_is_kept(self):
|
||||
inp = PerplexityBlock.Input(
|
||||
**_make_input(model=PerplexityModel.SONAR_DEEP_RESEARCH)
|
||||
)
|
||||
assert inp.model == PerplexityModel.SONAR_DEEP_RESEARCH
|
||||
|
||||
def test_default_model_when_omitted(self):
|
||||
inp = PerplexityBlock.Input(**_make_input())
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_value",
|
||||
[
|
||||
"perplexity/sonar",
|
||||
"perplexity/sonar-pro",
|
||||
"perplexity/sonar-deep-research",
|
||||
],
|
||||
)
|
||||
def test_all_valid_models_accepted(self, model_value: str):
|
||||
inp = PerplexityBlock.Input(**_make_input(model=model_value))
|
||||
assert inp.model.value == model_value
|
||||
|
||||
|
||||
class TestPerplexityValidateData:
|
||||
"""Tests for validate_data which runs during block execution (before
|
||||
Pydantic instantiation). Invalid models must be sanitized here so
|
||||
JSON schema validation does not reject them."""
|
||||
|
||||
def test_invalid_model_sanitized_before_schema_validation(self):
|
||||
data = _make_input(model="gpt-5.2-2025-12-11")
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
assert data["model"] == PerplexityModel.SONAR.value
|
||||
|
||||
def test_valid_model_unchanged_by_validate_data(self):
|
||||
data = _make_input(model="perplexity/sonar-pro")
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
assert data["model"] == "perplexity/sonar-pro"
|
||||
|
||||
def test_missing_model_uses_default(self):
|
||||
data = _make_input() # no model key
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
inp = PerplexityBlock.Input(**data)
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
@@ -115,7 +115,7 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=300, # 5 min safety net — explicit per-turn pause is the primary mechanism
|
||||
default=10800, # 3 hours — wall-clock timeout, not idle; explicit pause is primary
|
||||
description="E2B sandbox running-time timeout (seconds). "
|
||||
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
|
||||
"mechanism; this is the safety net.",
|
||||
|
||||
@@ -11,23 +11,12 @@ from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
# Allowed base directory for the Read tool. Public so service.py can use it
|
||||
# for sweep operations without depending on a private implementation detail.
|
||||
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
|
||||
# _projects_base() function.
|
||||
_config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects"))
|
||||
|
||||
# Compiled UUID pattern for validating conversation directory names.
|
||||
# Kept as a module-level constant so the security-relevant pattern is easy
|
||||
# to audit in one place and avoids recompilation on every call.
|
||||
_UUID_RE = re.compile(r"^[0-9a-f]{8}(?:-[0-9a-f]{4}){3}-[0-9a-f]{12}$", re.IGNORECASE)
|
||||
# Allowed base directory for the Read tool.
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
# Encoded project-directory name for the current session (e.g.
|
||||
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
|
||||
@@ -44,20 +33,11 @@ _current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
|
||||
def encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The Claude CLI encodes the absolute cwd as a directory name by replacing
|
||||
every non-alphanumeric character with ``-``. For example
|
||||
``/tmp/copilot-abc`` becomes ``-tmp-copilot-abc``.
|
||||
"""
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does."""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
# Keep the private alias for internal callers (backwards compat).
|
||||
_encode_cwd_for_cli = encode_cwd_for_cli
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
@@ -102,25 +82,12 @@ def resolve_sandbox_path(path: str) -> str:
|
||||
return normalized
|
||||
|
||||
|
||||
async def get_workspace_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped :class:`WorkspaceManager`.
|
||||
|
||||
Placed here (rather than in ``tools/workspace_files``) so that modules
|
||||
like ``sdk/file_ref`` can import it without triggering the heavy
|
||||
``tools/__init__`` import chain.
|
||||
"""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
|
||||
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
"""Return True if *path* is within an allowed host-filesystem location.
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``.
|
||||
The SDK nests tool-results under a conversation UUID directory;
|
||||
the UUID segment is validated with ``_UUID_RE``.
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
@@ -139,22 +106,10 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
|
||||
encoded = _current_project_dir.get("")
|
||||
if encoded:
|
||||
project_dir = os.path.realpath(os.path.join(SDK_PROJECTS_DIR, encoded))
|
||||
# Defence-in-depth: ensure project_dir didn't escape the base.
|
||||
if not project_dir.startswith(SDK_PROJECTS_DIR + os.sep):
|
||||
return False
|
||||
# Only allow: <encoded-cwd>/<uuid>/tool-results/<file>
|
||||
# The SDK always creates a conversation UUID directory between
|
||||
# the project dir and tool-results/.
|
||||
if resolved.startswith(project_dir + os.sep):
|
||||
relative = resolved[len(project_dir) + 1 :]
|
||||
parts = relative.split(os.sep)
|
||||
# Require exactly: [<uuid>, "tool-results", <file>, ...]
|
||||
if (
|
||||
len(parts) >= 3
|
||||
and _UUID_RE.match(parts[0])
|
||||
and parts[1] == "tool-results"
|
||||
):
|
||||
return True
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
if resolved == tool_results_dir or resolved.startswith(
|
||||
tool_results_dir + os.sep
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -9,7 +9,7 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import (
|
||||
SDK_PROJECTS_DIR,
|
||||
_SDK_PROJECTS_DIR,
|
||||
_current_project_dir,
|
||||
get_current_sandbox,
|
||||
get_execution_context,
|
||||
@@ -104,13 +104,11 @@ def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
|
||||
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_with_uuid():
|
||||
"""Files under <encoded-cwd>/<uuid>/tool-results/ are allowed."""
|
||||
def test_is_allowed_local_path_tool_results_dir():
|
||||
"""Files under the tool-results directory for the current project are allowed."""
|
||||
encoded = "test-encoded-dir"
|
||||
conv_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
path = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-results", "output.txt"
|
||||
)
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
path = os.path.join(tool_results_dir, "output.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
@@ -119,22 +117,10 @@ def test_is_allowed_local_path_tool_results_with_uuid():
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_without_uuid_rejected():
|
||||
"""Direct <encoded-cwd>/tool-results/ (no UUID) is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
path = os.path.join(SDK_PROJECTS_DIR, encoded, "tool-results", "output.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert not is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
"""A path adjacent to tool-results/ but not inside it is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
sibling_path = os.path.join(SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
@@ -143,21 +129,6 @@ def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_valid_uuid_wrong_segment_name_rejected():
|
||||
"""A valid UUID dir but non-'tool-results' second segment is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
uuid_str = "12345678-1234-5678-9abc-def012345678"
|
||||
path = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, uuid_str, "not-tool-results", "output.txt"
|
||||
)
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert not is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -52,43 +52,6 @@ Examples:
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
|
||||
**Structured data**: When the **entire** argument value is a single file
|
||||
reference (no surrounding text), the platform automatically parses the file
|
||||
content based on its extension or MIME type. Supported formats: JSON, JSONL,
|
||||
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
|
||||
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
|
||||
the rows will be parsed into `list[list[str]]` automatically. If the format is
|
||||
unrecognised or parsing fails, the content is returned as a plain string.
|
||||
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
|
||||
|
||||
**Type coercion**: The platform also coerces expanded values to match the
|
||||
block's expected input types. For example, if a block expects `list[list[str]]`
|
||||
and the expanded value is a JSON string, it will be parsed into the correct type.
|
||||
|
||||
### Media file inputs (format: "file")
|
||||
Some block inputs accept media files — their schema shows `"format": "file"`.
|
||||
These fields accept:
|
||||
- **`workspace://<file_id>`** or **`workspace://<file_id>#<mime>`** — preferred
|
||||
for large files (images, videos, PDFs). The platform passes the reference
|
||||
directly to the block without reading the content into memory.
|
||||
- **`data:<mime>;base64,<payload>`** — inline base64 data URI, suitable for
|
||||
small files only.
|
||||
|
||||
When a block input has `format: "file"`, **pass the `workspace://` URI
|
||||
directly as the value** (do NOT wrap it in `@@agptfile:`). This avoids large
|
||||
payloads in tool arguments and preserves binary content (images, videos)
|
||||
that would be corrupted by text encoding.
|
||||
|
||||
Example — committing an image file to GitHub:
|
||||
```json
|
||||
{
|
||||
"files": [{
|
||||
"path": "docs/hero.png",
|
||||
"content": "workspace://abc123#image/png",
|
||||
"operation": "upsert"
|
||||
}]
|
||||
}
|
||||
```
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
@@ -152,13 +115,6 @@ def _build_storage_supplement(
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
|
||||
### SDK tool-result files
|
||||
When tool outputs are large, the SDK truncates them and saves the full output to
|
||||
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
|
||||
always use `read_file` or `Read` (NOT `read_workspace_file`).
|
||||
`read_workspace_file` reads from cloud workspace storage, where SDK
|
||||
tool-results are NOT stored.
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
|
||||
|
||||
|
||||
@@ -3,45 +3,12 @@
|
||||
This module provides the integration layer between the Claude Agent SDK
|
||||
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||
|
||||
Submodule imports are deferred via PEP 562 ``__getattr__`` to break a
|
||||
circular import cycle::
|
||||
|
||||
sdk/__init__ → tool_adapter → copilot.tools (TOOL_REGISTRY)
|
||||
copilot.tools → run_block → sdk.file_ref (no cycle here, but…)
|
||||
sdk/__init__ → service → copilot.prompting → copilot.tools (cycle!)
|
||||
|
||||
``tool_adapter`` uses ``TOOL_REGISTRY`` at **module level** to build the
|
||||
static ``COPILOT_TOOL_NAMES`` list, so the import cannot be deferred to
|
||||
function scope without a larger refactor (moving tool-name registration
|
||||
to a separate lightweight module). The lazy-import pattern here is the
|
||||
least invasive way to break the cycle while keeping module-level constants
|
||||
intact.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from .service import stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server
|
||||
|
||||
__all__ = [
|
||||
"stream_chat_completion_sdk",
|
||||
"create_copilot_mcp_server",
|
||||
]
|
||||
|
||||
# Dispatch table for PEP 562 lazy imports. Each entry is a (module, attr)
|
||||
# pair so new exports can be added without touching __getattr__ itself.
|
||||
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
|
||||
"stream_chat_completion_sdk": (".service", "stream_chat_completion_sdk"),
|
||||
"create_copilot_mcp_server": (".tool_adapter", "create_copilot_mcp_server"),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
entry = _LAZY_IMPORTS.get(name)
|
||||
if entry is not None:
|
||||
module_path, attr = entry
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_path, package=__name__)
|
||||
value = getattr(module, attr)
|
||||
globals()[name] = value
|
||||
return value
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -11,7 +11,7 @@ persistence, and the ``CompactionTracker`` state machine.
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from collections.abc import Callable
|
||||
|
||||
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
||||
from ..model import ChatMessage, ChatSession
|
||||
@@ -27,19 +27,6 @@ from ..response_model import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompactionResult:
|
||||
"""Result of emit_end_if_ready — bundles events with compaction metadata.
|
||||
|
||||
Eliminates the need for separate ``compaction_just_ended`` checks,
|
||||
preventing TOCTOU races between the emit call and the flag read.
|
||||
"""
|
||||
|
||||
events: list[StreamBaseResponse] = field(default_factory=list)
|
||||
just_ended: bool = False
|
||||
transcript_path: str = ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event builders (private — use CompactionTracker or compaction_events)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -190,22 +177,11 @@ class CompactionTracker:
|
||||
self._start_emitted = False
|
||||
self._done = False
|
||||
self._tool_call_id = ""
|
||||
self._transcript_path: str = ""
|
||||
|
||||
def on_compact(self, transcript_path: str = "") -> None:
|
||||
"""Callback for the PreCompact hook. Stores transcript_path."""
|
||||
if (
|
||||
self._transcript_path
|
||||
and transcript_path
|
||||
and self._transcript_path != transcript_path
|
||||
):
|
||||
logger.warning(
|
||||
"[Compaction] Overwriting transcript_path %s -> %s",
|
||||
self._transcript_path,
|
||||
transcript_path,
|
||||
)
|
||||
self._transcript_path = transcript_path
|
||||
self._compact_start.set()
|
||||
@property
|
||||
def on_compact(self) -> Callable[[], None]:
|
||||
"""Callback for the PreCompact hook."""
|
||||
return self._compact_start.set
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pre-query compaction
|
||||
@@ -225,7 +201,6 @@ class CompactionTracker:
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
self._transcript_path = ""
|
||||
|
||||
def emit_start_if_ready(self) -> list[StreamBaseResponse]:
|
||||
"""If the PreCompact hook fired, emit start events (spinning tool)."""
|
||||
@@ -236,20 +211,15 @@ class CompactionTracker:
|
||||
return _start_events(self._tool_call_id)
|
||||
return []
|
||||
|
||||
async def emit_end_if_ready(self, session: ChatSession) -> CompactionResult:
|
||||
"""If compaction is in progress, emit end events and persist.
|
||||
|
||||
Returns a ``CompactionResult`` with ``just_ended=True`` and the
|
||||
captured ``transcript_path`` when a compaction cycle completes.
|
||||
This avoids a separate flag check (TOCTOU-safe).
|
||||
"""
|
||||
async def emit_end_if_ready(self, session: ChatSession) -> list[StreamBaseResponse]:
|
||||
"""If compaction is in progress, emit end events and persist."""
|
||||
# Yield so pending hook tasks can set compact_start
|
||||
await asyncio.sleep(0)
|
||||
|
||||
if self._done:
|
||||
return CompactionResult()
|
||||
return []
|
||||
if not self._start_emitted and not self._compact_start.is_set():
|
||||
return CompactionResult()
|
||||
return []
|
||||
|
||||
if self._start_emitted:
|
||||
# Close the open spinner
|
||||
@@ -262,12 +232,8 @@ class CompactionTracker:
|
||||
COMPACTION_DONE_MSG, tool_call_id=persist_id
|
||||
)
|
||||
|
||||
transcript_path = self._transcript_path
|
||||
self._compact_start.clear()
|
||||
self._start_emitted = False
|
||||
self._done = True
|
||||
self._transcript_path = ""
|
||||
_persist(session, persist_id, COMPACTION_DONE_MSG)
|
||||
return CompactionResult(
|
||||
events=done_events, just_ended=True, transcript_path=transcript_path
|
||||
)
|
||||
return done_events
|
||||
|
||||
@@ -195,11 +195,10 @@ class TestCompactionTracker:
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
assert result.just_ended is True
|
||||
assert len(result.events) == 2
|
||||
assert isinstance(result.events[0], StreamToolOutputAvailable)
|
||||
assert isinstance(result.events[1], StreamFinishStep)
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert len(evts) == 2
|
||||
assert isinstance(evts[0], StreamToolOutputAvailable)
|
||||
assert isinstance(evts[1], StreamFinishStep)
|
||||
# Should persist
|
||||
assert len(session.messages) == 2
|
||||
|
||||
@@ -211,32 +210,28 @@ class TestCompactionTracker:
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
# Don't call emit_start_if_ready
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
assert result.just_ended is True
|
||||
assert len(result.events) == 5 # Full self-contained event
|
||||
assert isinstance(result.events[0], StreamStartStep)
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert len(evts) == 5 # Full self-contained event
|
||||
assert isinstance(evts[0], StreamStartStep)
|
||||
assert len(session.messages) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_no_op_when_no_new_compaction(self):
|
||||
async def test_emit_end_no_op_when_done(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
result1 = await tracker.emit_end_if_ready(session)
|
||||
assert result1.just_ended is True
|
||||
# Second call should be no-op (no new on_compact)
|
||||
result2 = await tracker.emit_end_if_ready(session)
|
||||
assert result2.just_ended is False
|
||||
assert result2.events == []
|
||||
await tracker.emit_end_if_ready(session)
|
||||
# Second call should be no-op
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert evts == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_no_op_when_nothing_happened(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
assert result.just_ended is False
|
||||
assert result.events == []
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert evts == []
|
||||
|
||||
def test_emit_pre_query(self):
|
||||
tracker = CompactionTracker()
|
||||
@@ -251,29 +246,20 @@ class TestCompactionTracker:
|
||||
tracker._done = True
|
||||
tracker._start_emitted = True
|
||||
tracker._tool_call_id = "old"
|
||||
tracker._transcript_path = "/some/path"
|
||||
tracker.reset_for_query()
|
||||
assert tracker._done is False
|
||||
assert tracker._start_emitted is False
|
||||
assert tracker._tool_call_id == ""
|
||||
assert tracker._transcript_path == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_query_blocks_sdk_compaction_until_reset(self):
|
||||
"""After pre-query compaction, SDK compaction is blocked until
|
||||
reset_for_query is called."""
|
||||
async def test_pre_query_blocks_sdk_compaction(self):
|
||||
"""After pre-query compaction, SDK compaction events are suppressed."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.emit_pre_query(session)
|
||||
tracker.on_compact()
|
||||
# _done is True so emit_start_if_ready is blocked
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert evts == []
|
||||
# Reset clears _done, allowing subsequent compaction
|
||||
tracker.reset_for_query()
|
||||
tracker.on_compact()
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert len(evts) == 3
|
||||
assert evts == [] # _done blocks it
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_allows_new_compaction(self):
|
||||
@@ -293,9 +279,9 @@ class TestCompactionTracker:
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
start_evts = tracker.emit_start_if_ready()
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
end_evts = await tracker.emit_end_if_ready(session)
|
||||
start_evt = start_evts[1]
|
||||
end_evt = result.events[0]
|
||||
end_evt = end_evts[0]
|
||||
assert isinstance(start_evt, StreamToolInputStart)
|
||||
assert isinstance(end_evt, StreamToolOutputAvailable)
|
||||
assert start_evt.toolCallId == end_evt.toolCallId
|
||||
@@ -303,105 +289,3 @@ class TestCompactionTracker:
|
||||
tool_calls = session.messages[0].tool_calls
|
||||
assert tool_calls is not None
|
||||
assert tool_calls[0]["id"] == start_evt.toolCallId
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_compactions_within_query(self):
|
||||
"""Two mid-stream compactions within a single query both trigger."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
|
||||
# First compaction cycle
|
||||
tracker.on_compact("/path/1")
|
||||
tracker.emit_start_if_ready()
|
||||
result1 = await tracker.emit_end_if_ready(session)
|
||||
assert result1.just_ended is True
|
||||
assert len(result1.events) == 2
|
||||
assert result1.transcript_path == "/path/1"
|
||||
|
||||
# Second compaction cycle (should NOT be blocked — _done resets
|
||||
# because emit_end_if_ready sets it True, but the next on_compact
|
||||
# + emit_start_if_ready checks !_done which IS True now.
|
||||
# So we need reset_for_query between queries, but within a single
|
||||
# query multiple compactions work because _done blocks emit_start
|
||||
# until the next message arrives, at which point emit_end detects it)
|
||||
#
|
||||
# Actually: _done=True blocks emit_start_if_ready, so we need
|
||||
# the stream loop to reset. In practice service.py doesn't call
|
||||
# reset between compactions within the same query — let's verify
|
||||
# the actual behavior.
|
||||
tracker.on_compact("/path/2")
|
||||
# _done is True from first compaction, so start is blocked
|
||||
start_evts = tracker.emit_start_if_ready()
|
||||
assert start_evts == []
|
||||
# But emit_end returns no-op because _done is True
|
||||
result2 = await tracker.emit_end_if_ready(session)
|
||||
assert result2.just_ended is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_compactions_with_intervening_message(self):
|
||||
"""Multiple compactions work when the stream loop processes messages between them.
|
||||
|
||||
In the real service.py flow:
|
||||
1. PreCompact fires → on_compact()
|
||||
2. emit_start shows spinner
|
||||
3. Next message arrives → emit_end completes compaction (_done=True)
|
||||
4. Stream continues processing messages...
|
||||
5. If a second PreCompact fires, _done=True blocks emit_start
|
||||
6. But the next message triggers emit_end, which sees _done=True → no-op
|
||||
7. The stream loop needs to detect this and handle accordingly
|
||||
|
||||
The actual flow for multiple compactions within a query requires
|
||||
_done to be cleared between them. The service.py code uses
|
||||
CompactionResult.just_ended to trigger replace_entries, and _done
|
||||
stays True until reset_for_query.
|
||||
"""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
|
||||
# First compaction
|
||||
tracker.on_compact("/path/1")
|
||||
tracker.emit_start_if_ready()
|
||||
result1 = await tracker.emit_end_if_ready(session)
|
||||
assert result1.just_ended is True
|
||||
assert result1.transcript_path == "/path/1"
|
||||
|
||||
# Simulate reset between queries
|
||||
tracker.reset_for_query()
|
||||
|
||||
# Second compaction in new query
|
||||
tracker.on_compact("/path/2")
|
||||
start_evts = tracker.emit_start_if_ready()
|
||||
assert len(start_evts) == 3
|
||||
result2 = await tracker.emit_end_if_ready(session)
|
||||
assert result2.just_ended is True
|
||||
assert result2.transcript_path == "/path/2"
|
||||
|
||||
def test_on_compact_stores_transcript_path(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact("/some/path.jsonl")
|
||||
assert tracker._transcript_path == "/some/path.jsonl"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_returns_transcript_path(self):
|
||||
"""CompactionResult includes the transcript_path from on_compact."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact("/my/session.jsonl")
|
||||
tracker.emit_start_if_ready()
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
assert result.just_ended is True
|
||||
assert result.transcript_path == "/my/session.jsonl"
|
||||
# transcript_path is cleared after emit_end
|
||||
assert tracker._transcript_path == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_clears_transcript_path(self):
|
||||
"""After emit_end, _transcript_path is reset so it doesn't leak to
|
||||
subsequent non-compaction emit_end calls."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact("/first/path.jsonl")
|
||||
tracker.emit_start_if_ready()
|
||||
await tracker.emit_end_if_ready(session)
|
||||
# After compaction, _transcript_path is cleared
|
||||
assert tracker._transcript_path == ""
|
||||
|
||||
@@ -26,41 +26,6 @@ from backend.copilot.context import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _check_sandbox_symlink_escape(
|
||||
sandbox: Any,
|
||||
parent: str,
|
||||
) -> str | None:
|
||||
"""Resolve the canonical parent path inside the sandbox to detect symlink escapes.
|
||||
|
||||
``normpath`` (used by ``resolve_sandbox_path``) only normalises the string;
|
||||
``readlink -f`` follows actual symlinks on the sandbox filesystem.
|
||||
|
||||
Returns the canonical parent path, or ``None`` if the path escapes
|
||||
``E2B_WORKDIR``.
|
||||
|
||||
Note: There is an inherent TOCTOU window between this check and the
|
||||
subsequent ``sandbox.files.write()``. A symlink could theoretically be
|
||||
replaced between the two operations. This is acceptable in the E2B
|
||||
sandbox model since the sandbox is single-user and ephemeral.
|
||||
"""
|
||||
canonical_res = await sandbox.commands.run(
|
||||
f"readlink -f {shlex.quote(parent or E2B_WORKDIR)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=5,
|
||||
)
|
||||
canonical_parent = (canonical_res.stdout or "").strip()
|
||||
if (
|
||||
canonical_res.exit_code != 0
|
||||
or not canonical_parent
|
||||
or (
|
||||
canonical_parent != E2B_WORKDIR
|
||||
and not canonical_parent.startswith(E2B_WORKDIR + "/")
|
||||
)
|
||||
):
|
||||
return None
|
||||
return canonical_parent
|
||||
|
||||
|
||||
def _get_sandbox():
|
||||
return get_current_sandbox()
|
||||
|
||||
@@ -141,10 +106,6 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
parent = os.path.dirname(remote)
|
||||
if parent and parent != E2B_WORKDIR:
|
||||
await sandbox.files.make_dir(parent)
|
||||
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
|
||||
if canonical_parent is None:
|
||||
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
|
||||
remote = os.path.join(canonical_parent, os.path.basename(remote))
|
||||
await sandbox.files.write(remote, content)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
@@ -169,12 +130,6 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
parent = os.path.dirname(remote)
|
||||
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
|
||||
if canonical_parent is None:
|
||||
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
|
||||
remote = os.path.join(canonical_parent, os.path.basename(remote))
|
||||
|
||||
try:
|
||||
raw: bytes = await sandbox.files.read(remote, format="bytes")
|
||||
content = raw.decode("utf-8", errors="replace")
|
||||
|
||||
@@ -4,19 +4,15 @@ Pure unit tests with no external dependencies (no E2B, no sandbox).
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_project_dir
|
||||
from backend.copilot.context import _current_project_dir
|
||||
|
||||
from .e2b_file_tools import _read_local, resolve_sandbox_path
|
||||
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
from .e2b_file_tools import (
|
||||
_check_sandbox_symlink_escape,
|
||||
_read_local,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
|
||||
@@ -25,48 +21,46 @@ from .e2b_file_tools import (
|
||||
|
||||
class TestResolveSandboxPath:
|
||||
def test_relative_path_resolved(self):
|
||||
assert resolve_sandbox_path("src/main.py") == f"{E2B_WORKDIR}/src/main.py"
|
||||
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
|
||||
|
||||
def test_absolute_within_sandbox(self):
|
||||
assert (
|
||||
resolve_sandbox_path(f"{E2B_WORKDIR}/file.txt") == f"{E2B_WORKDIR}/file.txt"
|
||||
)
|
||||
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
|
||||
|
||||
def test_workdir_itself(self):
|
||||
assert resolve_sandbox_path(E2B_WORKDIR) == E2B_WORKDIR
|
||||
assert resolve_sandbox_path("/home/user") == "/home/user"
|
||||
|
||||
def test_relative_dotslash(self):
|
||||
assert resolve_sandbox_path("./README.md") == f"{E2B_WORKDIR}/README.md"
|
||||
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
|
||||
|
||||
def test_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("../../etc/passwd")
|
||||
|
||||
def test_absolute_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path(f"{E2B_WORKDIR}/../../etc/passwd")
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
|
||||
def test_absolute_outside_sandbox_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
|
||||
def test_root_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/")
|
||||
|
||||
def test_home_other_user_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/home/other/file.txt")
|
||||
|
||||
def test_deep_nested_allowed(self):
|
||||
assert resolve_sandbox_path("a/b/c/d/e.txt") == f"{E2B_WORKDIR}/a/b/c/d/e.txt"
|
||||
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||
|
||||
def test_trailing_slash_normalised(self):
|
||||
assert resolve_sandbox_path("src/") == f"{E2B_WORKDIR}/src"
|
||||
assert resolve_sandbox_path("src/") == "/home/user/src"
|
||||
|
||||
def test_double_dots_within_sandbox_ok(self):
|
||||
"""Path that resolves back within E2B_WORKDIR is allowed."""
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == f"{E2B_WORKDIR}/a/c.txt"
|
||||
"""Path that resolves back within /home/user is allowed."""
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -79,13 +73,9 @@ class TestResolveSandboxPath:
|
||||
|
||||
|
||||
class TestReadLocal:
|
||||
_CONV_UUID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
|
||||
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
|
||||
"""Create a tool-results file under <encoded>/<uuid>/tool-results/."""
|
||||
tool_results_dir = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
|
||||
)
|
||||
"""Create a tool-results file and return its path."""
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, filename)
|
||||
with open(filepath, "w") as f:
|
||||
@@ -117,9 +107,7 @@ class TestReadLocal:
|
||||
def test_read_nonexistent_tool_results(self):
|
||||
"""A tool-results path that doesn't exist returns FileNotFoundError."""
|
||||
encoded = "-tmp-copilot-e2b-test-nofile"
|
||||
tool_results_dir = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
|
||||
)
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
|
||||
token = _current_project_dir.set(encoded)
|
||||
@@ -129,7 +117,7 @@ class TestReadLocal:
|
||||
assert "not found" in result["content"][0]["text"].lower()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
|
||||
os.rmdir(tool_results_dir)
|
||||
|
||||
def test_read_traversal_path_blocked(self):
|
||||
"""A traversal attempt that escapes allowed directories is blocked."""
|
||||
@@ -164,66 +152,3 @@ class TestReadLocal:
|
||||
"""Without _current_project_dir set, all paths are blocked."""
|
||||
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _check_sandbox_symlink_escape — symlink escape detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_sandbox(stdout: str, exit_code: int = 0) -> SimpleNamespace:
|
||||
"""Build a minimal sandbox mock whose commands.run returns a fixed result."""
|
||||
run_result = SimpleNamespace(stdout=stdout, exit_code=exit_code)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
return SimpleNamespace(commands=commands)
|
||||
|
||||
|
||||
class TestCheckSandboxSymlinkEscape:
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonical_path_within_workdir_returns_path(self):
|
||||
"""When readlink -f resolves to a path inside E2B_WORKDIR, returns it."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/src\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result == f"{E2B_WORKDIR}/src"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workdir_itself_returns_workdir(self):
|
||||
"""When readlink -f resolves to E2B_WORKDIR exactly, returns E2B_WORKDIR."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, E2B_WORKDIR)
|
||||
assert result == E2B_WORKDIR
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_symlink_escape_returns_none(self):
|
||||
"""When readlink -f resolves outside E2B_WORKDIR (symlink escape), returns None."""
|
||||
sandbox = _make_sandbox(stdout="/etc\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/evil")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonzero_exit_code_returns_none(self):
|
||||
"""A non-zero exit code from readlink -f returns None."""
|
||||
sandbox = _make_sandbox(stdout="", exit_code=1)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_stdout_returns_none(self):
|
||||
"""Empty stdout from readlink (e.g. path doesn't exist yet) returns None."""
|
||||
sandbox = _make_sandbox(stdout="", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prefix_collision_returns_none(self):
|
||||
"""A path prefixed with E2B_WORKDIR but not within it is rejected."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}-evil\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}-evil")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deeply_nested_path_within_workdir(self):
|
||||
"""Deep nested paths inside E2B_WORKDIR are allowed."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/a/b/c/d\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/a/b/c/d")
|
||||
assert result == f"{E2B_WORKDIR}/a/b/c/d"
|
||||
|
||||
@@ -1,531 +0,0 @@
|
||||
"""End-to-end compaction flow test.
|
||||
|
||||
Simulates the full service.py compaction lifecycle using real-format
|
||||
JSONL session files — no SDK subprocess needed. Exercises:
|
||||
|
||||
1. TranscriptBuilder loads a "downloaded" transcript
|
||||
2. User query appended, assistant response streamed
|
||||
3. PreCompact hook fires → CompactionTracker.on_compact()
|
||||
4. Next message → emit_start_if_ready() yields spinner events
|
||||
5. Message after that → emit_end_if_ready() returns CompactionResult
|
||||
6. read_compacted_entries() reads the CLI session file
|
||||
7. TranscriptBuilder.replace_entries() syncs state
|
||||
8. More messages appended post-compaction
|
||||
9. to_jsonl() exports full state for upload
|
||||
10. Fresh builder loads the export — roundtrip verified
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import (
|
||||
StreamFinishStep,
|
||||
StreamStartStep,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.sdk.compaction import CompactionTracker
|
||||
from backend.copilot.sdk.transcript import (
|
||||
read_compacted_entries,
|
||||
strip_progress_entries,
|
||||
)
|
||||
from backend.copilot.sdk.transcript_builder import TranscriptBuilder
|
||||
from backend.util import json
|
||||
|
||||
|
||||
def _make_jsonl(*entries: dict) -> str:
|
||||
return "\n".join(json.dumps(e) for e in entries) + "\n"
|
||||
|
||||
|
||||
def _run(coro):
|
||||
"""Run an async coroutine synchronously."""
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: realistic CLI session file content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Pre-compaction conversation
|
||||
USER_1 = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "What files are in this project?"},
|
||||
}
|
||||
ASST_1_THINKING = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-think",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_aaa",
|
||||
"type": "message",
|
||||
"content": [{"type": "thinking", "thinking": "Let me look at the files..."}],
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
ASST_1_TOOL = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-tool",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_aaa",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tu1",
|
||||
"name": "Bash",
|
||||
"input": {"command": "ls"},
|
||||
}
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
TOOL_RESULT_1 = {
|
||||
"type": "user",
|
||||
"uuid": "tr1",
|
||||
"parentUuid": "a1-tool",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "tu1",
|
||||
"content": "file1.py\nfile2.py",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
ASST_1_TEXT = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-text",
|
||||
"parentUuid": "tr1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_bbb",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "I found file1.py and file2.py."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
# Progress entries (should be stripped during upload)
|
||||
PROGRESS_1 = {
|
||||
"type": "progress",
|
||||
"uuid": "prog1",
|
||||
"parentUuid": "a1-tool",
|
||||
"data": {"type": "bash_progress", "stdout": "running ls..."},
|
||||
}
|
||||
# Second user message
|
||||
USER_2 = {
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"parentUuid": "a1-text",
|
||||
"message": {"role": "user", "content": "Show me file1.py"},
|
||||
}
|
||||
ASST_2 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "u2",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_ccc",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is file1.py content..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
# --- Compaction summary (written by CLI after context compaction) ---
|
||||
COMPACT_SUMMARY = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Summary: User asked about project files. Found file1.py and file2.py. "
|
||||
"User then asked to see file1.py."
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# Post-compaction assistant response
|
||||
POST_COMPACT_ASST = {
|
||||
"type": "assistant",
|
||||
"uuid": "a3",
|
||||
"parentUuid": "cs1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_ddd",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is the content of file1.py..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
# Post-compaction user follow-up
|
||||
USER_3 = {
|
||||
"type": "user",
|
||||
"uuid": "u3",
|
||||
"parentUuid": "a3",
|
||||
"message": {"role": "user", "content": "Now show file2.py"},
|
||||
}
|
||||
ASST_3 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a4",
|
||||
"parentUuid": "u3",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_eee",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is file2.py..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2E test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactionE2E:
|
||||
def _write_session_file(self, session_dir, entries):
|
||||
"""Write a CLI session JSONL file."""
|
||||
path = session_dir / "session.jsonl"
|
||||
path.write_text(_make_jsonl(*entries))
|
||||
return path
|
||||
|
||||
def test_full_compaction_lifecycle(self, tmp_path, monkeypatch):
|
||||
"""Simulate the complete service.py compaction flow.
|
||||
|
||||
Timeline:
|
||||
1. Previous turn uploaded transcript with [USER_1, ASST_1, USER_2, ASST_2]
|
||||
2. Current turn: download → load_previous
|
||||
3. User sends "Now show file2.py" → append_user
|
||||
4. SDK starts streaming response
|
||||
5. Mid-stream: PreCompact hook fires (context too large)
|
||||
6. CLI writes compaction summary to session file
|
||||
7. Next SDK message → emit_start (spinner)
|
||||
8. Following message → emit_end (CompactionResult)
|
||||
9. read_compacted_entries reads the session file
|
||||
10. replace_entries syncs TranscriptBuilder
|
||||
11. More assistant messages appended
|
||||
12. Export → upload → next turn downloads it
|
||||
"""
|
||||
# --- Setup CLI projects directory ---
|
||||
config_dir = tmp_path / "config"
|
||||
projects_dir = config_dir / "projects"
|
||||
session_dir = projects_dir / "proj"
|
||||
session_dir.mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
# --- Step 1-2: Load "downloaded" transcript from previous turn ---
|
||||
previous_transcript = _make_jsonl(
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(previous_transcript)
|
||||
assert builder.entry_count == 7
|
||||
|
||||
# --- Step 3: User sends new query ---
|
||||
builder.append_user("Now show file2.py")
|
||||
assert builder.entry_count == 8
|
||||
|
||||
# --- Step 4: SDK starts streaming ---
|
||||
builder.append_assistant(
|
||||
[{"type": "thinking", "thinking": "Let me read file2.py..."}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
)
|
||||
assert builder.entry_count == 9
|
||||
|
||||
# --- Step 5-6: PreCompact fires, CLI writes session file ---
|
||||
session_file = self._write_session_file(
|
||||
session_dir,
|
||||
[
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
PROGRESS_1,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
COMPACT_SUMMARY,
|
||||
POST_COMPACT_ASST,
|
||||
USER_3,
|
||||
ASST_3,
|
||||
],
|
||||
)
|
||||
|
||||
# --- Step 7: CompactionTracker receives PreCompact hook ---
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test-user")
|
||||
tracker.on_compact(str(session_file))
|
||||
|
||||
# --- Step 8: Next SDK message arrives → emit_start ---
|
||||
start_events = tracker.emit_start_if_ready()
|
||||
assert len(start_events) == 3
|
||||
assert isinstance(start_events[0], StreamStartStep)
|
||||
assert isinstance(start_events[1], StreamToolInputStart)
|
||||
assert isinstance(start_events[2], StreamToolInputAvailable)
|
||||
|
||||
# Verify tool_call_id is set
|
||||
tool_call_id = start_events[1].toolCallId
|
||||
assert tool_call_id.startswith("compaction-")
|
||||
|
||||
# --- Step 9: Following message → emit_end ---
|
||||
result = _run(tracker.emit_end_if_ready(session))
|
||||
assert result.just_ended is True
|
||||
assert result.transcript_path == str(session_file)
|
||||
assert len(result.events) == 2
|
||||
assert isinstance(result.events[0], StreamToolOutputAvailable)
|
||||
assert isinstance(result.events[1], StreamFinishStep)
|
||||
# Verify same tool_call_id
|
||||
assert result.events[0].toolCallId == tool_call_id
|
||||
|
||||
# Session should have compaction messages persisted
|
||||
assert len(session.messages) == 2
|
||||
assert session.messages[0].role == "assistant"
|
||||
assert session.messages[1].role == "tool"
|
||||
|
||||
# --- Step 10: read_compacted_entries + replace_entries ---
|
||||
compacted = read_compacted_entries(str(session_file))
|
||||
assert compacted is not None
|
||||
# Should have: COMPACT_SUMMARY + POST_COMPACT_ASST + USER_3 + ASST_3
|
||||
assert len(compacted) == 4
|
||||
assert compacted[0]["uuid"] == "cs1"
|
||||
assert compacted[0]["isCompactSummary"] is True
|
||||
|
||||
# Replace builder state with compacted entries
|
||||
old_count = builder.entry_count
|
||||
builder.replace_entries(compacted)
|
||||
assert builder.entry_count == 4 # Only compacted entries
|
||||
assert builder.entry_count < old_count # Compaction reduced entries
|
||||
|
||||
# --- Step 11: More assistant messages after compaction ---
|
||||
builder.append_assistant(
|
||||
[{"type": "text", "text": "Here is file2.py:\n\ndef hello():\n pass"}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
stop_reason="end_turn",
|
||||
)
|
||||
assert builder.entry_count == 5
|
||||
|
||||
# --- Step 12: Export for upload ---
|
||||
output = builder.to_jsonl()
|
||||
assert output # Not empty
|
||||
output_entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert len(output_entries) == 5
|
||||
|
||||
# Verify structure:
|
||||
# [COMPACT_SUMMARY, POST_COMPACT_ASST, USER_3, ASST_3, new_assistant]
|
||||
assert output_entries[0]["type"] == "summary"
|
||||
assert output_entries[0].get("isCompactSummary") is True
|
||||
assert output_entries[0]["uuid"] == "cs1"
|
||||
assert output_entries[1]["uuid"] == "a3"
|
||||
assert output_entries[2]["uuid"] == "u3"
|
||||
assert output_entries[3]["uuid"] == "a4"
|
||||
assert output_entries[4]["type"] == "assistant"
|
||||
|
||||
# Verify parent chain is intact
|
||||
assert output_entries[1]["parentUuid"] == "cs1" # a3 → cs1
|
||||
assert output_entries[2]["parentUuid"] == "a3" # u3 → a3
|
||||
assert output_entries[3]["parentUuid"] == "u3" # a4 → u3
|
||||
assert output_entries[4]["parentUuid"] == "a4" # new → a4
|
||||
|
||||
# --- Step 13: Roundtrip — next turn loads this export ---
|
||||
builder2 = TranscriptBuilder()
|
||||
builder2.load_previous(output)
|
||||
assert builder2.entry_count == 5
|
||||
|
||||
# isCompactSummary survives roundtrip
|
||||
output2 = builder2.to_jsonl()
|
||||
first_entry = json.loads(output2.strip().split("\n")[0])
|
||||
assert first_entry.get("isCompactSummary") is True
|
||||
|
||||
# Can append more messages
|
||||
builder2.append_user("What about file3.py?")
|
||||
assert builder2.entry_count == 6
|
||||
final_output = builder2.to_jsonl()
|
||||
last_entry = json.loads(final_output.strip().split("\n")[-1])
|
||||
assert last_entry["type"] == "user"
|
||||
# Parented to the last entry from previous turn
|
||||
assert last_entry["parentUuid"] == output_entries[-1]["uuid"]
|
||||
|
||||
def test_double_compaction_within_session(self, tmp_path, monkeypatch):
|
||||
"""Two compactions in the same session (across reset_for_query)."""
|
||||
config_dir = tmp_path / "config"
|
||||
projects_dir = config_dir / "projects"
|
||||
session_dir = projects_dir / "proj"
|
||||
session_dir.mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test")
|
||||
builder = TranscriptBuilder()
|
||||
|
||||
# --- First query with compaction ---
|
||||
builder.append_user("first question")
|
||||
builder.append_assistant([{"type": "text", "text": "first answer"}])
|
||||
|
||||
# Write session file for first compaction
|
||||
first_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs-first",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "First compaction summary"},
|
||||
}
|
||||
first_post = {
|
||||
"type": "assistant",
|
||||
"uuid": "a-first",
|
||||
"parentUuid": "cs-first",
|
||||
"message": {"role": "assistant", "content": "first post-compact"},
|
||||
}
|
||||
file1 = session_dir / "session1.jsonl"
|
||||
file1.write_text(_make_jsonl(first_summary, first_post))
|
||||
|
||||
tracker.on_compact(str(file1))
|
||||
tracker.emit_start_if_ready()
|
||||
result1 = _run(tracker.emit_end_if_ready(session))
|
||||
assert result1.just_ended is True
|
||||
|
||||
compacted1 = read_compacted_entries(str(file1))
|
||||
assert compacted1 is not None
|
||||
builder.replace_entries(compacted1)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# --- Reset for second query ---
|
||||
tracker.reset_for_query()
|
||||
|
||||
# --- Second query with compaction ---
|
||||
builder.append_user("second question")
|
||||
builder.append_assistant([{"type": "text", "text": "second answer"}])
|
||||
|
||||
second_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs-second",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "Second compaction summary"},
|
||||
}
|
||||
second_post = {
|
||||
"type": "assistant",
|
||||
"uuid": "a-second",
|
||||
"parentUuid": "cs-second",
|
||||
"message": {"role": "assistant", "content": "second post-compact"},
|
||||
}
|
||||
file2 = session_dir / "session2.jsonl"
|
||||
file2.write_text(_make_jsonl(second_summary, second_post))
|
||||
|
||||
tracker.on_compact(str(file2))
|
||||
tracker.emit_start_if_ready()
|
||||
result2 = _run(tracker.emit_end_if_ready(session))
|
||||
assert result2.just_ended is True
|
||||
|
||||
compacted2 = read_compacted_entries(str(file2))
|
||||
assert compacted2 is not None
|
||||
builder.replace_entries(compacted2)
|
||||
assert builder.entry_count == 2 # Only second compaction entries
|
||||
|
||||
# Export and verify
|
||||
output = builder.to_jsonl()
|
||||
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert entries[0]["uuid"] == "cs-second"
|
||||
assert entries[0].get("isCompactSummary") is True
|
||||
|
||||
def test_strip_progress_then_load_then_compact_roundtrip(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""Full pipeline: strip → load → compact → replace → export → reload.
|
||||
|
||||
This tests the exact sequence that happens across two turns:
|
||||
Turn 1: SDK produces transcript with progress entries
|
||||
Upload: strip_progress_entries removes progress, upload to cloud
|
||||
Turn 2: Download → load_previous → compaction fires → replace → export
|
||||
Turn 3: Download the Turn 2 export → load_previous (roundtrip)
|
||||
"""
|
||||
config_dir = tmp_path / "config"
|
||||
projects_dir = config_dir / "projects"
|
||||
session_dir = projects_dir / "proj"
|
||||
session_dir.mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
# --- Turn 1: SDK produces raw transcript ---
|
||||
raw_content = _make_jsonl(
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
PROGRESS_1,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
)
|
||||
|
||||
# Strip progress for upload
|
||||
stripped = strip_progress_entries(raw_content)
|
||||
stripped_entries = [
|
||||
json.loads(line) for line in stripped.strip().split("\n") if line.strip()
|
||||
]
|
||||
# Progress should be gone
|
||||
assert not any(e.get("type") == "progress" for e in stripped_entries)
|
||||
assert len(stripped_entries) == 7 # 8 - 1 progress
|
||||
|
||||
# --- Turn 2: Download stripped, load, compaction happens ---
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(stripped)
|
||||
assert builder.entry_count == 7
|
||||
|
||||
builder.append_user("Now show file2.py")
|
||||
builder.append_assistant(
|
||||
[{"type": "text", "text": "Reading file2.py..."}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
)
|
||||
|
||||
# CLI writes session file with compaction
|
||||
session_file = self._write_session_file(
|
||||
session_dir,
|
||||
[
|
||||
USER_1,
|
||||
ASST_1_TOOL,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
COMPACT_SUMMARY,
|
||||
POST_COMPACT_ASST,
|
||||
],
|
||||
)
|
||||
|
||||
compacted = read_compacted_entries(str(session_file))
|
||||
assert compacted is not None
|
||||
builder.replace_entries(compacted)
|
||||
|
||||
# Append post-compaction message
|
||||
builder.append_user("Thanks!")
|
||||
output = builder.to_jsonl()
|
||||
|
||||
# --- Turn 3: Fresh load of Turn 2 export ---
|
||||
builder3 = TranscriptBuilder()
|
||||
builder3.load_previous(output)
|
||||
# Should have: compact_summary + post_compact_asst + "Thanks!"
|
||||
assert builder3.entry_count == 3
|
||||
|
||||
# Compact summary survived the full pipeline
|
||||
first = json.loads(builder3.to_jsonl().strip().split("\n")[0])
|
||||
assert first.get("isCompactSummary") is True
|
||||
assert first["type"] == "summary"
|
||||
@@ -41,20 +41,12 @@ from typing import Any
|
||||
from backend.copilot.context import (
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
get_workspace_manager,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.workspace_files import get_manager
|
||||
from backend.util.file import parse_workspace_uri
|
||||
from backend.util.file_content_parser import (
|
||||
BINARY_FORMATS,
|
||||
MIME_TO_FORMAT,
|
||||
PARSE_EXCEPTIONS,
|
||||
infer_format_from_uri,
|
||||
parse_file_content,
|
||||
)
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
class FileRefExpansionError(Exception):
|
||||
@@ -82,8 +74,6 @@ _FILE_REF_RE = re.compile(
|
||||
_MAX_EXPAND_CHARS = 200_000
|
||||
# Maximum total characters across all @@agptfile: expansions in one string.
|
||||
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
|
||||
# Maximum raw byte size for bare ref structured parsing (10 MB).
|
||||
_MAX_BARE_REF_BYTES = 10_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -93,11 +83,6 @@ class FileRef:
|
||||
end_line: int | None # 1-indexed, inclusive
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API (top-down: main functions first, helpers below)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_file_ref(text: str) -> FileRef | None:
|
||||
"""Return a :class:`FileRef` if *text* is a bare file reference token.
|
||||
|
||||
@@ -119,6 +104,17 @@ def parse_file_ref(text: str) -> FileRef | None:
|
||||
return FileRef(uri=m.group(1), start_line=start, end_line=end)
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else len(lines)
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
return "".join(selected)
|
||||
|
||||
|
||||
async def read_file_bytes(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
@@ -134,47 +130,27 @@ async def read_file_bytes(
|
||||
if plain.startswith("workspace://"):
|
||||
if not user_id:
|
||||
raise ValueError("workspace:// file references require authentication")
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
ws = parse_workspace_uri(plain)
|
||||
try:
|
||||
data = await (
|
||||
return await (
|
||||
manager.read_file(ws.file_ref)
|
||||
if ws.is_path
|
||||
else manager.read_file_by_id(ws.file_ref)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except (PermissionError, OSError) as exc:
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
except (AttributeError, TypeError, RuntimeError) as exc:
|
||||
# AttributeError/TypeError: workspace manager returned an
|
||||
# unexpected type or interface; RuntimeError: async runtime issues.
|
||||
logger.warning("Unexpected error reading %s: %s", plain, exc)
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
# NOTE: Workspace API does not support pre-read size checks;
|
||||
# the full file is loaded before the size guard below.
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
|
||||
if is_allowed_local_path(plain, get_sdk_cwd()):
|
||||
resolved = os.path.realpath(os.path.expanduser(plain))
|
||||
try:
|
||||
# Read with a one-byte overshoot to detect files that exceed the limit
|
||||
# without a separate os.path.getsize call (avoids TOCTOU race).
|
||||
with open(resolved, "rb") as fh:
|
||||
data = fh.read(_MAX_BARE_REF_BYTES + 1)
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large (>{_MAX_BARE_REF_BYTES} bytes, "
|
||||
f"limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
return fh.read()
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except OSError as exc:
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
@@ -186,33 +162,9 @@ async def read_file_bytes(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
) from exc
|
||||
try:
|
||||
data = bytes(await sandbox.files.read(remote, format="bytes"))
|
||||
except (FileNotFoundError, OSError, UnicodeDecodeError) as exc:
|
||||
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||
return bytes(await sandbox.files.read(remote, format="bytes"))
|
||||
except Exception as exc:
|
||||
# E2B SDK raises SandboxException subclasses (NotFoundException,
|
||||
# TimeoutException, NotEnoughSpaceException, etc.) which don't
|
||||
# inherit from standard exceptions. Import lazily to avoid a
|
||||
# hard dependency on e2b at module level.
|
||||
try:
|
||||
from e2b.exceptions import SandboxException # noqa: PLC0415
|
||||
|
||||
if isinstance(exc, SandboxException):
|
||||
raise ValueError(
|
||||
f"Failed to read from sandbox: {plain}: {exc}"
|
||||
) from exc
|
||||
except ImportError:
|
||||
pass
|
||||
# Re-raise unexpected exceptions (TypeError, AttributeError, etc.)
|
||||
# so they surface as real bugs rather than being silently masked.
|
||||
raise
|
||||
# NOTE: E2B sandbox API does not support pre-read size checks;
|
||||
# the full file is loaded before the size guard below.
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
@@ -226,13 +178,15 @@ async def resolve_file_ref(
|
||||
) -> str:
|
||||
"""Resolve a :class:`FileRef` to its text content."""
|
||||
raw = await read_file_bytes(ref.uri, user_id, session)
|
||||
return _apply_line_range(_to_str(raw), ref.start_line, ref.end_line)
|
||||
return _apply_line_range(
|
||||
raw.decode("utf-8", errors="replace"), ref.start_line, ref.end_line
|
||||
)
|
||||
|
||||
|
||||
async def expand_file_refs_in_string(
|
||||
text: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
session: "ChatSession",
|
||||
*,
|
||||
raise_on_error: bool = False,
|
||||
) -> str:
|
||||
@@ -278,9 +232,6 @@ async def expand_file_refs_in_string(
|
||||
if len(content) > _MAX_EXPAND_CHARS:
|
||||
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
|
||||
# remaining == 0 means the budget was exactly exhausted by the
|
||||
# previous ref. The elif below (len > remaining) won't catch
|
||||
# this since 0 > 0 is false, so we need the <= 0 check.
|
||||
if remaining <= 0:
|
||||
content = "[file-ref budget exhausted: total expansion limit reached]"
|
||||
elif len(content) > remaining:
|
||||
@@ -301,31 +252,13 @@ async def expand_file_refs_in_string(
|
||||
async def expand_file_refs_in_args(
|
||||
args: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
*,
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
session: "ChatSession",
|
||||
) -> dict[str, Any]:
|
||||
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
|
||||
|
||||
String values are expanded in-place. Nested dicts and lists are
|
||||
traversed. Non-string scalars are returned unchanged.
|
||||
|
||||
**Bare references** (the entire argument value is a single
|
||||
``@@agptfile:...`` token with no surrounding text) are resolved and then
|
||||
parsed according to the file's extension or MIME type. See
|
||||
:mod:`backend.util.file_content_parser` for the full list of supported
|
||||
formats (JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel).
|
||||
|
||||
When *input_schema* is provided and the target property has
|
||||
``"type": "string"``, structured parsing is skipped — the raw file content
|
||||
is returned as a plain string so blocks receive the original text.
|
||||
|
||||
If the format is unrecognised or parsing fails, the content is returned as
|
||||
a plain string (the fallback).
|
||||
|
||||
**Embedded references** (``@@agptfile:`` mixed with other text) always
|
||||
produce a plain string — structured parsing only applies to bare refs.
|
||||
|
||||
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
|
||||
so the tool is *not* executed with an error string as its input. The
|
||||
caller (the MCP tool wrapper) should convert this into an MCP error
|
||||
@@ -334,382 +267,15 @@ async def expand_file_refs_in_args(
|
||||
if not args:
|
||||
return args
|
||||
|
||||
properties = (input_schema or {}).get("properties", {})
|
||||
|
||||
async def _expand(
|
||||
value: Any,
|
||||
*,
|
||||
prop_schema: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Recursively expand a single argument value.
|
||||
|
||||
Strings are checked for ``@@agptfile:`` references and expanded
|
||||
(bare refs get structured parsing; embedded refs get inline
|
||||
substitution). Dicts and lists are traversed recursively,
|
||||
threading the corresponding sub-schema from *prop_schema* so
|
||||
that nested fields also receive correct type-aware expansion.
|
||||
Non-string scalars pass through unchanged.
|
||||
"""
|
||||
async def _expand(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
ref = parse_file_ref(value)
|
||||
if ref is not None:
|
||||
# MediaFileType fields: return the raw URI immediately —
|
||||
# no file reading, no format inference, no content parsing.
|
||||
if _is_media_file_field(prop_schema):
|
||||
return ref.uri
|
||||
|
||||
fmt = infer_format_from_uri(ref.uri)
|
||||
# Workspace URIs by ID (workspace://abc123) have no extension.
|
||||
# When the MIME fragment is also missing, fall back to the
|
||||
# workspace file manager's metadata for format detection.
|
||||
if fmt is None and ref.uri.startswith("workspace://"):
|
||||
fmt = await _infer_format_from_workspace(ref.uri, user_id, session)
|
||||
return await _expand_bare_ref(ref, fmt, user_id, session, prop_schema)
|
||||
|
||||
# Not a bare ref — do normal inline expansion.
|
||||
return await expand_file_refs_in_string(
|
||||
value, user_id, session, raise_on_error=True
|
||||
)
|
||||
if isinstance(value, dict):
|
||||
# When the schema says this is an object but doesn't define
|
||||
# inner properties, skip expansion — the caller (e.g.
|
||||
# RunBlockTool) will expand with the actual nested schema.
|
||||
if (
|
||||
prop_schema is not None
|
||||
and prop_schema.get("type") == "object"
|
||||
and "properties" not in prop_schema
|
||||
):
|
||||
return value
|
||||
nested_props = (prop_schema or {}).get("properties", {})
|
||||
return {
|
||||
k: await _expand(v, prop_schema=nested_props.get(k))
|
||||
for k, v in value.items()
|
||||
}
|
||||
return {k: await _expand(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
items_schema = (prop_schema or {}).get("items")
|
||||
return [await _expand(item, prop_schema=items_schema) for item in value]
|
||||
return [await _expand(item) for item in value]
|
||||
return value
|
||||
|
||||
return {k: await _expand(v, prop_schema=properties.get(k)) for k, v in args.items()}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers (used by the public functions above)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive).
|
||||
|
||||
When the requested range extends beyond the file, a note is appended
|
||||
so the LLM knows it received the entire remaining content.
|
||||
"""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
total = len(lines)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else total
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
result = "".join(selected)
|
||||
if end is not None and end > total:
|
||||
result += f"\n[Note: file has only {total} lines]\n"
|
||||
return result
|
||||
|
||||
|
||||
def _to_str(content: str | bytes) -> str:
|
||||
"""Decode *content* to a string if it is bytes, otherwise return as-is."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return content.decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def _check_content_size(content: str | bytes) -> None:
|
||||
"""Raise :class:`ValueError` if *content* exceeds the byte limit.
|
||||
|
||||
Raises ``ValueError`` (not ``FileRefExpansionError``) so that the caller
|
||||
(``_expand_bare_ref``) can unify all resolution errors into a single
|
||||
``except ValueError`` → ``FileRefExpansionError`` handler, keeping the
|
||||
error-flow consistent with ``read_file_bytes`` and ``resolve_file_ref``.
|
||||
|
||||
For ``bytes``, the length is the byte count directly. For ``str``,
|
||||
we encode to UTF-8 first because multi-byte characters (e.g. emoji)
|
||||
mean the byte size can be up to 4x the character count.
|
||||
"""
|
||||
if isinstance(content, bytes):
|
||||
size = len(content)
|
||||
else:
|
||||
char_len = len(content)
|
||||
# Fast lower bound: UTF-8 byte count >= char count.
|
||||
# If char count already exceeds the limit, reject immediately
|
||||
# without allocating an encoded copy.
|
||||
if char_len > _MAX_BARE_REF_BYTES:
|
||||
size = char_len # real byte size is even larger
|
||||
# Fast upper bound: each char is at most 4 UTF-8 bytes.
|
||||
# If worst-case is still under the limit, skip encoding entirely.
|
||||
elif char_len * 4 <= _MAX_BARE_REF_BYTES:
|
||||
return
|
||||
else:
|
||||
# Edge case: char count is under limit but multibyte chars
|
||||
# might push byte count over. Encode to get exact size.
|
||||
size = len(content.encode("utf-8"))
|
||||
if size > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large for structured parsing "
|
||||
f"({size} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
|
||||
|
||||
async def _infer_format_from_workspace(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> str | None:
|
||||
"""Look up workspace file metadata to infer the format.
|
||||
|
||||
Workspace URIs by ID (``workspace://abc123``) have no file extension.
|
||||
When the MIME fragment is also absent, we query the workspace file
|
||||
manager for the file's stored MIME type and original filename.
|
||||
"""
|
||||
if not user_id:
|
||||
return None
|
||||
try:
|
||||
ws = parse_workspace_uri(uri)
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
info = await (
|
||||
manager.get_file_info(ws.file_ref)
|
||||
if not ws.is_path
|
||||
else manager.get_file_info_by_path(ws.file_ref)
|
||||
)
|
||||
if info is None:
|
||||
return None
|
||||
# Try MIME type first, then filename extension.
|
||||
mime = (info.mime_type or "").split(";", 1)[0].strip().lower()
|
||||
return MIME_TO_FORMAT.get(mime) or infer_format_from_uri(info.name)
|
||||
except (
|
||||
ValueError,
|
||||
FileNotFoundError,
|
||||
OSError,
|
||||
PermissionError,
|
||||
AttributeError,
|
||||
TypeError,
|
||||
):
|
||||
# Expected failures: bad URI, missing file, permission denied, or
|
||||
# workspace manager returning unexpected types. Propagate anything
|
||||
# else (e.g. programming errors) so they don't get silently swallowed.
|
||||
logger.debug("workspace metadata lookup failed for %s", uri, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def _is_media_file_field(prop_schema: dict[str, Any] | None) -> bool:
|
||||
"""Return True if *prop_schema* describes a MediaFileType field (format: file)."""
|
||||
if prop_schema is None:
|
||||
return False
|
||||
return (
|
||||
prop_schema.get("type") == "string"
|
||||
and prop_schema.get("format") == MediaFileType.string_format
|
||||
)
|
||||
|
||||
|
||||
async def _expand_bare_ref(
|
||||
ref: FileRef,
|
||||
fmt: str | None,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
prop_schema: dict[str, Any] | None,
|
||||
) -> Any:
|
||||
"""Resolve and parse a bare ``@@agptfile:`` reference.
|
||||
|
||||
This is the structured-parsing path: the file is read, optionally parsed
|
||||
according to *fmt*, and adapted to the target *prop_schema*.
|
||||
|
||||
Raises :class:`FileRefExpansionError` on resolution or parsing failure.
|
||||
|
||||
Note: MediaFileType fields (format: "file") are handled earlier in
|
||||
``_expand`` to avoid unnecessary format inference and file I/O.
|
||||
"""
|
||||
try:
|
||||
if fmt is not None and fmt in BINARY_FORMATS:
|
||||
# Binary formats need raw bytes, not UTF-8 text.
|
||||
# Line ranges are meaningless for binary formats (parquet/xlsx)
|
||||
# — ignore them and parse full bytes. Warn so the caller/model
|
||||
# knows the range was silently dropped.
|
||||
if ref.start_line is not None or ref.end_line is not None:
|
||||
logger.warning(
|
||||
"Line range [%s-%s] ignored for binary format %s (%s); "
|
||||
"binary formats are always parsed in full.",
|
||||
ref.start_line,
|
||||
ref.end_line,
|
||||
fmt,
|
||||
ref.uri,
|
||||
)
|
||||
content: str | bytes = await read_file_bytes(ref.uri, user_id, session)
|
||||
else:
|
||||
content = await resolve_file_ref(ref, user_id, session)
|
||||
except ValueError as exc:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
|
||||
# For known formats this rejects files >10 MB before parsing.
|
||||
# For unknown formats _MAX_EXPAND_CHARS (200K chars) below is stricter,
|
||||
# but this check still guards the parsing path which has no char limit.
|
||||
# _check_content_size raises ValueError, which we unify here just like
|
||||
# resolution errors above.
|
||||
try:
|
||||
_check_content_size(content)
|
||||
except ValueError as exc:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
|
||||
# When the schema declares this parameter as "string",
|
||||
# return raw file content — don't parse into a structured
|
||||
# type that would need json.dumps() serialisation.
|
||||
expect_string = (prop_schema or {}).get("type") == "string"
|
||||
if expect_string:
|
||||
if isinstance(content, bytes):
|
||||
raise FileRefExpansionError(
|
||||
f"Cannot use {fmt} file as text input: "
|
||||
f"binary formats (parquet, xlsx) must be passed "
|
||||
f"to a block that accepts structured data (list/object), "
|
||||
f"not a string-typed parameter."
|
||||
)
|
||||
return content
|
||||
|
||||
if fmt is not None:
|
||||
# Use strict mode for binary formats so we surface the
|
||||
# actual error (e.g. missing pyarrow/openpyxl, corrupt
|
||||
# file) instead of silently returning garbled bytes.
|
||||
strict = fmt in BINARY_FORMATS
|
||||
try:
|
||||
parsed = parse_file_content(content, fmt, strict=strict)
|
||||
except PARSE_EXCEPTIONS as exc:
|
||||
raise FileRefExpansionError(f"Failed to parse {fmt} file: {exc}") from exc
|
||||
# Normalize bytes fallback to str so tools never
|
||||
# receive raw bytes when parsing fails.
|
||||
if isinstance(parsed, bytes):
|
||||
parsed = _to_str(parsed)
|
||||
return _adapt_to_schema(parsed, prop_schema)
|
||||
|
||||
# Unknown format — return as plain string, but apply
|
||||
# the same per-ref character limit used by inline refs
|
||||
# to prevent injecting unexpectedly large content.
|
||||
text = _to_str(content)
|
||||
if len(text) > _MAX_EXPAND_CHARS:
|
||||
text = text[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
return text
|
||||
|
||||
|
||||
def _adapt_to_schema(parsed: Any, prop_schema: dict[str, Any] | None) -> Any:
|
||||
"""Adapt a parsed file value to better fit the target schema type.
|
||||
|
||||
When the parser returns a natural type (e.g. dict from YAML, list from CSV)
|
||||
that doesn't match the block's expected type, this function converts it to
|
||||
a more useful representation instead of relying on pydantic's generic
|
||||
coercion (which can produce awkward results like flattened dicts → lists).
|
||||
|
||||
Returns *parsed* unchanged when no adaptation is needed.
|
||||
"""
|
||||
if prop_schema is None:
|
||||
return parsed
|
||||
|
||||
target_type = prop_schema.get("type")
|
||||
|
||||
# Dict → array: delegate to helper.
|
||||
if isinstance(parsed, dict) and target_type == "array":
|
||||
return _adapt_dict_to_array(parsed, prop_schema)
|
||||
|
||||
# List → object: delegate to helper (raises for non-tabular lists).
|
||||
if isinstance(parsed, list) and target_type == "object":
|
||||
return _adapt_list_to_object(parsed)
|
||||
|
||||
# Tabular list → Any (no type): convert to list of dicts.
|
||||
# Blocks like FindInDictionaryBlock have `input: Any` which produces
|
||||
# a schema with no "type" key. Tabular [[header],[rows]] is unusable
|
||||
# for key lookup, but [{col: val}, ...] works with FindInDict's
|
||||
# list-of-dicts branch (line 195-199 in data_manipulation.py).
|
||||
if isinstance(parsed, list) and target_type is None and _is_tabular(parsed):
|
||||
return _tabular_to_list_of_dicts(parsed)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def _adapt_dict_to_array(parsed: dict, prop_schema: dict[str, Any]) -> Any:
|
||||
"""Adapt a parsed dict to an array-typed field.
|
||||
|
||||
Extracts list-valued entries when the target item type is ``array``,
|
||||
passes through unchanged when item type is ``string`` (lets pydantic error),
|
||||
or wraps in ``[parsed]`` as a fallback.
|
||||
"""
|
||||
items_type = (prop_schema.get("items") or {}).get("type")
|
||||
if items_type == "array":
|
||||
# Target is List[List[Any]] — extract list-typed values from the
|
||||
# dict as inner lists. E.g. YAML {"fruits": [{...},...]}} with
|
||||
# ConcatenateLists (List[List[Any]]) → [[{...},...]].
|
||||
list_values = [v for v in parsed.values() if isinstance(v, list)]
|
||||
if list_values:
|
||||
return list_values
|
||||
if items_type == "string":
|
||||
# Target is List[str] — wrapping a dict would give [dict]
|
||||
# which can't coerce to strings. Return unchanged and let
|
||||
# pydantic surface a clear validation error.
|
||||
return parsed
|
||||
# Fallback: wrap in a single-element list so the block gets [dict]
|
||||
# instead of pydantic flattening keys/values into a flat list.
|
||||
return [parsed]
|
||||
|
||||
|
||||
def _adapt_list_to_object(parsed: list) -> Any:
|
||||
"""Adapt a parsed list to an object-typed field.
|
||||
|
||||
Converts tabular lists to column-dicts; raises for non-tabular lists.
|
||||
"""
|
||||
if _is_tabular(parsed):
|
||||
return _tabular_to_column_dict(parsed)
|
||||
# Non-tabular list (e.g. a plain Python list from a YAML file) cannot
|
||||
# be meaningfully coerced to an object. Raise explicitly so callers
|
||||
# get a clear error rather than pydantic silently wrapping the list.
|
||||
raise FileRefExpansionError(
|
||||
"Cannot adapt a non-tabular list to an object-typed field. "
|
||||
"Expected a tabular structure ([[header], [row1], ...]) or a dict."
|
||||
)
|
||||
|
||||
|
||||
def _is_tabular(parsed: Any) -> bool:
|
||||
"""Check if parsed data is in tabular format: [[header], [row1], ...].
|
||||
|
||||
Uses isinstance checks because this is a structural type guard on
|
||||
opaque parser output (Any), not duck typing. A Protocol wouldn't
|
||||
help here — we need to verify exact list-of-lists shape.
|
||||
"""
|
||||
if not isinstance(parsed, list) or len(parsed) < 2:
|
||||
return False
|
||||
header = parsed[0]
|
||||
if not isinstance(header, list) or not header:
|
||||
return False
|
||||
if not all(isinstance(h, str) for h in header):
|
||||
return False
|
||||
return all(isinstance(row, list) for row in parsed[1:])
|
||||
|
||||
|
||||
def _tabular_to_list_of_dicts(parsed: list) -> list[dict[str, Any]]:
|
||||
"""Convert [[header], [row1], ...] → [{header[0]: row[0], ...}, ...].
|
||||
|
||||
Ragged rows (fewer columns than the header) get None for missing values.
|
||||
Extra values beyond the header length are silently dropped.
|
||||
"""
|
||||
header = parsed[0]
|
||||
return [
|
||||
dict(itertools.zip_longest(header, row[: len(header)], fillvalue=None))
|
||||
for row in parsed[1:]
|
||||
]
|
||||
|
||||
|
||||
def _tabular_to_column_dict(parsed: list) -> dict[str, list]:
|
||||
"""Convert [[header], [row1], ...] → {"col1": [val1, ...], ...}.
|
||||
|
||||
Ragged rows (fewer columns than the header) get None for missing values,
|
||||
ensuring all columns have equal length.
|
||||
"""
|
||||
header = parsed[0]
|
||||
return {
|
||||
col: [row[i] if i < len(row) else None for row in parsed[1:]]
|
||||
for i, col in enumerate(header)
|
||||
}
|
||||
return {k: await _expand(v) for k, v in args.items()}
|
||||
|
||||
@@ -175,199 +175,6 @@ async def test_expand_args_replaces_file_ref_in_nested_dict():
|
||||
assert result["count"] == 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_args — bare ref structured parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_json_returns_parsed_dict():
|
||||
"""Bare ref to a .json file returns parsed dict, not raw string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "data.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write('{"key": "value", "count": 42}')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{json_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == {"key": "value", "count": 42}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_csv_returns_parsed_table():
|
||||
"""Bare ref to a .csv file returns list[list[str]] table."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
csv_file = os.path.join(sdk_cwd, "data.csv")
|
||||
with open(csv_file, "w") as f:
|
||||
f.write("Name,Score\nAlice,90\nBob,85")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"input": f"@@agptfile:{csv_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["input"] == [
|
||||
["Name", "Score"],
|
||||
["Alice", "90"],
|
||||
["Bob", "85"],
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_unknown_extension_returns_string():
|
||||
"""Bare ref to a file with unknown extension returns plain string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
txt_file = os.path.join(sdk_cwd, "readme.txt")
|
||||
with open(txt_file, "w") as f:
|
||||
f.write("plain text content")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{txt_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == "plain text content"
|
||||
assert isinstance(result["data"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_invalid_json_falls_back_to_string():
|
||||
"""Bare ref to a .json file with invalid JSON falls back to string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "bad.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write("not valid json {{{")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{json_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == "not valid json {{{"
|
||||
assert isinstance(result["data"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedded_ref_always_returns_string_even_for_json():
|
||||
"""Embedded ref (text around it) returns plain string, not parsed JSON."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "data.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write('{"key": "value"}')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"prefix @@agptfile:{json_file} suffix"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert isinstance(result["data"], str)
|
||||
assert result["data"].startswith("prefix ")
|
||||
assert result["data"].endswith(" suffix")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_yaml_returns_parsed_dict():
|
||||
"""Bare ref to a .yaml file returns parsed dict."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
yaml_file = os.path.join(sdk_cwd, "config.yaml")
|
||||
with open(yaml_file, "w") as f:
|
||||
f.write("name: test\ncount: 42\n")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"config": f"@@agptfile:{yaml_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["config"] == {"name": "test", "count": 42}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_binary_with_line_range_ignores_range():
|
||||
"""Bare ref to a binary file (.parquet) with line range parses the full file.
|
||||
|
||||
Binary formats (parquet, xlsx) ignore line ranges — the full content is
|
||||
parsed and the range is silently dropped with a log warning.
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pytest.skip("pandas not installed")
|
||||
try:
|
||||
import pyarrow # noqa: F401 # pyright: ignore[reportMissingImports]
|
||||
except ImportError:
|
||||
pytest.skip("pyarrow not installed")
|
||||
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
parquet_file = os.path.join(sdk_cwd, "data.parquet")
|
||||
import io as _io
|
||||
|
||||
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
|
||||
buf = _io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
with open(parquet_file, "wb") as f:
|
||||
f.write(buf.getvalue())
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
# Line range [1-2] should be silently ignored for binary formats.
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{parquet_file}[1-2]"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
# Full file is returned despite the line range.
|
||||
assert result["data"] == [["A", "B"], [1, 4], [2, 5], [3, 6]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_toml_returns_parsed_dict():
|
||||
"""Bare ref to a .toml file returns parsed dict."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
toml_file = os.path.join(sdk_cwd, "config.toml")
|
||||
with open(toml_file, "w") as f:
|
||||
f.write('name = "test"\ncount = 42\n')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"config": f"@@agptfile:{toml_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["config"] == {"name": "test", "count": 42}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_file_handler — extended to accept workspace:// and local paths
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -412,7 +219,7 @@ async def test_read_file_handler_workspace_uri():
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", mock_session),
|
||||
), patch(
|
||||
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await _read_file_handler(
|
||||
@@ -469,7 +276,7 @@ async def test_read_file_bytes_workspace_virtual_path():
|
||||
mock_manager.read_file.return_value = b"virtual path content"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -20,40 +20,9 @@ Use these URLs directly without asking the user:
|
||||
| Cloudflare | `https://mcp.cloudflare.com/mcp` |
|
||||
| Atlassian / Jira | `https://mcp.atlassian.com/mcp` |
|
||||
|
||||
For other services, search the MCP registry API:
|
||||
```http
|
||||
GET https://registry.modelcontextprotocol.io/v0/servers?q=<search_term>
|
||||
```
|
||||
Each result includes a `remotes` array with the exact server URL to use.
|
||||
|
||||
### Important: Check blocks first
|
||||
|
||||
Before using `run_mcp_tool`, always check if the platform already has blocks for the service
|
||||
using `find_block`. The platform has hundreds of built-in blocks (Google Sheets, Google Docs,
|
||||
Google Calendar, Gmail, etc.) that work without MCP setup.
|
||||
|
||||
Only use `run_mcp_tool` when:
|
||||
- The service is in the known hosted MCP servers list above, OR
|
||||
- You searched `find_block` first and found no matching blocks
|
||||
|
||||
**Never guess or construct MCP server URLs.** Only use URLs from the known servers list above
|
||||
or from the `remotes[].url` field in MCP registry search results.
|
||||
For other services, search the MCP registry at https://registry.modelcontextprotocol.io/.
|
||||
|
||||
### Authentication
|
||||
|
||||
If the server requires credentials, a `SetupRequirementsResponse` is returned with an OAuth
|
||||
login prompt. Once the user completes the flow and confirms, retry the same call immediately.
|
||||
|
||||
### Communication style
|
||||
|
||||
Avoid technical jargon like "MCP server", "OAuth", or "credentials" when talking to the user.
|
||||
Use plain, friendly language instead:
|
||||
|
||||
| Instead of… | Say… |
|
||||
|---|---|
|
||||
| "Let me connect to Sentry's MCP server and discover what tools are available." | "I can connect to Sentry and help identify important issues." |
|
||||
| "Let me connect to Sentry's MCP server now." | "Next, I'll connect to Sentry." |
|
||||
| "The MCP server at mcp.sentry.dev requires authentication. Please connect your credentials to continue." | "To continue, sign in to Sentry and approve access." |
|
||||
| "Sentry's MCP server needs OAuth authentication. You should see a prompt to connect your Sentry account…" | "You should see a prompt to sign in to Sentry. Once connected, I can help surface critical issues right away." |
|
||||
|
||||
Use **"connect to [Service]"** or **"sign in to [Service]"** — never "MCP server", "OAuth", or "credentials".
|
||||
|
||||
@@ -42,7 +42,7 @@ def _validate_workspace_path(
|
||||
Delegates to :func:`is_allowed_local_path` which permits:
|
||||
- The SDK working directory (``/tmp/copilot-<session>/``)
|
||||
- The current session's tool-results directory
|
||||
(``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/``)
|
||||
(``~/.claude/projects/<encoded-cwd>/tool-results/``)
|
||||
"""
|
||||
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
||||
if not path:
|
||||
@@ -127,7 +127,7 @@ def create_security_hooks(
|
||||
user_id: str | None,
|
||||
sdk_cwd: str | None = None,
|
||||
max_subtasks: int = 3,
|
||||
on_compact: Callable[[str], None] | None = None,
|
||||
on_compact: Callable[[], None] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create the security hooks configuration for Claude Agent SDK.
|
||||
|
||||
@@ -142,7 +142,6 @@ def create_security_hooks(
|
||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
||||
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
|
||||
on_compact: Callback invoked when SDK starts compacting context.
|
||||
Receives the transcript_path from the hook input.
|
||||
|
||||
Returns:
|
||||
Hooks configuration dict for ClaudeAgentOptions
|
||||
@@ -302,25 +301,11 @@ def create_security_hooks(
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
# Sanitize untrusted input: strip control chars for logging AND
|
||||
# for the value passed downstream. read_compacted_entries()
|
||||
# validates against _projects_base() as defence-in-depth, but
|
||||
# sanitizing here prevents log injection and rejects obviously
|
||||
# malformed paths early.
|
||||
transcript_path = (
|
||||
str(input_data.get("transcript_path", ""))
|
||||
.replace("\n", "")
|
||||
.replace("\r", "")
|
||||
)
|
||||
logger.info(
|
||||
"[SDK] Context compaction triggered: %s, user=%s, "
|
||||
"transcript_path=%s",
|
||||
trigger,
|
||||
user_id,
|
||||
transcript_path,
|
||||
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
||||
)
|
||||
if on_compact is not None:
|
||||
on_compact(transcript_path)
|
||||
on_compact()
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
hooks: dict[str, Any] = {
|
||||
|
||||
@@ -122,7 +122,7 @@ def test_read_no_cwd_denies_absolute():
|
||||
|
||||
def test_read_tool_results_allowed():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/a1b2c3d4-e5f6-7890-abcd-ef1234567890/tool-results/12345.txt"
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||
# is_allowed_local_path requires the session's encoded cwd to be set
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
|
||||
@@ -10,7 +10,6 @@ import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
@@ -30,7 +29,6 @@ from langfuse import propagate_attributes
|
||||
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -39,7 +37,6 @@ from backend.util.settings import Settings
|
||||
|
||||
from ..config import ChatConfig
|
||||
from ..constants import COPILOT_ERROR_PREFIX, COPILOT_SYSTEM_PREFIX
|
||||
from ..context import encode_cwd_for_cli
|
||||
from ..model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
@@ -65,6 +62,7 @@ from ..service import (
|
||||
)
|
||||
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tools.workspace_files import get_manager
|
||||
from ..tracking import track_user_message
|
||||
from .compaction import CompactionTracker, filter_compaction_messages
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
@@ -77,9 +75,8 @@ from .tool_adapter import (
|
||||
wait_for_stash,
|
||||
)
|
||||
from .transcript import (
|
||||
cleanup_stale_project_dirs,
|
||||
cleanup_cli_project_dir,
|
||||
download_transcript,
|
||||
read_compacted_entries,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
@@ -145,9 +142,6 @@ _background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
|
||||
|
||||
_last_sweep_time: float = 0.0
|
||||
_SWEEP_INTERVAL_SECONDS = 300 # 5 minutes
|
||||
|
||||
# Heartbeat interval — keep SSE alive through proxies/LBs during tool execution.
|
||||
# IMPORTANT: Must be less than frontend timeout (12s in useCopilotPage.ts)
|
||||
_HEARTBEAT_INTERVAL = 10.0 # seconds
|
||||
@@ -286,34 +280,31 @@ def _make_sdk_cwd(session_id: str) -> str:
|
||||
return cwd
|
||||
|
||||
|
||||
async def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||
def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||
"""Remove SDK session artifacts for a specific working directory.
|
||||
|
||||
Cleans up the ephemeral working directory ``/tmp/copilot-<session>/``.
|
||||
|
||||
Also sweeps stale CLI project directories (older than 12 h) to prevent
|
||||
unbounded disk growth. The sweep is best-effort, rate-limited to once
|
||||
every 5 minutes, and capped at 50 directories per sweep.
|
||||
Cleans up:
|
||||
- ``~/.claude/projects/<encoded-cwd>/`` — CLI session transcripts and
|
||||
tool-result files. Each SDK turn uses a unique cwd, so this directory
|
||||
is safe to remove entirely.
|
||||
- ``/tmp/copilot-<session>/`` — the ephemeral working directory.
|
||||
|
||||
Security: *cwd* MUST be created by ``_make_sdk_cwd()`` which sanitizes
|
||||
the session_id.
|
||||
"""
|
||||
normalized = os.path.normpath(cwd)
|
||||
if not normalized.startswith(_SDK_CWD_PREFIX):
|
||||
logger.warning("[SDK] Rejecting cleanup for path outside workspace: %s", cwd)
|
||||
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
|
||||
return
|
||||
|
||||
await asyncio.to_thread(shutil.rmtree, normalized, True)
|
||||
# Clean the CLI's project directory (transcripts + tool-results).
|
||||
cleanup_cli_project_dir(cwd)
|
||||
|
||||
# Best-effort sweep of old project dirs to prevent disk leak.
|
||||
# Pass the encoded cwd so only this session's project directory is swept,
|
||||
# which is safe in multi-tenant environments.
|
||||
global _last_sweep_time
|
||||
now = time.time()
|
||||
if now - _last_sweep_time >= _SWEEP_INTERVAL_SECONDS:
|
||||
_last_sweep_time = now
|
||||
encoded = encode_cwd_for_cli(normalized)
|
||||
await asyncio.to_thread(cleanup_stale_project_dirs, encoded)
|
||||
# Clean up the temp cwd directory itself.
|
||||
try:
|
||||
shutil.rmtree(normalized, ignore_errors=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
|
||||
@@ -573,7 +564,7 @@ async def _prepare_file_attachments(
|
||||
return empty
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to create workspace manager for file attachments",
|
||||
@@ -805,7 +796,7 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
except Exception as transcript_err:
|
||||
logger.warning(
|
||||
"%s Transcript download failed, continuing without --resume: %s",
|
||||
"%s Transcript download failed, continuing without " "--resume: %s",
|
||||
log_prefix,
|
||||
transcript_err,
|
||||
)
|
||||
@@ -828,7 +819,7 @@ async def stream_chat_completion_sdk(
|
||||
is_valid = validate_transcript(dl.content)
|
||||
dl_lines = dl.content.strip().split("\n") if dl.content else []
|
||||
logger.info(
|
||||
"%s Downloaded transcript: %dB, %d lines, msg_count=%d, valid=%s",
|
||||
"%s Downloaded transcript: %dB, %d lines, " "msg_count=%d, valid=%s",
|
||||
log_prefix,
|
||||
len(dl.content),
|
||||
len(dl_lines),
|
||||
@@ -1054,7 +1045,6 @@ async def stream_chat_completion_sdk(
|
||||
exc_info=True,
|
||||
)
|
||||
ended_with_stream_error = True
|
||||
|
||||
yield StreamError(
|
||||
errorText=f"SDK stream error: {stream_err}",
|
||||
code="sdk_stream_error",
|
||||
@@ -1062,7 +1052,8 @@ async def stream_chat_completion_sdk(
|
||||
break
|
||||
|
||||
logger.info(
|
||||
"%s Received: %s %s (unresolved=%d, current=%d, resolved=%d)",
|
||||
"%s Received: %s %s "
|
||||
"(unresolved=%d, current=%d, resolved=%d)",
|
||||
log_prefix,
|
||||
type(sdk_msg).__name__,
|
||||
getattr(sdk_msg, "subtype", ""),
|
||||
@@ -1107,14 +1098,7 @@ async def stream_chat_completion_sdk(
|
||||
and isinstance(sdk_msg, (AssistantMessage, ResultMessage))
|
||||
and not is_parallel_continuation
|
||||
):
|
||||
# 2.0 s timeout: the original 0.5 s caused frequent
|
||||
# timeouts under load (parallel tool calls, large
|
||||
# outputs). 2.0 s gives margin while still failing
|
||||
# fast when the hook genuinely will not fire.
|
||||
if await wait_for_stash(timeout=2.0):
|
||||
# Yield once so any callbacks scheduled by the
|
||||
# stash signal can propagate before we process
|
||||
# the next SDK message.
|
||||
if await wait_for_stash(timeout=0.5):
|
||||
await asyncio.sleep(0)
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -1145,26 +1129,9 @@ async def stream_chat_completion_sdk(
|
||||
sdk_msg.result or "(no error message provided)",
|
||||
)
|
||||
|
||||
# Emit compaction end if SDK finished compacting.
|
||||
# When compaction ends, sync TranscriptBuilder with the
|
||||
# CLI's active context so they stay identical.
|
||||
compact_result = await compaction.emit_end_if_ready(session)
|
||||
for ev in compact_result.events:
|
||||
# Emit compaction end if SDK finished compacting
|
||||
for ev in await compaction.emit_end_if_ready(session):
|
||||
yield ev
|
||||
# After replace_entries, skip append_assistant for this
|
||||
# sdk_msg — the CLI session file already contains it,
|
||||
# so appending again would create a duplicate.
|
||||
entries_replaced = False
|
||||
if compact_result.just_ended:
|
||||
compacted = await asyncio.to_thread(
|
||||
read_compacted_entries,
|
||||
compact_result.transcript_path,
|
||||
)
|
||||
if compacted is not None:
|
||||
transcript_builder.replace_entries(
|
||||
compacted, log_prefix=log_prefix
|
||||
)
|
||||
entries_replaced = True
|
||||
|
||||
for response in adapter.convert_message(sdk_msg):
|
||||
if isinstance(response, StreamStart):
|
||||
@@ -1251,11 +1218,10 @@ async def stream_chat_completion_sdk(
|
||||
tool_call_id=response.toolCallId,
|
||||
)
|
||||
)
|
||||
if not entries_replaced:
|
||||
transcript_builder.append_tool_result(
|
||||
tool_use_id=response.toolCallId,
|
||||
content=content,
|
||||
)
|
||||
transcript_builder.append_tool_result(
|
||||
tool_use_id=response.toolCallId,
|
||||
content=content,
|
||||
)
|
||||
has_tool_results = True
|
||||
|
||||
elif isinstance(response, StreamFinish):
|
||||
@@ -1265,9 +1231,7 @@ async def stream_chat_completion_sdk(
|
||||
# any stashed tool results from the previous turn are
|
||||
# recorded first, preserving the required API order:
|
||||
# assistant(tool_use) → tool_result → assistant(text).
|
||||
# Skip if replace_entries just ran — the CLI session
|
||||
# file already contains this message.
|
||||
if isinstance(sdk_msg, AssistantMessage) and not entries_replaced:
|
||||
if isinstance(sdk_msg, AssistantMessage):
|
||||
transcript_builder.append_assistant(
|
||||
content_blocks=_format_sdk_content_blocks(sdk_msg.content),
|
||||
model=sdk_msg.model,
|
||||
@@ -1458,13 +1422,13 @@ async def stream_chat_completion_sdk(
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# --- Upload transcript for next-turn --resume ---
|
||||
# TranscriptBuilder is the single source of truth. It mirrors the
|
||||
# CLI's active context: on compaction, replace_entries() syncs it
|
||||
# with the compacted session file. No CLI file read needed here.
|
||||
# This MUST run in finally so the transcript is uploaded even when
|
||||
# the streaming loop raises an exception.
|
||||
# The transcript represents the COMPLETE active context (atomic).
|
||||
if config.claude_agent_use_resume and user_id and session is not None:
|
||||
try:
|
||||
# Build complete transcript from captured SDK messages
|
||||
transcript_content = transcript_builder.to_jsonl()
|
||||
entry_count = transcript_builder.entry_count
|
||||
|
||||
if not transcript_content:
|
||||
logger.warning(
|
||||
@@ -1474,15 +1438,18 @@ async def stream_chat_completion_sdk(
|
||||
logger.warning(
|
||||
"%s Transcript invalid, skipping upload (entries=%d)",
|
||||
log_prefix,
|
||||
entry_count,
|
||||
transcript_builder.entry_count,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"%s Uploading transcript (entries=%d, bytes=%d)",
|
||||
"%s Uploading complete transcript (entries=%d, bytes=%d)",
|
||||
log_prefix,
|
||||
entry_count,
|
||||
transcript_builder.entry_count,
|
||||
len(transcript_content),
|
||||
)
|
||||
# Shield upload from cancellation - let it complete even if
|
||||
# the finally block is interrupted. No timeout to avoid race
|
||||
# conditions where backgrounded uploads overwrite newer transcripts.
|
||||
await asyncio.shield(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
@@ -1500,14 +1467,11 @@ async def stream_chat_completion_sdk(
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
if sdk_cwd:
|
||||
await _cleanup_sdk_tool_results(sdk_cwd)
|
||||
except Exception:
|
||||
logger.warning("%s SDK cleanup failed", log_prefix, exc_info=True)
|
||||
finally:
|
||||
# Release stream lock to allow new streams for this session
|
||||
await lock.release()
|
||||
if sdk_cwd:
|
||||
_cleanup_sdk_tool_results(sdk_cwd)
|
||||
|
||||
# Release stream lock to allow new streams for this session
|
||||
await lock.release()
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
|
||||
@@ -20,7 +20,7 @@ class _FakeFileInfo:
|
||||
size_bytes: int
|
||||
|
||||
|
||||
_PATCH_TARGET = "backend.copilot.sdk.service.get_workspace_manager"
|
||||
_PATCH_TARGET = "backend.copilot.sdk.service.get_manager"
|
||||
|
||||
|
||||
class TestPrepareFileAttachments:
|
||||
@@ -288,90 +288,3 @@ class TestPromptSupplement:
|
||||
# Count how many times this tool appears as a bullet point
|
||||
count = docs.count(f"- **`{tool_name}`**")
|
||||
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cleanup_sdk_tool_results — orchestration + rate-limiting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupSdkToolResults:
|
||||
"""Tests for _cleanup_sdk_tool_results orchestration and sweep rate-limiting."""
|
||||
|
||||
# All valid cwds must start with /tmp/copilot- (the _SDK_CWD_PREFIX).
|
||||
_CWD_PREFIX = "/tmp/copilot-"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_removes_cwd_directory(self):
|
||||
"""Cleanup removes the session working directory."""
|
||||
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
cwd = "/tmp/copilot-test-cleanup-remove"
|
||||
os.makedirs(cwd, exist_ok=True)
|
||||
|
||||
with patch("backend.copilot.sdk.service.cleanup_stale_project_dirs"):
|
||||
import backend.copilot.sdk.service as svc_mod
|
||||
|
||||
svc_mod._last_sweep_time = 0.0
|
||||
await _cleanup_sdk_tool_results(cwd)
|
||||
|
||||
assert not os.path.exists(cwd)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sweep_runs_when_interval_elapsed(self):
|
||||
"""cleanup_stale_project_dirs is called when 5-minute interval has elapsed."""
|
||||
|
||||
import backend.copilot.sdk.service as svc_mod
|
||||
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
cwd = "/tmp/copilot-test-sweep-elapsed"
|
||||
os.makedirs(cwd, exist_ok=True)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
|
||||
) as mock_sweep:
|
||||
# Set last sweep to a time far in the past
|
||||
svc_mod._last_sweep_time = 0.0
|
||||
await _cleanup_sdk_tool_results(cwd)
|
||||
|
||||
mock_sweep.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sweep_skipped_within_interval(self):
|
||||
"""cleanup_stale_project_dirs is NOT called when within 5-minute interval."""
|
||||
import time
|
||||
|
||||
import backend.copilot.sdk.service as svc_mod
|
||||
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
cwd = "/tmp/copilot-test-sweep-ratelimit"
|
||||
os.makedirs(cwd, exist_ok=True)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
|
||||
) as mock_sweep:
|
||||
# Set last sweep to now — interval not elapsed
|
||||
svc_mod._last_sweep_time = time.time()
|
||||
await _cleanup_sdk_tool_results(cwd)
|
||||
|
||||
mock_sweep.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_path_outside_prefix(self, tmp_path):
|
||||
"""Cleanup rejects a cwd that does not start with the expected prefix."""
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
evil_cwd = str(tmp_path / "evil-path")
|
||||
os.makedirs(evil_cwd, exist_ok=True)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
|
||||
) as mock_sweep:
|
||||
await _cleanup_sdk_tool_results(evil_cwd)
|
||||
|
||||
# Directory should NOT have been removed (rejected early)
|
||||
assert os.path.exists(evil_cwd)
|
||||
mock_sweep.assert_not_called()
|
||||
|
||||
@@ -146,7 +146,7 @@ def stash_pending_tool_output(tool_name: str, output: Any) -> None:
|
||||
event.set()
|
||||
|
||||
|
||||
async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
async def wait_for_stash(timeout: float = 0.5) -> bool:
|
||||
"""Wait for a PostToolUse hook to stash tool output.
|
||||
|
||||
The SDK fires PostToolUse hooks asynchronously via ``start_soon()`` —
|
||||
@@ -155,12 +155,12 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
by waiting on the ``_stash_event``, which is signaled by
|
||||
:func:`stash_pending_tool_output`.
|
||||
|
||||
Returns ``True`` if a stash signal was received, ``False`` on timeout.
|
||||
After the event fires, callers should ``await asyncio.sleep(0)`` to
|
||||
give any remaining concurrent hooks a chance to complete.
|
||||
|
||||
The 2.0 s default was chosen based on production metrics: the original
|
||||
0.5 s caused frequent timeouts under load (parallel tool calls, large
|
||||
outputs). 2.0 s gives a comfortable margin while still failing fast
|
||||
when the hook genuinely will not fire.
|
||||
Returns ``True`` if a stash signal was received, ``False`` on timeout.
|
||||
The timeout is a safety net — normally the stash happens within
|
||||
microseconds of yielding to the event loop.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is None:
|
||||
@@ -285,7 +285,7 @@ async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
resolved = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(resolved, encoding="utf-8", errors="replace") as f:
|
||||
with open(resolved) as f:
|
||||
selected = list(itertools.islice(f, offset, offset + limit))
|
||||
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
||||
# don't delete here — the SDK may read in multiple chunks.
|
||||
@@ -347,7 +347,7 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
:func:`get_sdk_disallowed_tools`.
|
||||
"""
|
||||
|
||||
def _truncating(fn, tool_name: str, input_schema: dict[str, Any] | None = None):
|
||||
def _truncating(fn, tool_name: str):
|
||||
"""Wrap a tool handler so its response is truncated to stay under the
|
||||
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
|
||||
response adapter before the SDK can apply its own head-truncation.
|
||||
@@ -361,9 +361,7 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
user_id, session = get_execution_context()
|
||||
if session is not None:
|
||||
try:
|
||||
args = await expand_file_refs_in_args(
|
||||
args, user_id, session, input_schema=input_schema
|
||||
)
|
||||
args = await expand_file_refs_in_args(args, user_id, session)
|
||||
except FileRefExpansionError as exc:
|
||||
return _mcp_error(
|
||||
f"@@agptfile: reference could not be resolved: {exc}. "
|
||||
@@ -391,12 +389,11 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handler = create_tool_handler(base_tool)
|
||||
schema = _build_input_schema(base_tool)
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
schema,
|
||||
)(_truncating(handler, tool_name, input_schema=schema))
|
||||
_build_input_schema(base_tool),
|
||||
)(_truncating(handler, tool_name))
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
|
||||
|
||||
@@ -13,10 +13,8 @@ filesystem for self-hosted) — no DB column needed.
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from backend.util import json
|
||||
|
||||
@@ -84,11 +82,7 @@ def strip_progress_entries(content: str) -> str:
|
||||
parent = entry.get("parentUuid", "")
|
||||
if uid:
|
||||
uuid_to_parent[uid] = parent
|
||||
if (
|
||||
entry.get("type", "") in STRIPPABLE_TYPES
|
||||
and uid
|
||||
and not entry.get("isCompactSummary")
|
||||
):
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
|
||||
stripped_uuids.add(uid)
|
||||
|
||||
# Second pass: keep non-stripped entries, reparenting where needed.
|
||||
@@ -112,9 +106,7 @@ def strip_progress_entries(content: str) -> str:
|
||||
if not isinstance(entry, dict):
|
||||
result_lines.append(line)
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
|
||||
"isCompactSummary"
|
||||
):
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES:
|
||||
continue
|
||||
uid = entry.get("uuid", "")
|
||||
if uid in reparented:
|
||||
@@ -145,180 +137,32 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||
|
||||
|
||||
def _projects_base() -> str:
|
||||
"""Return the resolved path to the CLI's projects directory."""
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory.
|
||||
|
||||
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
|
||||
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
|
||||
safe to remove entirely after the transcript has been uploaded.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
|
||||
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
return os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||
|
||||
|
||||
_STALE_PROJECT_DIR_SECONDS = 12 * 3600 # 12 hours — matches max session lifetime
|
||||
_MAX_PROJECT_DIRS_TO_SWEEP = 50 # limit per sweep to avoid long pauses
|
||||
|
||||
|
||||
def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
|
||||
"""Remove CLI project directories older than ``_STALE_PROJECT_DIR_SECONDS``.
|
||||
|
||||
Each CoPilot SDK turn creates a unique ``~/.claude/projects/<encoded-cwd>/``
|
||||
directory. These are intentionally kept across turns so the model can read
|
||||
tool-result files via ``--resume``. However, after a session ends they
|
||||
become stale. This function sweeps old ones to prevent unbounded disk
|
||||
growth.
|
||||
|
||||
When *encoded_cwd* is provided the sweep is scoped to that single
|
||||
directory, making the operation safe in multi-tenant environments where
|
||||
multiple copilot sessions share the same host. Without it the function
|
||||
falls back to sweeping all directories matching the copilot naming pattern
|
||||
(``-tmp-copilot-``), which is only safe for single-tenant deployments.
|
||||
|
||||
Returns the number of directories removed.
|
||||
"""
|
||||
projects_base = _projects_base()
|
||||
if not os.path.isdir(projects_base):
|
||||
return 0
|
||||
|
||||
now = time.time()
|
||||
removed = 0
|
||||
|
||||
# Scoped mode: only clean up the one directory for the current session.
|
||||
if encoded_cwd:
|
||||
target = Path(projects_base) / encoded_cwd
|
||||
if not target.is_dir():
|
||||
return 0
|
||||
# Guard: only sweep copilot-generated dirs.
|
||||
if "-tmp-copilot-" not in target.name:
|
||||
logger.warning(
|
||||
"[Transcript] Refusing to sweep non-copilot dir: %s", target.name
|
||||
)
|
||||
return 0
|
||||
try:
|
||||
# st_mtime is used as a proxy for session activity. Claude CLI writes
|
||||
# its JSONL transcript into this directory during each turn, so mtime
|
||||
# advances on every turn. A directory whose mtime is older than
|
||||
# _STALE_PROJECT_DIR_SECONDS has not had an active turn in that window
|
||||
# and is safe to remove (the session cannot --resume after cleanup).
|
||||
age = now - target.stat().st_mtime
|
||||
except OSError:
|
||||
return 0
|
||||
if age < _STALE_PROJECT_DIR_SECONDS:
|
||||
return 0
|
||||
try:
|
||||
shutil.rmtree(target, ignore_errors=True)
|
||||
removed = 1
|
||||
except OSError:
|
||||
pass
|
||||
if removed:
|
||||
logger.info(
|
||||
"[Transcript] Swept stale CLI project dir %s (age %ds > %ds)",
|
||||
target.name,
|
||||
int(age),
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
)
|
||||
return removed
|
||||
|
||||
# Unscoped fallback: sweep all copilot dirs across the projects base.
|
||||
# Only safe for single-tenant deployments; callers should prefer the
|
||||
# scoped variant by passing encoded_cwd.
|
||||
try:
|
||||
entries = Path(projects_base).iterdir()
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to list projects dir: %s", e)
|
||||
return 0
|
||||
|
||||
for entry in entries:
|
||||
if removed >= _MAX_PROJECT_DIRS_TO_SWEEP:
|
||||
break
|
||||
# Only sweep copilot-generated dirs (pattern: -tmp-copilot- or
|
||||
# -private-tmp-copilot-).
|
||||
if "-tmp-copilot-" not in entry.name:
|
||||
continue
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
try:
|
||||
# See the scoped-mode comment above: st_mtime advances on every turn,
|
||||
# so a stale mtime reliably indicates an inactive session.
|
||||
age = now - entry.stat().st_mtime
|
||||
except OSError:
|
||||
continue
|
||||
if age < _STALE_PROJECT_DIR_SECONDS:
|
||||
continue
|
||||
|
||||
try:
|
||||
shutil.rmtree(entry, ignore_errors=True)
|
||||
removed += 1
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if removed:
|
||||
logger.info(
|
||||
"[Transcript] Swept %d stale CLI project dirs (older than %ds)",
|
||||
removed,
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
)
|
||||
return removed
|
||||
|
||||
|
||||
def read_compacted_entries(transcript_path: str) -> list[dict] | None:
|
||||
"""Read compacted entries from the CLI session file after compaction.
|
||||
|
||||
Parses the JSONL file line-by-line, finds the ``isCompactSummary: true``
|
||||
entry, and returns it plus all entries after it.
|
||||
|
||||
The CLI writes the compaction summary BEFORE sending the next message,
|
||||
so the file is guaranteed to be flushed by the time we read it.
|
||||
|
||||
Returns a list of parsed dicts, or ``None`` if the file cannot be read
|
||||
or no compaction summary is found.
|
||||
"""
|
||||
if not transcript_path:
|
||||
return None
|
||||
|
||||
projects_base = _projects_base()
|
||||
real_path = os.path.realpath(transcript_path)
|
||||
if not real_path.startswith(projects_base + os.sep):
|
||||
if not project_dir.startswith(projects_base + os.sep):
|
||||
logger.warning(
|
||||
"[Transcript] transcript_path outside projects base: %s", transcript_path
|
||||
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
|
||||
)
|
||||
return None
|
||||
return
|
||||
|
||||
try:
|
||||
content = Path(real_path).read_text()
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
"[Transcript] Failed to read session file %s: %s", transcript_path, e
|
||||
)
|
||||
return None
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
compact_idx: int | None = None
|
||||
|
||||
for idx, line in enumerate(lines):
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if entry.get("isCompactSummary"):
|
||||
compact_idx = idx # don't break — find the LAST summary
|
||||
|
||||
if compact_idx is None:
|
||||
logger.debug("[Transcript] No compaction summary found in %s", transcript_path)
|
||||
return None
|
||||
|
||||
entries: list[dict] = []
|
||||
for line in lines[compact_idx:]:
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if isinstance(entry, dict):
|
||||
entries.append(entry)
|
||||
|
||||
logger.info(
|
||||
"[Transcript] Read %d compacted entries from %s (summary at line %d)",
|
||||
len(entries),
|
||||
transcript_path,
|
||||
compact_idx + 1,
|
||||
)
|
||||
return entries
|
||||
if os.path.isdir(project_dir):
|
||||
shutil.rmtree(project_dir, ignore_errors=True)
|
||||
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
|
||||
else:
|
||||
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
|
||||
|
||||
|
||||
def write_transcript_to_tempfile(
|
||||
@@ -415,27 +259,24 @@ def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, s
|
||||
)
|
||||
|
||||
|
||||
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
|
||||
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
|
||||
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
"""Build the full storage path string that ``retrieve()`` expects.
|
||||
|
||||
``store()`` returns a path like ``gcs://bucket/workspaces/...`` or
|
||||
``local://workspace_id/file_id/filename``. Since we use deterministic
|
||||
arguments we can reconstruct the same path for download/delete without
|
||||
having stored the return value.
|
||||
"""
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||
|
||||
wid, fid, fname = parts
|
||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
||||
|
||||
if isinstance(backend, GCSWorkspaceStorage):
|
||||
blob = f"workspaces/{wid}/{fid}/{fname}"
|
||||
return f"gcs://{backend.bucket_name}/{blob}"
|
||||
return f"local://{wid}/{fid}/{fname}"
|
||||
|
||||
|
||||
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
"""Build the full storage path string that ``retrieve()`` expects."""
|
||||
return _build_path_from_parts(_storage_path_parts(user_id, session_id), backend)
|
||||
|
||||
|
||||
def _build_meta_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
"""Build the full storage path for the companion .meta.json file."""
|
||||
return _build_path_from_parts(
|
||||
_meta_storage_path_parts(user_id, session_id), backend
|
||||
)
|
||||
else:
|
||||
# LocalWorkspaceStorage returns local://{relative_path}
|
||||
return f"local://{wid}/{fid}/{fname}"
|
||||
|
||||
|
||||
async def upload_transcript(
|
||||
@@ -540,7 +381,15 @@ async def download_transcript(
|
||||
message_count = 0
|
||||
uploaded_at = 0.0
|
||||
try:
|
||||
meta_path = _build_meta_storage_path(user_id, session_id, storage)
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
if isinstance(storage, GCSWorkspaceStorage):
|
||||
blob = f"workspaces/{mwid}/{mfid}/{mfname}"
|
||||
meta_path = f"gcs://{storage.bucket_name}/{blob}"
|
||||
else:
|
||||
meta_path = f"local://{mwid}/{mfid}/{mfname}"
|
||||
|
||||
meta_data = await storage.retrieve(meta_path)
|
||||
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
@@ -557,11 +406,7 @@ async def download_transcript(
|
||||
|
||||
|
||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
"""Delete transcript and its metadata from bucket storage.
|
||||
|
||||
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
|
||||
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
|
||||
"""
|
||||
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
@@ -569,14 +414,6 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
|
||||
try:
|
||||
await storage.delete(path)
|
||||
logger.info("[Transcript] Deleted transcript for session %s", session_id)
|
||||
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete transcript: %s", e)
|
||||
|
||||
# Also delete the companion .meta.json to avoid orphaned metadata.
|
||||
try:
|
||||
meta_path = _build_meta_storage_path(user_id, session_id, storage)
|
||||
await storage.delete(meta_path)
|
||||
logger.info("[Transcript] Deleted metadata for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete metadata: %s", e)
|
||||
logger.warning(f"[Transcript] Failed to delete transcript: {e}")
|
||||
|
||||
@@ -30,7 +30,6 @@ class TranscriptEntry(BaseModel):
|
||||
type: str
|
||||
uuid: str
|
||||
parentUuid: str | None
|
||||
isCompactSummary: bool | None = None
|
||||
message: dict[str, Any]
|
||||
|
||||
|
||||
@@ -54,24 +53,6 @@ class TranscriptBuilder:
|
||||
return self._entries[-1].message.get("id", "")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _parse_entry(data: dict) -> TranscriptEntry | None:
|
||||
"""Parse a single transcript entry, filtering strippable types.
|
||||
|
||||
Returns ``None`` for entries that should be skipped (strippable types
|
||||
that are not compaction summaries).
|
||||
"""
|
||||
entry_type = data.get("type", "")
|
||||
if entry_type in STRIPPABLE_TYPES and not data.get("isCompactSummary"):
|
||||
return None
|
||||
return TranscriptEntry(
|
||||
type=entry_type,
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
isCompactSummary=data.get("isCompactSummary") or None,
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
|
||||
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
|
||||
"""Load complete previous transcript.
|
||||
|
||||
@@ -97,9 +78,18 @@ class TranscriptBuilder:
|
||||
)
|
||||
continue
|
||||
|
||||
entry = self._parse_entry(data)
|
||||
if entry is None:
|
||||
# Load all non-strippable entries (user/assistant/system/etc.)
|
||||
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
|
||||
entry_type = data.get("type", "")
|
||||
if entry_type in STRIPPABLE_TYPES:
|
||||
continue
|
||||
|
||||
entry = TranscriptEntry(
|
||||
type=data["type"],
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
self._entries.append(entry)
|
||||
self._last_uuid = entry.uuid
|
||||
|
||||
@@ -172,43 +162,6 @@ class TranscriptBuilder:
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def replace_entries(
|
||||
self, compacted_entries: list[dict], log_prefix: str = "[Transcript]"
|
||||
) -> None:
|
||||
"""Replace all entries with compacted entries from the CLI session file.
|
||||
|
||||
Called after mid-stream compaction so TranscriptBuilder mirrors the
|
||||
CLI's active context (compaction summary + post-compaction entries).
|
||||
|
||||
Builds the new list first and validates it's non-empty before swapping,
|
||||
so corrupt input cannot wipe the conversation history.
|
||||
"""
|
||||
new_entries: list[TranscriptEntry] = []
|
||||
for data in compacted_entries:
|
||||
entry = self._parse_entry(data)
|
||||
if entry is not None:
|
||||
new_entries.append(entry)
|
||||
|
||||
if not new_entries:
|
||||
logger.warning(
|
||||
"%s replace_entries produced 0 entries from %d inputs, keeping old (%d entries)",
|
||||
log_prefix,
|
||||
len(compacted_entries),
|
||||
len(self._entries),
|
||||
)
|
||||
return
|
||||
|
||||
old_count = len(self._entries)
|
||||
self._entries = new_entries
|
||||
self._last_uuid = new_entries[-1].uuid
|
||||
|
||||
logger.info(
|
||||
"%s TranscriptBuilder compacted: %d entries -> %d entries",
|
||||
log_prefix,
|
||||
old_count,
|
||||
len(self._entries),
|
||||
)
|
||||
|
||||
def to_jsonl(self) -> str:
|
||||
"""Export complete context as JSONL.
|
||||
|
||||
|
||||
@@ -1,21 +1,15 @@
|
||||
"""Unit tests for JSONL transcript management utilities."""
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import (
|
||||
STRIPPABLE_TYPES,
|
||||
delete_transcript,
|
||||
read_compacted_entries,
|
||||
strip_progress_entries,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
from .transcript_builder import TranscriptBuilder
|
||||
|
||||
|
||||
def _make_jsonl(*entries: dict) -> str:
|
||||
@@ -288,737 +282,3 @@ class TestStripProgressEntries:
|
||||
lines = result.strip().split("\n")
|
||||
asst_entry = json.loads(lines[-1])
|
||||
assert asst_entry["parentUuid"] == "u1" # reparented
|
||||
|
||||
|
||||
# --- delete_transcript ---
|
||||
|
||||
|
||||
class TestDeleteTranscript:
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_both_jsonl_and_meta(self):
|
||||
"""delete_transcript removes both the .jsonl and .meta.json files."""
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.delete = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
await delete_transcript("user-123", "session-456")
|
||||
|
||||
assert mock_storage.delete.call_count == 2
|
||||
paths = [call.args[0] for call in mock_storage.delete.call_args_list]
|
||||
assert any(p.endswith(".jsonl") for p in paths)
|
||||
assert any(p.endswith(".meta.json") for p in paths)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_continues_on_jsonl_delete_failure(self):
|
||||
"""If .jsonl delete fails, .meta.json delete is still attempted."""
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.delete = AsyncMock(
|
||||
side_effect=[Exception("jsonl delete failed"), None]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
# Should not raise
|
||||
await delete_transcript("user-123", "session-456")
|
||||
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_meta_delete_failure(self):
|
||||
"""If .meta.json delete fails, no exception propagates."""
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.delete = AsyncMock(
|
||||
side_effect=[None, Exception("meta delete failed")]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
# Should not raise
|
||||
await delete_transcript("user-123", "session-456")
|
||||
|
||||
|
||||
# --- read_compacted_entries ---
|
||||
|
||||
|
||||
COMPACT_SUMMARY = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "assistant", "content": "compacted context"},
|
||||
}
|
||||
POST_COMPACT_ASST = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "cs1",
|
||||
"message": {"role": "assistant", "content": "response after compaction"},
|
||||
}
|
||||
|
||||
|
||||
class TestReadCompactedEntries:
|
||||
def test_returns_summary_and_entries_after(self, tmp_path, monkeypatch):
|
||||
"""File with isCompactSummary entry returns summary + entries after."""
|
||||
config_dir = tmp_path / "config"
|
||||
projects_dir = config_dir / "projects"
|
||||
session_dir = projects_dir / "proj"
|
||||
session_dir.mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
pre_compact = {"type": "user", "uuid": "u1", "message": {"role": "user"}}
|
||||
path = session_dir / "session.jsonl"
|
||||
path.write_text(_make_jsonl(pre_compact, COMPACT_SUMMARY, POST_COMPACT_ASST))
|
||||
|
||||
result = read_compacted_entries(str(path))
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
assert result[0]["isCompactSummary"] is True
|
||||
assert result[1]["uuid"] == "a2"
|
||||
|
||||
def test_no_compact_summary_returns_none(self, tmp_path, monkeypatch):
|
||||
"""File without isCompactSummary returns None."""
|
||||
config_dir = tmp_path / "config"
|
||||
projects_dir = config_dir / "projects"
|
||||
session_dir = projects_dir / "proj"
|
||||
session_dir.mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
path = session_dir / "session.jsonl"
|
||||
path.write_text(_make_jsonl(USER_MSG, ASST_MSG))
|
||||
|
||||
result = read_compacted_entries(str(path))
|
||||
assert result is None
|
||||
|
||||
def test_file_not_found_returns_none(self, tmp_path, monkeypatch):
|
||||
"""Non-existent file returns None."""
|
||||
config_dir = tmp_path / "config"
|
||||
projects_dir = config_dir / "projects"
|
||||
projects_dir.mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
result = read_compacted_entries(str(projects_dir / "missing.jsonl"))
|
||||
assert result is None
|
||||
|
||||
def test_empty_path_returns_none(self):
|
||||
"""Empty string path returns None."""
|
||||
result = read_compacted_entries("")
|
||||
assert result is None
|
||||
|
||||
def test_malformed_json_lines_skipped(self, tmp_path, monkeypatch):
|
||||
"""Malformed JSON lines are skipped gracefully."""
|
||||
config_dir = tmp_path / "config"
|
||||
projects_dir = config_dir / "projects"
|
||||
session_dir = projects_dir / "proj"
|
||||
session_dir.mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
path = session_dir / "session.jsonl"
|
||||
content = "not valid json\n" + json.dumps(COMPACT_SUMMARY) + "\n"
|
||||
content += "also bad\n" + json.dumps(POST_COMPACT_ASST) + "\n"
|
||||
path.write_text(content)
|
||||
|
||||
result = read_compacted_entries(str(path))
|
||||
assert result is not None
|
||||
assert len(result) == 2 # summary + post-compact assistant
|
||||
|
||||
def test_multiple_compact_summaries_uses_last(self, tmp_path, monkeypatch):
|
||||
"""When multiple isCompactSummary entries exist, uses the last one
|
||||
(most recent compaction)."""
|
||||
config_dir = tmp_path / "config"
|
||||
projects_dir = config_dir / "projects"
|
||||
session_dir = projects_dir / "proj"
|
||||
session_dir.mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
second_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs2",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "assistant", "content": "second summary"},
|
||||
}
|
||||
path = session_dir / "session.jsonl"
|
||||
path.write_text(_make_jsonl(COMPACT_SUMMARY, POST_COMPACT_ASST, second_summary))
|
||||
|
||||
result = read_compacted_entries(str(path))
|
||||
assert result is not None
|
||||
# Last summary found, so only cs2 returned
|
||||
assert len(result) == 1
|
||||
assert result[0]["uuid"] == "cs2"
|
||||
|
||||
def test_path_outside_projects_base_returns_none(self, tmp_path, monkeypatch):
|
||||
"""Transcript path outside the projects directory is rejected."""
|
||||
config_dir = tmp_path / "config"
|
||||
(config_dir / "projects").mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
evil_file = tmp_path / "evil.jsonl"
|
||||
evil_file.write_text(_make_jsonl(COMPACT_SUMMARY))
|
||||
|
||||
result = read_compacted_entries(str(evil_file))
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- TranscriptBuilder.replace_entries ---
|
||||
|
||||
|
||||
class TestTranscriptBuilderReplaceEntries:
|
||||
def test_replaces_existing_entries(self):
|
||||
"""replace_entries replaces all entries with compacted ones."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello")
|
||||
builder.append_assistant([{"type": "text", "text": "world"}])
|
||||
assert builder.entry_count == 2
|
||||
|
||||
compacted = [
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "compacted summary"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "cs1",
|
||||
"message": {"role": "assistant", "content": "response"},
|
||||
},
|
||||
]
|
||||
builder.replace_entries(compacted)
|
||||
assert builder.entry_count == 2
|
||||
output = builder.to_jsonl()
|
||||
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert entries[0]["uuid"] == "cs1"
|
||||
assert entries[1]["uuid"] == "a1"
|
||||
|
||||
def test_filters_strippable_types(self):
|
||||
"""Strippable types are filtered out during replace."""
|
||||
builder = TranscriptBuilder()
|
||||
compacted = [
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "cs1",
|
||||
"message": {"role": "user", "content": "compacted summary"},
|
||||
},
|
||||
{"type": "progress", "uuid": "p1", "message": {}},
|
||||
{"type": "summary", "uuid": "s1", "message": {}},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "cs1",
|
||||
"message": {"role": "assistant", "content": "hi"},
|
||||
},
|
||||
]
|
||||
builder.replace_entries(compacted)
|
||||
assert builder.entry_count == 2 # progress and summary were filtered
|
||||
|
||||
def test_maintains_last_uuid_chain(self):
|
||||
"""After replace, _last_uuid is the last entry's uuid."""
|
||||
builder = TranscriptBuilder()
|
||||
compacted = [
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "cs1",
|
||||
"message": {"role": "user", "content": "compacted summary"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "cs1",
|
||||
"message": {"role": "assistant", "content": "hi"},
|
||||
},
|
||||
]
|
||||
builder.replace_entries(compacted)
|
||||
# Appending a new user message should chain to a1
|
||||
builder.append_user("next question")
|
||||
output = builder.to_jsonl()
|
||||
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert entries[-1]["parentUuid"] == "a1"
|
||||
|
||||
def test_empty_entries_list_keeps_existing(self):
|
||||
"""Replacing with empty list keeps existing entries (safety check)."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello")
|
||||
builder.replace_entries([])
|
||||
# Empty input is treated as corrupt — existing entries preserved
|
||||
assert builder.entry_count == 1
|
||||
assert not builder.is_empty
|
||||
|
||||
|
||||
# --- TranscriptBuilder.load_previous with compacted content ---
|
||||
|
||||
|
||||
class TestTranscriptBuilderLoadPreviousCompacted:
|
||||
def test_preserves_compact_summary_entry(self):
|
||||
"""load_previous preserves isCompactSummary entries even though
|
||||
their type is 'summary' (which is in STRIPPABLE_TYPES)."""
|
||||
compacted_content = _make_jsonl(COMPACT_SUMMARY, POST_COMPACT_ASST)
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(compacted_content)
|
||||
assert builder.entry_count == 2
|
||||
output = builder.to_jsonl()
|
||||
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert entries[0]["type"] == "summary"
|
||||
assert entries[0]["uuid"] == "cs1"
|
||||
assert entries[1]["uuid"] == "a2"
|
||||
|
||||
def test_strips_regular_summary_entries(self):
|
||||
"""Regular summary entries (without isCompactSummary) are still stripped."""
|
||||
regular_summary = {"type": "summary", "uuid": "s1", "message": {"content": "x"}}
|
||||
content = _make_jsonl(regular_summary, POST_COMPACT_ASST)
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(content)
|
||||
assert builder.entry_count == 1 # Only the assistant entry
|
||||
|
||||
|
||||
# --- End-to-end compaction flow (simulates service.py) ---
|
||||
|
||||
|
||||
class TestCompactionFlowIntegration:
|
||||
"""Simulate the full compaction flow as it happens in service.py:
|
||||
|
||||
1. TranscriptBuilder loads a previous transcript (download)
|
||||
2. New messages are appended (user query + assistant response)
|
||||
3. CompactionTracker fires (PreCompact hook → emit_start → emit_end)
|
||||
4. read_compacted_entries reads the CLI session file
|
||||
5. TranscriptBuilder.replace_entries syncs with CLI state
|
||||
6. Final to_jsonl() produces the correct output (upload)
|
||||
"""
|
||||
|
||||
def test_full_compaction_roundtrip(self, tmp_path, monkeypatch):
|
||||
"""Full roundtrip: load → append → compact → replace → export."""
|
||||
# Setup: create a CLI session file with pre-compact + compaction entries
|
||||
config_dir = tmp_path / "config"
|
||||
projects_dir = config_dir / "projects"
|
||||
session_dir = projects_dir / "proj"
|
||||
session_dir.mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
# Simulate a transcript with old messages, then a compaction summary
|
||||
old_user = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "old question"},
|
||||
}
|
||||
old_asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {"role": "assistant", "content": "old answer"},
|
||||
}
|
||||
compact_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "compacted summary of conversation"},
|
||||
}
|
||||
post_compact_asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "cs1",
|
||||
"message": {"role": "assistant", "content": "response after compaction"},
|
||||
}
|
||||
session_file = session_dir / "session.jsonl"
|
||||
session_file.write_text(
|
||||
_make_jsonl(old_user, old_asst, compact_summary, post_compact_asst)
|
||||
)
|
||||
|
||||
# Step 1: TranscriptBuilder loads previous transcript (simulates download)
|
||||
# The previous transcript would have the OLD entries (pre-compaction)
|
||||
previous_transcript = _make_jsonl(old_user, old_asst)
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(previous_transcript)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# Step 2: New messages appended during the current query
|
||||
builder.append_user("new question")
|
||||
builder.append_assistant([{"type": "text", "text": "new answer"}])
|
||||
assert builder.entry_count == 4
|
||||
|
||||
# Step 3: read_compacted_entries reads the CLI session file
|
||||
compacted = read_compacted_entries(str(session_file))
|
||||
assert compacted is not None
|
||||
assert len(compacted) == 2 # compact_summary + post_compact_asst
|
||||
assert compacted[0]["isCompactSummary"] is True
|
||||
|
||||
# Step 4: replace_entries syncs builder with CLI state
|
||||
builder.replace_entries(compacted)
|
||||
assert builder.entry_count == 2 # Only compacted entries now
|
||||
|
||||
# Step 5: Append post-compaction messages (continuing the stream)
|
||||
builder.append_user("follow-up question")
|
||||
assert builder.entry_count == 3
|
||||
|
||||
# Step 6: Export and verify
|
||||
output = builder.to_jsonl()
|
||||
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert len(entries) == 3
|
||||
# First entry is the compaction summary
|
||||
assert entries[0]["type"] == "summary"
|
||||
assert entries[0]["uuid"] == "cs1"
|
||||
# Second is the post-compact assistant
|
||||
assert entries[1]["uuid"] == "a2"
|
||||
# Third is our follow-up, parented to the last compacted entry
|
||||
assert entries[2]["type"] == "user"
|
||||
assert entries[2]["parentUuid"] == "a2"
|
||||
|
||||
def test_compaction_preserves_chain_across_multiple_compactions(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""Two compactions: first compacts old history, second compacts the first."""
|
||||
config_dir = tmp_path / "config"
|
||||
projects_dir = config_dir / "projects"
|
||||
session_dir = projects_dir / "proj"
|
||||
session_dir.mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
# First compaction
|
||||
first_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "first summary"},
|
||||
}
|
||||
mid_asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "cs1",
|
||||
"message": {"role": "assistant", "content": "mid response"},
|
||||
}
|
||||
# Second compaction (compacts the first summary + mid_asst)
|
||||
second_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs2",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "second summary"},
|
||||
}
|
||||
final_asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "cs2",
|
||||
"message": {"role": "assistant", "content": "final response"},
|
||||
}
|
||||
|
||||
session_file = session_dir / "session.jsonl"
|
||||
session_file.write_text(
|
||||
_make_jsonl(first_summary, mid_asst, second_summary, final_asst)
|
||||
)
|
||||
|
||||
# read_compacted_entries should find the LAST summary
|
||||
compacted = read_compacted_entries(str(session_file))
|
||||
assert compacted is not None
|
||||
assert len(compacted) == 2 # second_summary + final_asst
|
||||
assert compacted[0]["uuid"] == "cs2"
|
||||
|
||||
# Apply to builder
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("old stuff")
|
||||
builder.append_assistant([{"type": "text", "text": "old response"}])
|
||||
builder.replace_entries(compacted)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# New message chains correctly
|
||||
builder.append_user("after second compaction")
|
||||
output = builder.to_jsonl()
|
||||
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert entries[-1]["parentUuid"] == "a2"
|
||||
|
||||
def test_strip_progress_preserves_compact_summaries(self):
|
||||
"""strip_progress_entries doesn't strip isCompactSummary entries
|
||||
even though their type is 'summary' (in STRIPPABLE_TYPES)."""
|
||||
compact_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "compacted"},
|
||||
}
|
||||
regular_summary = {"type": "summary", "uuid": "s1", "message": {"content": "x"}}
|
||||
progress = {"type": "progress", "uuid": "p1", "data": {"stdout": "..."}}
|
||||
user = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
}
|
||||
|
||||
content = _make_jsonl(compact_summary, regular_summary, progress, user)
|
||||
stripped = strip_progress_entries(content)
|
||||
stripped_entries = [
|
||||
json.loads(line) for line in stripped.strip().split("\n") if line.strip()
|
||||
]
|
||||
|
||||
uuids = [e.get("uuid") for e in stripped_entries]
|
||||
# compact_summary kept, regular_summary stripped, progress stripped, user kept
|
||||
assert "cs1" in uuids # compact summary preserved
|
||||
assert "s1" not in uuids # regular summary stripped
|
||||
assert "p1" not in uuids # progress stripped
|
||||
assert "u1" in uuids # user kept
|
||||
|
||||
def test_builder_load_then_replace_then_export_roundtrip(self):
|
||||
"""Load a compacted transcript, replace with new compaction, export.
|
||||
Simulates two consecutive turns with compaction each time."""
|
||||
# Turn 1: load compacted transcript
|
||||
compact1 = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "summary v1"},
|
||||
}
|
||||
asst1 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "cs1",
|
||||
"message": {"role": "assistant", "content": "response 1"},
|
||||
}
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(_make_jsonl(compact1, asst1))
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# Turn 1: append new messages
|
||||
builder.append_user("question")
|
||||
builder.append_assistant([{"type": "text", "text": "answer"}])
|
||||
assert builder.entry_count == 4
|
||||
|
||||
# Turn 1: compaction fires — replace with new compacted state
|
||||
compact2 = {
|
||||
"type": "summary",
|
||||
"uuid": "cs2",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "summary v2"},
|
||||
}
|
||||
asst2 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "cs2",
|
||||
"message": {"role": "assistant", "content": "continuing"},
|
||||
}
|
||||
builder.replace_entries([compact2, asst2])
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# Export (this goes to cloud storage for next turn's download)
|
||||
output = builder.to_jsonl()
|
||||
lines = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert lines[0]["uuid"] == "cs2"
|
||||
assert lines[0]["type"] == "summary"
|
||||
assert lines[1]["uuid"] == "a2"
|
||||
|
||||
# Turn 2: fresh builder loads the exported transcript
|
||||
builder2 = TranscriptBuilder()
|
||||
builder2.load_previous(output)
|
||||
assert builder2.entry_count == 2
|
||||
builder2.append_user("turn 2 question")
|
||||
output2 = builder2.to_jsonl()
|
||||
lines2 = [json.loads(line) for line in output2.strip().split("\n")]
|
||||
assert lines2[-1]["parentUuid"] == "a2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cleanup_stale_project_dirs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupStaleProjectDirs:
|
||||
"""Tests for cleanup_stale_project_dirs (disk leak prevention)."""
|
||||
|
||||
def test_removes_old_copilot_dirs(self, tmp_path, monkeypatch):
|
||||
"""Directories matching copilot pattern older than threshold are removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
# Create a stale dir
|
||||
stale = projects_dir / "-tmp-copilot-old-session"
|
||||
stale.mkdir()
|
||||
# Set mtime to past the threshold
|
||||
import time
|
||||
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
os.utime(stale, (old_time, old_time))
|
||||
|
||||
# Create a fresh dir
|
||||
fresh = projects_dir / "-tmp-copilot-new-session"
|
||||
fresh.mkdir()
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 1
|
||||
assert not stale.exists()
|
||||
assert fresh.exists()
|
||||
|
||||
def test_ignores_non_copilot_dirs(self, tmp_path, monkeypatch):
|
||||
"""Directories not matching copilot pattern are left alone."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
# Non-copilot dir that's old
|
||||
import time
|
||||
|
||||
other = projects_dir / "some-other-project"
|
||||
other.mkdir()
|
||||
old_time = time.time() - 999999
|
||||
os.utime(other, (old_time, old_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
assert other.exists()
|
||||
|
||||
def test_ttl_boundary_not_removed(self, tmp_path, monkeypatch):
|
||||
"""A directory exactly at the TTL boundary should NOT be removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
# Dir that's exactly at the TTL (age == threshold, not >) — should survive
|
||||
boundary = projects_dir / "-tmp-copilot-boundary"
|
||||
boundary.mkdir()
|
||||
boundary_time = time.time() - _STALE_PROJECT_DIR_SECONDS + 1
|
||||
os.utime(boundary, (boundary_time, boundary_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
assert boundary.exists()
|
||||
|
||||
def test_skips_non_directory_entries(self, tmp_path, monkeypatch):
|
||||
"""Regular files matching the copilot pattern are not removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
# Create a regular FILE (not a dir) with the copilot pattern name
|
||||
stale_file = projects_dir / "-tmp-copilot-stale-file"
|
||||
stale_file.write_text("not a dir")
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
os.utime(stale_file, (old_time, old_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
assert stale_file.exists()
|
||||
|
||||
def test_missing_base_dir_returns_zero(self, tmp_path, monkeypatch):
|
||||
"""If the projects base directory doesn't exist, return 0 gracefully."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
|
||||
nonexistent = str(tmp_path / "does-not-exist" / "projects")
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: nonexistent,
|
||||
)
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
|
||||
def test_scoped_removes_only_target_dir(self, tmp_path, monkeypatch):
|
||||
"""When encoded_cwd is supplied only that directory is swept."""
|
||||
import time
|
||||
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
|
||||
# Two stale copilot dirs
|
||||
target = projects_dir / "-tmp-copilot-session-abc"
|
||||
target.mkdir()
|
||||
os.utime(target, (old_time, old_time))
|
||||
|
||||
other = projects_dir / "-tmp-copilot-session-xyz"
|
||||
other.mkdir()
|
||||
os.utime(other, (old_time, old_time))
|
||||
|
||||
# Only the target dir should be removed
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="-tmp-copilot-session-abc")
|
||||
assert removed == 1
|
||||
assert not target.exists()
|
||||
assert other.exists() # untouched — not the current session
|
||||
|
||||
def test_scoped_fresh_dir_not_removed(self, tmp_path, monkeypatch):
|
||||
"""Scoped sweep leaves a fresh directory alone."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
fresh = projects_dir / "-tmp-copilot-session-new"
|
||||
fresh.mkdir()
|
||||
# mtime is now — well within TTL
|
||||
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="-tmp-copilot-session-new")
|
||||
assert removed == 0
|
||||
assert fresh.exists()
|
||||
|
||||
def test_scoped_non_copilot_dir_not_removed(self, tmp_path, monkeypatch):
|
||||
"""Scoped sweep refuses to remove a non-copilot directory."""
|
||||
import time
|
||||
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
non_copilot = projects_dir / "some-other-project"
|
||||
non_copilot.mkdir()
|
||||
os.utime(non_copilot, (old_time, old_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="some-other-project")
|
||||
assert removed == 0
|
||||
assert non_copilot.exists()
|
||||
|
||||
@@ -23,11 +23,6 @@ from typing import Any, Literal
|
||||
|
||||
import orjson
|
||||
|
||||
from backend.api.model import CopilotCompletionPayload
|
||||
from backend.data.notification_bus import (
|
||||
AsyncRedisNotificationEventBus,
|
||||
NotificationEvent,
|
||||
)
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from .config import ChatConfig
|
||||
@@ -43,7 +38,6 @@ from .response_model import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
_notification_bus = AsyncRedisNotificationEventBus()
|
||||
|
||||
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
||||
_local_sessions: dict[str, asyncio.Task] = {}
|
||||
@@ -751,29 +745,6 @@ async def mark_session_completed(
|
||||
|
||||
# Clean up local session reference if exists
|
||||
_local_sessions.pop(session_id, None)
|
||||
|
||||
# Publish copilot completion notification via WebSocket
|
||||
if meta:
|
||||
parsed = _parse_session_meta(meta, session_id)
|
||||
if parsed.user_id:
|
||||
try:
|
||||
await _notification_bus.publish(
|
||||
NotificationEvent(
|
||||
user_id=parsed.user_id,
|
||||
payload=CopilotCompletionPayload(
|
||||
type="copilot_completion",
|
||||
event="session_completed",
|
||||
session_id=session_id,
|
||||
status=status,
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to publish copilot completion notification "
|
||||
f"for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ import shutil
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.util.request import validate_url_host
|
||||
|
||||
@@ -44,6 +43,7 @@ from .models import (
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .workspace_files import get_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -194,7 +194,7 @@ async def _save_browser_state(
|
||||
),
|
||||
}
|
||||
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
await manager.write_file(
|
||||
content=json.dumps(state).encode("utf-8"),
|
||||
filename=_STATE_FILENAME,
|
||||
@@ -218,7 +218,7 @@ async def _restore_browser_state(
|
||||
Returns True on success (or no state to restore), False on failure.
|
||||
"""
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
|
||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||
if file_info is None:
|
||||
@@ -360,7 +360,7 @@ async def close_browser_session(session_name: str, user_id: str | None = None) -
|
||||
# Delete persisted browser state (cookies, localStorage) from workspace.
|
||||
if user_id:
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_name)
|
||||
manager = await get_manager(user_id, session_name)
|
||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||
if file_info is not None:
|
||||
await manager.delete_file(file_info.id)
|
||||
|
||||
@@ -897,7 +897,7 @@ class TestHasLocalSession:
|
||||
# _save_browser_state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_workspace_manager"
|
||||
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_manager"
|
||||
|
||||
|
||||
def _make_mock_manager():
|
||||
|
||||
@@ -829,12 +829,8 @@ class AgentFixer:
|
||||
|
||||
For nodes whose block has category "AI", this function ensures that the
|
||||
input_default has a "model" parameter set to one of the allowed models.
|
||||
If missing or set to an unsupported value, it is replaced with the
|
||||
appropriate default.
|
||||
|
||||
Blocks that define their own ``enum`` constraint on the ``model`` field
|
||||
in their inputSchema (e.g. PerplexityBlock) are validated against that
|
||||
enum instead of the generic allowed set.
|
||||
If missing or set to an unsupported value, it is replaced with
|
||||
default_model.
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to fix
|
||||
@@ -844,7 +840,7 @@ class AgentFixer:
|
||||
Returns:
|
||||
The fixed agent dictionary
|
||||
"""
|
||||
generic_allowed_models = {"gpt-4o", "claude-opus-4-6"}
|
||||
allowed_models = {"gpt-4o", "claude-opus-4-6"}
|
||||
|
||||
# Create a mapping of block_id to block for quick lookup
|
||||
block_map = {block.get("id"): block for block in blocks}
|
||||
@@ -872,36 +868,20 @@ class AgentFixer:
|
||||
input_default = node.get("input_default", {})
|
||||
current_model = input_default.get("model")
|
||||
|
||||
# Determine allowed models and default from the block's schema.
|
||||
# Blocks with a block-specific enum on the model field (e.g.
|
||||
# PerplexityBlock) use their own enum values; others use the
|
||||
# generic set.
|
||||
model_schema = (
|
||||
block.get("inputSchema", {}).get("properties", {}).get("model", {})
|
||||
)
|
||||
block_model_enum = model_schema.get("enum")
|
||||
|
||||
if block_model_enum:
|
||||
allowed_models = set(block_model_enum)
|
||||
fallback_model = model_schema.get("default", block_model_enum[0])
|
||||
else:
|
||||
allowed_models = generic_allowed_models
|
||||
fallback_model = default_model
|
||||
|
||||
if current_model not in allowed_models:
|
||||
block_name = block.get("name", "Unknown AI Block")
|
||||
if current_model is None:
|
||||
self.add_fix_log(
|
||||
f"Added model parameter '{fallback_model}' to AI "
|
||||
f"Added model parameter '{default_model}' to AI "
|
||||
f"block node {node_id} ({block_name})"
|
||||
)
|
||||
else:
|
||||
self.add_fix_log(
|
||||
f"Replaced unsupported model '{current_model}' "
|
||||
f"with '{fallback_model}' on AI block node "
|
||||
f"with '{default_model}' on AI block node "
|
||||
f"{node_id} ({block_name})"
|
||||
)
|
||||
input_default["model"] = fallback_model
|
||||
input_default["model"] = default_model
|
||||
node["input_default"] = input_default
|
||||
fixed_count += 1
|
||||
|
||||
|
||||
@@ -475,111 +475,6 @@ class TestFixAiModelParameter:
|
||||
|
||||
assert result["nodes"][0]["input_default"]["model"] == "claude-opus-4-6"
|
||||
|
||||
def test_block_specific_enum_uses_block_default(self):
|
||||
"""Blocks with their own model enum (e.g. PerplexityBlock) should use
|
||||
the block's allowed models and default, not the generic ones."""
|
||||
fixer = AgentFixer()
|
||||
block_id = generate_uuid()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=block_id,
|
||||
input_default={"model": "gpt-5.2-2025-12-11"},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
blocks = [
|
||||
{
|
||||
"id": block_id,
|
||||
"name": "PerplexityBlock",
|
||||
"categories": [{"category": "AI"}],
|
||||
"inputSchema": {
|
||||
"properties": {
|
||||
"model": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"perplexity/sonar",
|
||||
"perplexity/sonar-pro",
|
||||
"perplexity/sonar-deep-research",
|
||||
],
|
||||
"default": "perplexity/sonar",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
result = fixer.fix_ai_model_parameter(agent, blocks)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["model"] == "perplexity/sonar"
|
||||
|
||||
def test_block_specific_enum_valid_model_unchanged(self):
|
||||
"""A valid block-specific model should not be replaced."""
|
||||
fixer = AgentFixer()
|
||||
block_id = generate_uuid()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=block_id,
|
||||
input_default={"model": "perplexity/sonar-pro"},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
blocks = [
|
||||
{
|
||||
"id": block_id,
|
||||
"name": "PerplexityBlock",
|
||||
"categories": [{"category": "AI"}],
|
||||
"inputSchema": {
|
||||
"properties": {
|
||||
"model": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"perplexity/sonar",
|
||||
"perplexity/sonar-pro",
|
||||
"perplexity/sonar-deep-research",
|
||||
],
|
||||
"default": "perplexity/sonar",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
result = fixer.fix_ai_model_parameter(agent, blocks)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["model"] == "perplexity/sonar-pro"
|
||||
|
||||
def test_block_specific_enum_missing_model_gets_block_default(self):
|
||||
"""Missing model on a block with enum should use the block's default."""
|
||||
fixer = AgentFixer()
|
||||
block_id = generate_uuid()
|
||||
node = _make_node(node_id="n1", block_id=block_id, input_default={})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
blocks = [
|
||||
{
|
||||
"id": block_id,
|
||||
"name": "PerplexityBlock",
|
||||
"categories": [{"category": "AI"}],
|
||||
"inputSchema": {
|
||||
"properties": {
|
||||
"model": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"perplexity/sonar",
|
||||
"perplexity/sonar-pro",
|
||||
"perplexity/sonar-deep-research",
|
||||
],
|
||||
"default": "perplexity/sonar",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
result = fixer.fix_ai_model_parameter(agent, blocks)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["model"] == "perplexity/sonar"
|
||||
|
||||
|
||||
class TestFixAgentExecutorBlocks:
|
||||
"""Tests for fix_agent_executor_blocks."""
|
||||
|
||||
@@ -935,5 +935,5 @@ class AgentValidator:
|
||||
for i, error in enumerate(self.errors, 1):
|
||||
error_message += f"{i}. {error}\n"
|
||||
|
||||
logger.warning(f"Agent validation failed: {error_message}")
|
||||
logger.error(f"Agent validation failed: {error_message}")
|
||||
return False, error_message
|
||||
|
||||
@@ -21,11 +21,9 @@ Lifecycle
|
||||
Cost control
|
||||
------------
|
||||
Sandboxes are created with a configurable ``on_timeout`` lifecycle action
|
||||
(default: ``"pause"``) and ``auto_resume`` (default: ``True``). The explicit
|
||||
per-turn ``pause_sandbox()`` call is the primary mechanism; the lifecycle
|
||||
timeout is a safety net (default: 5 min). ``auto_resume`` ensures that paused
|
||||
sandboxes wake transparently on SDK activity, making the aggressive safety-net
|
||||
timeout safe. Paused sandboxes are free.
|
||||
(default: ``"pause"``). The explicit per-turn ``pause_sandbox()`` call is the
|
||||
primary mechanism; the lifecycle setting is a safety net. Paused sandboxes are
|
||||
free.
|
||||
|
||||
The sandbox_id is stored in Redis. The same key doubles as a creation lock:
|
||||
a ``"creating"`` sentinel value is written with a short TTL while a new sandbox
|
||||
@@ -42,7 +40,6 @@ import logging
|
||||
from typing import Any, Awaitable, Callable, Literal
|
||||
|
||||
from e2b import AsyncSandbox
|
||||
from e2b.sandbox.sandbox_api import SandboxLifecycle
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
@@ -119,10 +116,9 @@ async def get_or_create_sandbox(
|
||||
removes the need for a separate lock key.
|
||||
|
||||
*timeout* controls how long the e2b sandbox may run continuously before
|
||||
the ``on_timeout`` lifecycle rule fires (default: 5 min).
|
||||
the ``on_timeout`` lifecycle rule fires (default: 3 h).
|
||||
*on_timeout* controls what happens on timeout: ``"pause"`` (default, free)
|
||||
or ``"kill"``. When ``"pause"``, ``auto_resume`` is enabled so paused
|
||||
sandboxes wake transparently on SDK activity.
|
||||
or ``"kill"``.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = _sandbox_key(session_id)
|
||||
@@ -160,15 +156,11 @@ async def get_or_create_sandbox(
|
||||
|
||||
# We hold the slot — create the sandbox.
|
||||
try:
|
||||
lifecycle = SandboxLifecycle(
|
||||
on_timeout=on_timeout,
|
||||
auto_resume=on_timeout == "pause",
|
||||
)
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
lifecycle=lifecycle,
|
||||
lifecycle={"on_timeout": on_timeout},
|
||||
)
|
||||
try:
|
||||
await _set_stored_sandbox_id(session_id, sandbox.sandbox_id)
|
||||
|
||||
@@ -157,17 +157,14 @@ class TestGetOrCreateSandbox:
|
||||
|
||||
assert result is new_sb
|
||||
mock_cls.create.assert_awaited_once()
|
||||
# Verify lifecycle: pause + auto_resume enabled
|
||||
# Verify lifecycle param is set
|
||||
_, kwargs = mock_cls.create.call_args
|
||||
assert kwargs.get("lifecycle") == {
|
||||
"on_timeout": "pause",
|
||||
"auto_resume": True,
|
||||
}
|
||||
assert kwargs.get("lifecycle") == {"on_timeout": "pause"}
|
||||
# sandbox_id should be saved to Redis
|
||||
redis.set.assert_awaited()
|
||||
|
||||
def test_create_with_on_timeout_kill(self):
|
||||
"""on_timeout='kill' disables auto_resume automatically."""
|
||||
"""on_timeout='kill' is passed through to AsyncSandbox.create."""
|
||||
new_sb = _mock_sandbox("sb-new")
|
||||
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
|
||||
with (
|
||||
@@ -182,10 +179,7 @@ class TestGetOrCreateSandbox:
|
||||
)
|
||||
|
||||
_, kwargs = mock_cls.create.call_args
|
||||
assert kwargs.get("lifecycle") == {
|
||||
"on_timeout": "kill",
|
||||
"auto_resume": False,
|
||||
}
|
||||
assert kwargs.get("lifecycle") == {"on_timeout": "kill"}
|
||||
|
||||
def test_create_failure_releases_slot(self):
|
||||
"""If sandbox creation fails, the Redis creation slot is deleted."""
|
||||
|
||||
@@ -13,7 +13,6 @@ from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
from backend.util.type import coerce_inputs_to_schema
|
||||
|
||||
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
|
||||
from .utils import match_credentials_to_requirements
|
||||
@@ -112,9 +111,6 @@ async def execute_block(
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Coerce non-matching data types to the expected input schema.
|
||||
coerce_inputs_to_schema(input_data, block.input_schema)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
|
||||
@@ -1,333 +0,0 @@
|
||||
"""Tests for execute_block type coercion in helpers.py.
|
||||
|
||||
Verifies that execute_block() coerces string input values to match the block's
|
||||
expected input types, mirroring the executor's validate_exec() logic.
|
||||
This is critical for @@agptfile: expansion, where file content is always a string
|
||||
but the block may expect structured types (e.g. list[list[str]]).
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools.helpers import execute_block
|
||||
from backend.copilot.tools.models import BlockOutputResponse
|
||||
|
||||
|
||||
def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
|
||||
"""Create a mock input_schema with model_fields matching the given annotations."""
|
||||
schema = MagicMock()
|
||||
# coerce_inputs_to_schema uses model_fields (Pydantic v2 API)
|
||||
model_fields = {}
|
||||
for name, ann in annotations.items():
|
||||
field = MagicMock()
|
||||
field.annotation = ann
|
||||
model_fields[name] = field
|
||||
schema.model_fields = model_fields
|
||||
return schema
|
||||
|
||||
|
||||
def _make_block(
|
||||
block_id: str,
|
||||
name: str,
|
||||
annotations: dict[str, Any],
|
||||
outputs: dict[str, list[Any]] | None = None,
|
||||
) -> MagicMock:
|
||||
"""Create a mock block with typed annotations and a simple execute method."""
|
||||
block = MagicMock()
|
||||
block.id = block_id
|
||||
block.name = name
|
||||
block.input_schema = _make_block_schema(annotations)
|
||||
|
||||
captured_inputs: dict[str, Any] = {}
|
||||
|
||||
async def mock_execute(input_data: dict, **_kwargs: Any):
|
||||
captured_inputs.update(input_data)
|
||||
for output_name, values in (outputs or {"result": ["ok"]}).items():
|
||||
for v in values:
|
||||
yield output_name, v
|
||||
|
||||
block.execute = mock_execute
|
||||
block._captured_inputs = captured_inputs
|
||||
return block
|
||||
|
||||
|
||||
_TEST_SESSION_ID = "test-session-coerce"
|
||||
_TEST_USER_ID = "test-user-coerce"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_nested_list():
|
||||
"""JSON string → list[list[str]] (Google Sheets CSV import case)."""
|
||||
block = _make_block(
|
||||
"sheets-write",
|
||||
"Google Sheets Write",
|
||||
{"values": list[list[str]], "spreadsheet_id": str},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="sheets-write",
|
||||
input_data={
|
||||
"values": '[["Name","Score"],["Alice","90"],["Bob","85"]]',
|
||||
"spreadsheet_id": "abc123",
|
||||
},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
# Verify the input was coerced from string to list[list[str]]
|
||||
assert block._captured_inputs["values"] == [
|
||||
["Name", "Score"],
|
||||
["Alice", "90"],
|
||||
["Bob", "85"],
|
||||
]
|
||||
assert isinstance(block._captured_inputs["values"], list)
|
||||
assert isinstance(block._captured_inputs["values"][0], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_list():
|
||||
"""JSON string → list[str]."""
|
||||
block = _make_block(
|
||||
"list-block",
|
||||
"List Block",
|
||||
{"items": list[str]},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="list-block",
|
||||
input_data={"items": '["a","b","c"]'},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-2",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert block._captured_inputs["items"] == ["a", "b", "c"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_dict():
|
||||
"""JSON string → dict[str, str]."""
|
||||
block = _make_block(
|
||||
"dict-block",
|
||||
"Dict Block",
|
||||
{"config": dict[str, str]},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="dict-block",
|
||||
input_data={"config": '{"key": "value", "foo": "bar"}'},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-3",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert block._captured_inputs["config"] == {"key": "value", "foo": "bar"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_coercion_when_type_matches():
|
||||
"""Already-correct types pass through without coercion."""
|
||||
block = _make_block(
|
||||
"pass-through",
|
||||
"Pass Through",
|
||||
{"values": list[list[str]], "name": str},
|
||||
)
|
||||
|
||||
original_values = [["a", "b"], ["c", "d"]]
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="pass-through",
|
||||
input_data={"values": original_values, "name": "test"},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-4",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert block._captured_inputs["values"] == original_values
|
||||
assert block._captured_inputs["name"] == "test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_string_to_int():
|
||||
"""String number → int."""
|
||||
block = _make_block(
|
||||
"int-block",
|
||||
"Int Block",
|
||||
{"count": int},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="int-block",
|
||||
input_data={"count": "42"},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-5",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert block._captured_inputs["count"] == 42
|
||||
assert isinstance(block._captured_inputs["count"], int)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_skips_none_values():
|
||||
"""None values are not coerced (they may be optional fields)."""
|
||||
block = _make_block(
|
||||
"optional-block",
|
||||
"Optional Block",
|
||||
{"data": list[str], "label": str},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="optional-block",
|
||||
input_data={"label": "test"},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-6",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
# 'data' was not provided, so it should not appear in captured inputs
|
||||
assert "data" not in block._captured_inputs
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_union_type_preserves_valid_member():
|
||||
"""Union-typed fields should not be coerced when the value matches a member."""
|
||||
block = _make_block(
|
||||
"union-block",
|
||||
"Union Block",
|
||||
{"content": str | list[str]},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="union-block",
|
||||
input_data={"content": ["a", "b"]},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-7",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
# list[str] should NOT be stringified to '["a", "b"]'
|
||||
assert block._captured_inputs["content"] == ["a", "b"]
|
||||
assert isinstance(block._captured_inputs["content"], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_inner_elements_of_generic():
|
||||
"""Inner elements of generic containers are recursively coerced."""
|
||||
block = _make_block(
|
||||
"inner-coerce",
|
||||
"Inner Coerce",
|
||||
{"values": list[str]},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="inner-coerce",
|
||||
# Inner elements are ints, but target is list[str]
|
||||
input_data={"values": [1, 2, 3]},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-8",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
# Inner elements should be coerced from int to str
|
||||
assert block._captured_inputs["values"] == ["1", "2", "3"]
|
||||
assert all(isinstance(v, str) for v in block._captured_inputs["values"])
|
||||
@@ -12,7 +12,6 @@ from backend.copilot.constants import (
|
||||
COPILOT_SESSION_PREFIX,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError, expand_file_refs_in_args
|
||||
from backend.data.db_accessors import review_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
@@ -198,29 +197,6 @@ class RunBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Expand @@agptfile: refs in input_data with the block's input
|
||||
# schema. The generic _truncating wrapper skips opaque object
|
||||
# properties (input_data has no declared inner properties in the
|
||||
# tool schema), so file ref tokens are still intact here.
|
||||
# Using the block's schema lets us return raw text for string-typed
|
||||
# fields and parsed structures for list/dict-typed fields.
|
||||
if input_data:
|
||||
try:
|
||||
input_data = await expand_file_refs_in_args(
|
||||
input_data,
|
||||
user_id,
|
||||
session,
|
||||
input_schema=input_schema,
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Failed to resolve file reference: {exc}. "
|
||||
"Ensure the file exists before referencing it."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
# Return setup requirements response with missing credentials
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
|
||||
@@ -34,11 +34,6 @@ logger = logging.getLogger(__name__)
|
||||
_AUTH_STATUS_CODES = {401, 403}
|
||||
|
||||
|
||||
def _service_name(host: str) -> str:
|
||||
"""Strip the 'mcp.' prefix from an MCP hostname: 'mcp.sentry.dev' → 'sentry.dev'"""
|
||||
return host[4:] if host.startswith("mcp.") else host
|
||||
|
||||
|
||||
class RunMCPToolTool(BaseTool):
|
||||
"""
|
||||
Tool for discovering and executing tools on any MCP server.
|
||||
@@ -184,12 +179,10 @@ class RunMCPToolTool(BaseTool):
|
||||
if e.status_code in _AUTH_STATUS_CODES and not creds:
|
||||
# Server requires auth and user has no stored credentials
|
||||
return self._build_setup_requirements(server_url, session_id)
|
||||
host = server_host(server_url)
|
||||
logger.warning("MCP HTTP error for %s: status=%s", host, e.status_code)
|
||||
logger.warning("MCP HTTP error for %s: %s", server_host(server_url), e)
|
||||
return ErrorResponse(
|
||||
message=(f"MCP request to {host} failed with HTTP {e.status_code}."),
|
||||
message=f"MCP server returned HTTP {e.status_code}: {e}",
|
||||
session_id=session_id,
|
||||
error=f"HTTP {e.status_code}: {str(e)[:300]}",
|
||||
)
|
||||
|
||||
except MCPClientError as e:
|
||||
@@ -310,8 +303,8 @@ class RunMCPToolTool(BaseTool):
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Unable to connect to {_service_name(server_host(server_url))} "
|
||||
"— no credentials configured."
|
||||
f"The MCP server at {server_host(server_url)} requires authentication, "
|
||||
"but no credential configuration was found."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -319,13 +312,15 @@ class RunMCPToolTool(BaseTool):
|
||||
missing_creds_list = list(missing_creds_dict.values())
|
||||
|
||||
host = server_host(server_url)
|
||||
service = _service_name(host)
|
||||
return SetupRequirementsResponse(
|
||||
message=(f"To continue, sign in to {service} and approve access."),
|
||||
message=(
|
||||
f"The MCP server at {host} requires authentication. "
|
||||
"Please connect your credentials to continue."
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=server_url,
|
||||
agent_name=service,
|
||||
agent_name=f"MCP: {host}",
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_creds_dict,
|
||||
|
||||
@@ -580,49 +580,6 @@ async def test_auth_error_with_existing_creds_returns_error():
|
||||
assert "403" in response.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_http_error_returns_clean_message_with_collapsible_detail():
|
||||
"""Non-auth HTTP errors return a clean message with raw detail in the `error` field."""
|
||||
from backend.util.request import HTTPClientError
|
||||
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError(
|
||||
"<!doctype html><html><body>Not Found</body></html>",
|
||||
status_code=404,
|
||||
)
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "404" in response.message
|
||||
# Raw HTML body must NOT leak into the user-facing message
|
||||
assert "<!doctype" not in response.message
|
||||
# Raw detail (including original body) goes in the collapsible `error` field
|
||||
assert response.error is not None
|
||||
assert "404" in response.error
|
||||
assert "<!doctype" in response.error.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_mcp_client_error_returns_error_response():
|
||||
"""MCPClientError (protocol-level) maps to a clean ErrorResponse."""
|
||||
@@ -799,4 +756,4 @@ async def test_build_setup_requirements_returns_setup_response():
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.setup_info.agent_id == _SERVER_URL
|
||||
assert "sign in" in result.message.lower()
|
||||
assert "authentication" in result.message.lower()
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -11,13 +10,11 @@ from pydantic import BaseModel
|
||||
from backend.copilot.context import (
|
||||
E2B_WORKDIR,
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
get_workspace_manager,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
@@ -27,10 +24,6 @@ from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Sentinel file_id used when a tool-result file is read directly from the local
|
||||
# host filesystem (rather than from workspace storage).
|
||||
_LOCAL_TOOL_RESULT_FILE_ID = "local"
|
||||
|
||||
|
||||
async def _resolve_write_content(
|
||||
content_text: str | None,
|
||||
@@ -225,6 +218,12 @@ def _is_text_mime(mime_type: str) -> bool:
|
||||
return any(mime_type.startswith(t) for t in _TEXT_MIME_PREFIXES)
|
||||
|
||||
|
||||
async def get_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped WorkspaceManager."""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
|
||||
async def _resolve_file(
|
||||
manager: WorkspaceManager,
|
||||
file_id: str | None,
|
||||
@@ -282,93 +281,6 @@ class WorkspaceFileContentResponse(ToolResponseBase):
|
||||
content_base64: str
|
||||
|
||||
|
||||
_MAX_LOCAL_TOOL_RESULT_BYTES = 10 * 1024 * 1024 # 10 MB
|
||||
|
||||
|
||||
def _read_local_tool_result(
|
||||
path: str,
|
||||
char_offset: int,
|
||||
char_length: Optional[int],
|
||||
session_id: str,
|
||||
sdk_cwd: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""Read an SDK tool-result file from local disk.
|
||||
|
||||
This is a fallback for when the model mistakenly calls
|
||||
``read_workspace_file`` with an SDK tool-result path that only exists on
|
||||
the host filesystem, not in cloud workspace storage.
|
||||
|
||||
Defence-in-depth: validates *path* via :func:`is_allowed_local_path`
|
||||
regardless of what the caller has already checked.
|
||||
"""
|
||||
# TOCTOU: path validated then opened separately. Acceptable because
|
||||
# the tool-results directory is server-controlled, not user-writable.
|
||||
expanded = os.path.realpath(os.path.expanduser(path))
|
||||
# Defence-in-depth: re-check with resolved path (caller checked raw path).
|
||||
if not is_allowed_local_path(expanded, sdk_cwd or get_sdk_cwd()):
|
||||
return ErrorResponse(
|
||||
message=f"Path not allowed: {os.path.basename(path)}", session_id=session_id
|
||||
)
|
||||
try:
|
||||
# The 10 MB cap (_MAX_LOCAL_TOOL_RESULT_BYTES) bounds memory usage.
|
||||
# Pre-read size check prevents loading files far above the cap;
|
||||
# the remaining TOCTOU gap is acceptable for server-controlled paths.
|
||||
file_size = os.path.getsize(expanded)
|
||||
if file_size > _MAX_LOCAL_TOOL_RESULT_BYTES:
|
||||
return ErrorResponse(
|
||||
message=(f"File too large: {os.path.basename(path)}"),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Detect binary files: try strict UTF-8 first, fall back to
|
||||
# base64-encoding the raw bytes for binary content.
|
||||
with open(expanded, "rb") as fh:
|
||||
raw = fh.read()
|
||||
try:
|
||||
text_content = raw.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# Binary file — return raw base64, ignore char_offset/char_length
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=_LOCAL_TOOL_RESULT_FILE_ID,
|
||||
name=os.path.basename(path),
|
||||
path=path,
|
||||
mime_type=mimetypes.guess_type(path)[0] or "application/octet-stream",
|
||||
content_base64=base64.b64encode(raw).decode("ascii"),
|
||||
message=(
|
||||
f"Read {file_size:,} bytes (binary) from local tool-result "
|
||||
f"{os.path.basename(path)}"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
end = (
|
||||
char_offset + char_length if char_length is not None else len(text_content)
|
||||
)
|
||||
slice_text = text_content[char_offset:end]
|
||||
except FileNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {os.path.basename(path)}", session_id=session_id
|
||||
)
|
||||
except Exception as exc:
|
||||
return ErrorResponse(
|
||||
message=f"Error reading file: {type(exc).__name__}", session_id=session_id
|
||||
)
|
||||
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=_LOCAL_TOOL_RESULT_FILE_ID,
|
||||
name=os.path.basename(path),
|
||||
path=path,
|
||||
mime_type=mimetypes.guess_type(path)[0] or "text/plain",
|
||||
content_base64=base64.b64encode(slice_text.encode("utf-8")).decode("ascii"),
|
||||
message=(
|
||||
f"Read chars {char_offset}\u2013{char_offset + len(slice_text)} "
|
||||
f"of {len(text_content):,} chars from local tool-result "
|
||||
f"{os.path.basename(path)}"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceFileMetadataResponse(ToolResponseBase):
|
||||
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
|
||||
|
||||
@@ -474,7 +386,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
files = await manager.list_files(
|
||||
path=path_prefix, limit=limit, include_all_sessions=include_all_sessions
|
||||
)
|
||||
@@ -624,17 +536,9 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
# Fallback: if the path is an SDK tool-result on local disk,
|
||||
# read it directly instead of failing. The model sometimes
|
||||
# calls read_workspace_file for these paths by mistake.
|
||||
sdk_cwd = get_sdk_cwd()
|
||||
if path and is_allowed_local_path(path, sdk_cwd):
|
||||
return _read_local_tool_result(
|
||||
path, char_offset, char_length, session_id, sdk_cwd=sdk_cwd
|
||||
)
|
||||
return resolved
|
||||
target_file_id, file_info = resolved
|
||||
|
||||
@@ -868,7 +772,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
|
||||
try:
|
||||
await scan_content_safe(content, filename=filename)
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
rec = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
@@ -995,7 +899,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
|
||||
@@ -2,25 +2,18 @@
|
||||
|
||||
import base64
|
||||
import os
|
||||
import shutil
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import SDK_PROJECTS_DIR, _current_project_dir
|
||||
from backend.copilot.tools._test_data import make_session, setup_test_data
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
from backend.copilot.tools.workspace_files import (
|
||||
_MAX_LOCAL_TOOL_RESULT_BYTES,
|
||||
DeleteWorkspaceFileTool,
|
||||
ListWorkspaceFilesTool,
|
||||
ReadWorkspaceFileTool,
|
||||
WorkspaceDeleteResponse,
|
||||
WorkspaceFileContentResponse,
|
||||
WorkspaceFileListResponse,
|
||||
WorkspaceWriteResponse,
|
||||
WriteWorkspaceFileTool,
|
||||
_read_local_tool_result,
|
||||
_resolve_write_content,
|
||||
_validate_ephemeral_path,
|
||||
)
|
||||
@@ -332,294 +325,3 @@ async def test_write_workspace_file_source_path(setup_test_data):
|
||||
await delete_tool._execute(
|
||||
user_id=user.id, session=session, file_id=write_resp.file_id
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_local_tool_result — local disk fallback for SDK tool-result files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CONV_UUID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
|
||||
|
||||
class TestReadLocalToolResult:
|
||||
"""Tests for _read_local_tool_result (local disk fallback)."""
|
||||
|
||||
def _make_tool_result(self, encoded: str, filename: str, content: bytes) -> str:
|
||||
"""Create a tool-results file and return its path."""
|
||||
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, _CONV_UUID, "tool-results")
|
||||
os.makedirs(tool_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_dir, filename)
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(content)
|
||||
return filepath
|
||||
|
||||
def _cleanup(self, encoded: str) -> None:
|
||||
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
|
||||
|
||||
def test_read_text_file(self):
|
||||
"""Read a UTF-8 text tool-result file."""
|
||||
encoded = "-tmp-copilot-local-read-text"
|
||||
path = self._make_tool_result(encoded, "output.txt", b"hello world")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == "hello world"
|
||||
assert "text/plain" in result.mime_type
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_read_text_with_offset(self):
|
||||
"""Read a slice of a text file using char_offset and char_length."""
|
||||
encoded = "-tmp-copilot-local-read-offset"
|
||||
path = self._make_tool_result(encoded, "data.txt", b"ABCDEFGHIJ")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 3, 4, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == "DEFG"
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_read_binary_file(self):
|
||||
"""Binary files are returned as raw base64."""
|
||||
encoded = "-tmp-copilot-local-read-binary"
|
||||
binary_data = bytes(range(256))
|
||||
path = self._make_tool_result(encoded, "image.png", binary_data)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64)
|
||||
assert decoded == binary_data
|
||||
assert "binary" in result.message
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_disallowed_path_rejected(self):
|
||||
"""Paths not under allowed directories are rejected."""
|
||||
result = _read_local_tool_result("/etc/passwd", 0, None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "not allowed" in result.message.lower()
|
||||
|
||||
def test_file_not_found(self):
|
||||
"""Missing files return an error."""
|
||||
encoded = "-tmp-copilot-local-read-missing"
|
||||
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, _CONV_UUID, "tool-results")
|
||||
os.makedirs(tool_dir, exist_ok=True)
|
||||
path = os.path.join(tool_dir, "nope.txt")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "not found" in result.message.lower()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_file_too_large(self, monkeypatch):
|
||||
"""Files exceeding the size limit are rejected."""
|
||||
encoded = "-tmp-copilot-local-read-large"
|
||||
# Create a small file but fake os.path.getsize to return a huge value
|
||||
path = self._make_tool_result(encoded, "big.txt", b"small")
|
||||
token = _current_project_dir.set(encoded)
|
||||
monkeypatch.setattr(
|
||||
"os.path.getsize", lambda _: _MAX_LOCAL_TOOL_RESULT_BYTES + 1
|
||||
)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "too large" in result.message.lower()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_offset_beyond_file_length(self):
|
||||
"""Offset past end-of-file returns empty content."""
|
||||
encoded = "-tmp-copilot-local-read-past-eof"
|
||||
path = self._make_tool_result(encoded, "short.txt", b"abc")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 999, 10, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == ""
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_zero_length_read(self):
|
||||
"""Requesting zero characters returns empty content."""
|
||||
encoded = "-tmp-copilot-local-read-zero-len"
|
||||
path = self._make_tool_result(encoded, "data.txt", b"ABCDEF")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 2, 0, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == ""
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_mime_type_from_json_extension(self):
|
||||
"""JSON files get application/json MIME type, not hardcoded text/plain."""
|
||||
encoded = "-tmp-copilot-local-read-json"
|
||||
path = self._make_tool_result(encoded, "result.json", b'{"key": "value"}')
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
assert result.mime_type == "application/json"
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_mime_type_from_png_extension(self):
|
||||
"""Binary .png files get image/png MIME type via mimetypes."""
|
||||
encoded = "-tmp-copilot-local-read-png-mime"
|
||||
binary_data = bytes(range(256))
|
||||
path = self._make_tool_result(encoded, "chart.png", binary_data)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
assert result.mime_type == "image/png"
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_explicit_sdk_cwd_parameter(self):
|
||||
"""The sdk_cwd parameter overrides get_sdk_cwd() for path validation."""
|
||||
encoded = "-tmp-copilot-local-read-sdkcwd"
|
||||
path = self._make_tool_result(encoded, "out.txt", b"content")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
# Pass sdk_cwd explicitly — should still succeed because the path
|
||||
# is under SDK_PROJECTS_DIR which is always allowed.
|
||||
result = _read_local_tool_result(
|
||||
path, 0, None, "s1", sdk_cwd="/tmp/copilot-test"
|
||||
)
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == "content"
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_offset_with_no_length_reads_to_end(self):
|
||||
"""When char_length is None, read from offset to end of file."""
|
||||
encoded = "-tmp-copilot-local-read-offset-noLen"
|
||||
path = self._make_tool_result(encoded, "data.txt", b"0123456789")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 5, None, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == "56789"
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ReadWorkspaceFileTool fallback to _read_local_tool_result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_read_workspace_file_falls_back_to_local_tool_result(setup_test_data):
|
||||
"""When _resolve_file returns ErrorResponse for an allowed local path,
|
||||
ReadWorkspaceFileTool should fall back to _read_local_tool_result."""
|
||||
user = setup_test_data["user"]
|
||||
session = make_session(user.id)
|
||||
|
||||
# Create a real tool-result file on disk so the fallback can read it.
|
||||
encoded = "-tmp-copilot-fallback-test"
|
||||
conv_uuid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-results")
|
||||
os.makedirs(tool_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_dir, "result.txt")
|
||||
with open(filepath, "w") as f:
|
||||
f.write("fallback content")
|
||||
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
# Mock _resolve_file to return an ErrorResponse (simulating "file not
|
||||
# found in workspace") so the fallback branch is exercised.
|
||||
mock_resolve = AsyncMock(
|
||||
return_value=ErrorResponse(
|
||||
message="File not found at path: result.txt",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
)
|
||||
with patch("backend.copilot.tools.workspace_files._resolve_file", mock_resolve):
|
||||
read_tool = ReadWorkspaceFileTool()
|
||||
result = await read_tool._execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
path=filepath,
|
||||
)
|
||||
|
||||
# Should have fallen back to _read_local_tool_result and succeeded.
|
||||
assert isinstance(result, WorkspaceFileContentResponse), (
|
||||
f"Expected fallback to local read, got {type(result).__name__}: "
|
||||
f"{getattr(result, 'message', '')}"
|
||||
)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == "fallback content"
|
||||
mock_resolve.assert_awaited_once()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_read_workspace_file_no_fallback_when_resolve_succeeds(setup_test_data):
|
||||
"""When _resolve_file succeeds, the local-disk fallback must NOT be invoked."""
|
||||
user = setup_test_data["user"]
|
||||
session = make_session(user.id)
|
||||
|
||||
fake_file_id = "fake-file-id-001"
|
||||
fake_content = b"workspace content"
|
||||
|
||||
# Build a minimal file_info stub that the tool's happy-path needs.
|
||||
class _FakeFileInfo:
|
||||
id = fake_file_id
|
||||
name = "result.json"
|
||||
path = "/result.json"
|
||||
mime_type = "text/plain"
|
||||
size_bytes = len(fake_content)
|
||||
|
||||
mock_resolve = AsyncMock(return_value=(fake_file_id, _FakeFileInfo()))
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.read_file_by_id = AsyncMock(return_value=fake_content)
|
||||
|
||||
with (
|
||||
patch("backend.copilot.tools.workspace_files._resolve_file", mock_resolve),
|
||||
patch(
|
||||
"backend.copilot.tools.workspace_files.get_workspace_manager",
|
||||
AsyncMock(return_value=mock_manager),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.workspace_files._read_local_tool_result"
|
||||
) as patched_local,
|
||||
):
|
||||
read_tool = ReadWorkspaceFileTool()
|
||||
result = await read_tool._execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
file_id=fake_file_id,
|
||||
)
|
||||
|
||||
# Fallback must not have been called.
|
||||
patched_local.assert_not_called()
|
||||
# Normal workspace path must have produced a content response.
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
assert base64.b64decode(result.content_base64) == fake_content
|
||||
|
||||
@@ -100,31 +100,19 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_120B: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_20B: 1,
|
||||
LlmModel.GEMINI_2_5_PRO_PREVIEW: 4,
|
||||
LlmModel.GEMINI_2_5_PRO: 4,
|
||||
LlmModel.GEMINI_3_1_PRO_PREVIEW: 5,
|
||||
LlmModel.GEMINI_3_FLASH_PREVIEW: 2,
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: 5,
|
||||
LlmModel.GEMINI_2_5_FLASH: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH: 1,
|
||||
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: 1,
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
|
||||
LlmModel.MISTRAL_NEMO: 1,
|
||||
LlmModel.MISTRAL_LARGE_3: 2,
|
||||
LlmModel.MISTRAL_MEDIUM_3_1: 2,
|
||||
LlmModel.MISTRAL_SMALL_3_2: 1,
|
||||
LlmModel.CODESTRAL: 1,
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: 1,
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
|
||||
LlmModel.COHERE_COMMAND_A_03_2025: 3,
|
||||
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: 3,
|
||||
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: 6,
|
||||
LlmModel.COHERE_COMMAND_A_VISION_07_2025: 3,
|
||||
LlmModel.DEEPSEEK_CHAT: 2,
|
||||
LlmModel.DEEPSEEK_R1_0528: 1,
|
||||
LlmModel.PERPLEXITY_SONAR: 1,
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: 5,
|
||||
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: 5,
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
|
||||
@@ -132,7 +120,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: 1,
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
|
||||
LlmModel.MICROSOFT_PHI_4: 1,
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
|
||||
LlmModel.META_LLAMA_4_SCOUT: 1,
|
||||
LlmModel.META_LLAMA_4_MAVERICK: 1,
|
||||
@@ -140,7 +127,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: 1,
|
||||
LlmModel.LLAMA_API_LLAMA3_3_8B: 1,
|
||||
LlmModel.LLAMA_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.GROK_3: 3,
|
||||
LlmModel.GROK_4: 9,
|
||||
LlmModel.GROK_4_FAST: 1,
|
||||
LlmModel.GROK_4_1_FAST: 1,
|
||||
|
||||
@@ -1,750 +0,0 @@
|
||||
import asyncio
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
from prisma.errors import UniqueViolationError
|
||||
from pydantic import BaseModel, EmailStr, TypeAdapter, ValidationError
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.data.model import User
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.tally import get_business_understanding_input_from_tally, mask_email
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
merge_business_understanding_data,
|
||||
)
|
||||
from backend.data.user import get_user_by_email, get_user_by_id
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import (
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
PreconditionFailed,
|
||||
)
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_settings = Settings()
|
||||
|
||||
_WORKER_ID = f"{socket.gethostname()}:{os.getpid()}"
|
||||
|
||||
_tally_seed_tasks: dict[str, asyncio.Task] = {}
|
||||
_TALLY_STALE_SECONDS = 300
|
||||
_MAX_TALLY_ERROR_LENGTH = 200
|
||||
_email_adapter = TypeAdapter(EmailStr)
|
||||
|
||||
MAX_BULK_INVITE_FILE_BYTES = 1024 * 1024
|
||||
MAX_BULK_INVITE_ROWS = 500
|
||||
|
||||
|
||||
class InvitedUserRecord(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
status: prisma.enums.InvitedUserStatus
|
||||
auth_user_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tally_understanding: Optional[dict[str, Any]] = None
|
||||
tally_status: prisma.enums.TallyComputationStatus
|
||||
tally_computed_at: Optional[datetime] = None
|
||||
tally_error: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, invited_user: "prisma.models.InvitedUser") -> "InvitedUserRecord":
|
||||
payload = (
|
||||
invited_user.tallyUnderstanding
|
||||
if isinstance(invited_user.tallyUnderstanding, dict)
|
||||
else None
|
||||
)
|
||||
return cls(
|
||||
id=invited_user.id,
|
||||
email=invited_user.email,
|
||||
status=invited_user.status,
|
||||
auth_user_id=invited_user.authUserId,
|
||||
name=invited_user.name,
|
||||
tally_understanding=payload,
|
||||
tally_status=invited_user.tallyStatus,
|
||||
tally_computed_at=invited_user.tallyComputedAt,
|
||||
tally_error=invited_user.tallyError,
|
||||
created_at=invited_user.createdAt,
|
||||
updated_at=invited_user.updatedAt,
|
||||
)
|
||||
|
||||
|
||||
class BulkInvitedUserRowResult(BaseModel):
|
||||
row_number: int
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||
message: str
|
||||
invited_user: Optional[InvitedUserRecord] = None
|
||||
|
||||
|
||||
class BulkInvitedUsersResult(BaseModel):
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
error_count: int
|
||||
results: list[BulkInvitedUserRowResult]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ParsedInviteRow:
|
||||
row_number: int
|
||||
email: str
|
||||
name: Optional[str]
|
||||
|
||||
|
||||
def normalize_email(email: str) -> str:
|
||||
return email.strip().lower()
|
||||
|
||||
|
||||
def _normalize_name(name: Optional[str]) -> Optional[str]:
|
||||
if name is None:
|
||||
return None
|
||||
normalized = name.strip()
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _default_profile_name(email: str, preferred_name: Optional[str]) -> str:
|
||||
if preferred_name:
|
||||
return preferred_name
|
||||
local_part = email.split("@", 1)[0].strip()
|
||||
return local_part or "user"
|
||||
|
||||
|
||||
def _sanitize_username_base(email: str) -> str:
|
||||
local_part = email.split("@", 1)[0].lower()
|
||||
sanitized = re.sub(r"[^a-z0-9-]", "", local_part)
|
||||
sanitized = sanitized.strip("-")
|
||||
return sanitized[:40] or "user"
|
||||
|
||||
|
||||
async def _generate_unique_profile_username(email: str, tx) -> str:
|
||||
base = _sanitize_username_base(email)
|
||||
|
||||
for _ in range(2):
|
||||
candidate = f"{base}-{uuid4().hex[:6]}"
|
||||
existing = await prisma.models.Profile.prisma(tx).find_unique(
|
||||
where={"username": candidate}
|
||||
)
|
||||
if existing is None:
|
||||
return candidate
|
||||
|
||||
raise RuntimeError(f"Unable to generate unique username for {email}")
|
||||
|
||||
|
||||
async def _ensure_default_profile(
|
||||
user_id: str,
|
||||
email: str,
|
||||
preferred_name: Optional[str],
|
||||
tx,
|
||||
) -> None:
|
||||
existing_profile = await prisma.models.Profile.prisma(tx).find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
if existing_profile is not None:
|
||||
return
|
||||
|
||||
username = await _generate_unique_profile_username(email, tx)
|
||||
await prisma.models.Profile.prisma(tx).create(
|
||||
data=prisma.types.ProfileCreateInput(
|
||||
userId=user_id,
|
||||
name=_default_profile_name(email, preferred_name),
|
||||
username=username,
|
||||
description="I'm new here",
|
||||
links=[],
|
||||
avatarUrl="",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _ensure_default_onboarding(user_id: str, tx) -> None:
|
||||
await prisma.models.UserOnboarding.prisma(tx).upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": prisma.types.UserOnboardingCreateInput(userId=user_id),
|
||||
"update": {},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _apply_tally_understanding(
|
||||
user_id: str,
|
||||
invited_user: "prisma.models.InvitedUser",
|
||||
tx,
|
||||
) -> None:
|
||||
if not isinstance(invited_user.tallyUnderstanding, dict):
|
||||
return
|
||||
|
||||
try:
|
||||
input_data = BusinessUnderstandingInput.model_validate(
|
||||
invited_user.tallyUnderstanding
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Malformed tallyUnderstanding for invited user %s; skipping",
|
||||
invited_user.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
payload = merge_business_understanding_data({}, input_data)
|
||||
await prisma.models.CoPilotUnderstanding.prisma(tx).upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "data": SafeJson(payload)},
|
||||
"update": {"data": SafeJson(payload)},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def list_invited_users(
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> tuple[list[InvitedUserRecord], int]:
|
||||
total = await prisma.models.InvitedUser.prisma().count()
|
||||
invited_users = await prisma.models.InvitedUser.prisma().find_many(
|
||||
order={"createdAt": "desc"},
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
return [InvitedUserRecord.from_db(iu) for iu in invited_users], total
|
||||
|
||||
|
||||
async def create_invited_user(
|
||||
email: str, name: Optional[str] = None
|
||||
) -> InvitedUserRecord:
|
||||
normalized_email = normalize_email(email)
|
||||
normalized_name = _normalize_name(name)
|
||||
|
||||
existing_user = await prisma.models.User.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if existing_user is not None:
|
||||
raise PreconditionFailed("An active user with this email already exists")
|
||||
|
||||
existing_invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if existing_invited_user is not None:
|
||||
raise PreconditionFailed("An invited user with this email already exists")
|
||||
|
||||
try:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().create(
|
||||
data={
|
||||
"email": normalized_email,
|
||||
"name": normalized_name,
|
||||
"status": prisma.enums.InvitedUserStatus.INVITED,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
raise PreconditionFailed("An invited user with this email already exists")
|
||||
schedule_invited_user_tally_precompute(invited_user.id)
|
||||
return InvitedUserRecord.from_db(invited_user)
|
||||
|
||||
|
||||
async def revoke_invited_user(invited_user_id: str) -> InvitedUserRecord:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.CLAIMED:
|
||||
raise PreconditionFailed("Claimed invited users cannot be revoked")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
return InvitedUserRecord.from_db(invited_user)
|
||||
|
||||
revoked_user = await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={"status": prisma.enums.InvitedUserStatus.REVOKED},
|
||||
)
|
||||
if revoked_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
return InvitedUserRecord.from_db(revoked_user)
|
||||
|
||||
|
||||
async def retry_invited_user_tally(invited_user_id: str) -> InvitedUserRecord:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
raise PreconditionFailed("Revoked invited users cannot retry Tally seeding")
|
||||
|
||||
refreshed_user = await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyUnderstanding": None,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
|
||||
"tallyComputedAt": None,
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
if refreshed_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
schedule_invited_user_tally_precompute(invited_user_id)
|
||||
return InvitedUserRecord.from_db(refreshed_user)
|
||||
|
||||
|
||||
def _decode_bulk_invite_file(content: bytes) -> str:
|
||||
if len(content) > MAX_BULK_INVITE_FILE_BYTES:
|
||||
raise ValueError("Invite file exceeds the maximum size of 1 MB")
|
||||
|
||||
try:
|
||||
return content.decode("utf-8-sig")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise ValueError("Invite file must be UTF-8 encoded") from exc
|
||||
|
||||
|
||||
def _parse_bulk_invite_csv(text: str) -> list[_ParsedInviteRow]:
|
||||
indexed_rows: list[tuple[int, list[str]]] = []
|
||||
|
||||
for row_number, row in enumerate(csv.reader(io.StringIO(text)), start=1):
|
||||
normalized_row = [cell.strip() for cell in row]
|
||||
if any(normalized_row):
|
||||
indexed_rows.append((row_number, normalized_row))
|
||||
|
||||
if not indexed_rows:
|
||||
return []
|
||||
|
||||
header = [cell.lower() for cell in indexed_rows[0][1]]
|
||||
has_header = "email" in header
|
||||
email_index = header.index("email") if has_header else 0
|
||||
name_index: Optional[int] = (
|
||||
header.index("name")
|
||||
if has_header and "name" in header
|
||||
else (1 if not has_header else None)
|
||||
)
|
||||
data_rows = indexed_rows[1:] if has_header else indexed_rows
|
||||
|
||||
parsed_rows: list[_ParsedInviteRow] = []
|
||||
for row_number, row in data_rows:
|
||||
if len(parsed_rows) >= MAX_BULK_INVITE_ROWS:
|
||||
break
|
||||
email = row[email_index].strip() if len(row) > email_index else ""
|
||||
name = (
|
||||
row[name_index].strip()
|
||||
if name_index is not None and len(row) > name_index
|
||||
else ""
|
||||
)
|
||||
parsed_rows.append(
|
||||
_ParsedInviteRow(
|
||||
row_number=row_number,
|
||||
email=email,
|
||||
name=name or None,
|
||||
)
|
||||
)
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
def _parse_bulk_invite_text(text: str) -> list[_ParsedInviteRow]:
|
||||
parsed_rows: list[_ParsedInviteRow] = []
|
||||
|
||||
for row_number, raw_line in enumerate(text.splitlines(), start=1):
|
||||
if len(parsed_rows) >= MAX_BULK_INVITE_ROWS:
|
||||
break
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
parsed_rows.append(
|
||||
_ParsedInviteRow(
|
||||
row_number=row_number,
|
||||
email=line,
|
||||
name=None,
|
||||
)
|
||||
)
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
def _parse_bulk_invite_file(
|
||||
filename: Optional[str],
|
||||
content: bytes,
|
||||
) -> list[_ParsedInviteRow]:
|
||||
text = _decode_bulk_invite_file(content)
|
||||
file_name = filename.lower() if filename else ""
|
||||
parsed_rows = (
|
||||
_parse_bulk_invite_csv(text)
|
||||
if file_name.endswith(".csv")
|
||||
else _parse_bulk_invite_text(text)
|
||||
)
|
||||
|
||||
if not parsed_rows:
|
||||
raise ValueError("Invite file did not contain any emails")
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
async def bulk_create_invited_users_from_file(
|
||||
filename: Optional[str],
|
||||
content: bytes,
|
||||
) -> BulkInvitedUsersResult:
|
||||
parsed_rows = _parse_bulk_invite_file(filename, content)
|
||||
|
||||
created_count = 0
|
||||
skipped_count = 0
|
||||
error_count = 0
|
||||
results: list[BulkInvitedUserRowResult] = []
|
||||
seen_emails: set[str] = set()
|
||||
|
||||
for row in parsed_rows:
|
||||
row_name = _normalize_name(row.name)
|
||||
|
||||
try:
|
||||
validated_email = _email_adapter.validate_python(row.email)
|
||||
except ValidationError:
|
||||
error_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=row.email or None,
|
||||
name=row_name,
|
||||
status="ERROR",
|
||||
message="Invalid email address",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
normalized_email = normalize_email(str(validated_email))
|
||||
if normalized_email in seen_emails:
|
||||
skipped_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="SKIPPED",
|
||||
message="Duplicate email in upload file",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
seen_emails.add(normalized_email)
|
||||
|
||||
try:
|
||||
invited_user = await create_invited_user(normalized_email, row_name)
|
||||
except PreconditionFailed as exc:
|
||||
skipped_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="SKIPPED",
|
||||
message=str(exc),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
masked = mask_email(normalized_email)
|
||||
logger.exception(
|
||||
"Failed to create bulk invite for row %s (%s)",
|
||||
row.row_number,
|
||||
masked,
|
||||
)
|
||||
error_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="ERROR",
|
||||
message="Unexpected error creating invite",
|
||||
)
|
||||
)
|
||||
else:
|
||||
created_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="CREATED",
|
||||
message="Invite created",
|
||||
invited_user=invited_user,
|
||||
)
|
||||
)
|
||||
|
||||
return BulkInvitedUsersResult(
|
||||
created_count=created_count,
|
||||
skipped_count=skipped_count,
|
||||
error_count=error_count,
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
async def _compute_invited_user_tally_seed(invited_user_id: str) -> None:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
return
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
return
|
||||
|
||||
try:
|
||||
r = await get_redis_async()
|
||||
except Exception:
|
||||
r = None
|
||||
|
||||
lock: AsyncClusterLock | None = None
|
||||
|
||||
if r is not None:
|
||||
lock = AsyncClusterLock(
|
||||
redis=r,
|
||||
key=f"tally_seed:{invited_user_id}",
|
||||
owner_id=_WORKER_ID,
|
||||
timeout=_TALLY_STALE_SECONDS,
|
||||
)
|
||||
current_owner = await lock.try_acquire()
|
||||
|
||||
if current_owner is None:
|
||||
logger.warn("Redis unvailable for tally lock - skipping tally enrichement")
|
||||
return
|
||||
elif current_owner != _WORKER_ID:
|
||||
logger.debug(
|
||||
"Tally seed for %s already locked by %s, skipping",
|
||||
invited_user_id,
|
||||
current_owner,
|
||||
)
|
||||
return
|
||||
if (
|
||||
invited_user.tallyStatus == prisma.enums.TallyComputationStatus.RUNNING
|
||||
and invited_user.updatedAt is not None
|
||||
):
|
||||
age = (datetime.now(timezone.utc) - invited_user.updatedAt).total_seconds()
|
||||
if age < _TALLY_STALE_SECONDS:
|
||||
logger.debug(
|
||||
"Tally task for %s still RUNNING (age=%ds), skipping",
|
||||
invited_user_id,
|
||||
int(age),
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"Tally task for %s is stale (age=%ds), re-running",
|
||||
invited_user_id,
|
||||
int(age),
|
||||
)
|
||||
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.RUNNING,
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
input_data = await get_business_understanding_input_from_tally(
|
||||
invited_user.email,
|
||||
require_api_key=True,
|
||||
)
|
||||
payload = (
|
||||
SafeJson(input_data.model_dump(exclude_none=True))
|
||||
if input_data is not None
|
||||
else None
|
||||
)
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyUnderstanding": payload,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.READY,
|
||||
"tallyComputedAt": datetime.now(timezone.utc),
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Failed to compute Tally understanding for invited user %s",
|
||||
invited_user_id,
|
||||
)
|
||||
sanitized_error = re.sub(
|
||||
r"https?://\S+", "<url>", f"{type(exc).__name__}: {exc}"
|
||||
)[:_MAX_TALLY_ERROR_LENGTH]
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.FAILED,
|
||||
"tallyError": sanitized_error,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def schedule_invited_user_tally_precompute(invited_user_id: str) -> None:
|
||||
existing = _tally_seed_tasks.get(invited_user_id)
|
||||
if existing is not None and not existing.done():
|
||||
logger.debug("Tally task already running for %s, skipping", invited_user_id)
|
||||
return
|
||||
|
||||
task = asyncio.create_task(_compute_invited_user_tally_seed(invited_user_id))
|
||||
_tally_seed_tasks[invited_user_id] = task
|
||||
|
||||
def _on_done(t: asyncio.Task, _id: str = invited_user_id) -> None:
|
||||
if _tally_seed_tasks.get(_id) is t:
|
||||
del _tally_seed_tasks[_id]
|
||||
|
||||
task.add_done_callback(_on_done)
|
||||
|
||||
|
||||
async def _open_signup_create_user(
|
||||
auth_user_id: str,
|
||||
normalized_email: str,
|
||||
metadata_name: Optional[str],
|
||||
) -> User:
|
||||
"""Create a user without requiring an invite (open signup mode)."""
|
||||
preferred_name = _normalize_name(metadata_name)
|
||||
try:
|
||||
async with transaction() as tx:
|
||||
user = await prisma.models.User.prisma(tx).create(
|
||||
data=prisma.types.UserCreateInput(
|
||||
id=auth_user_id,
|
||||
email=normalized_email,
|
||||
name=preferred_name,
|
||||
)
|
||||
)
|
||||
await _ensure_default_profile(
|
||||
auth_user_id, normalized_email, preferred_name, tx
|
||||
)
|
||||
await _ensure_default_onboarding(auth_user_id, tx)
|
||||
except UniqueViolationError:
|
||||
existing = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if existing is not None:
|
||||
return User.from_db(existing)
|
||||
raise
|
||||
|
||||
return User.from_db(user)
|
||||
|
||||
|
||||
# TODO: We need to change this functions logic before going live
|
||||
async def get_or_activate_user(user_data: dict) -> User:
|
||||
auth_user_id = user_data.get("sub")
|
||||
if not auth_user_id:
|
||||
raise NotAuthorizedError("User ID not found in token")
|
||||
|
||||
auth_email = user_data.get("email")
|
||||
if not auth_email:
|
||||
raise NotAuthorizedError("Email not found in token")
|
||||
|
||||
normalized_email = normalize_email(auth_email)
|
||||
user_metadata = user_data.get("user_metadata")
|
||||
metadata_name = (
|
||||
user_metadata.get("name") if isinstance(user_metadata, dict) else None
|
||||
)
|
||||
|
||||
existing_user = None
|
||||
try:
|
||||
existing_user = await get_user_by_id(auth_user_id)
|
||||
except ValueError:
|
||||
existing_user = None
|
||||
except Exception:
|
||||
logger.exception("Error on get user by id during tally enrichment process")
|
||||
raise
|
||||
|
||||
if existing_user is not None:
|
||||
return existing_user
|
||||
|
||||
if not _settings.config.enable_invite_gate or normalized_email.endswith("@agpt.co"):
|
||||
return await _open_signup_create_user(
|
||||
auth_user_id, normalized_email, metadata_name
|
||||
)
|
||||
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotAuthorizedError("Your email is not allowed to access the platform")
|
||||
|
||||
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
||||
raise NotAuthorizedError("Your invitation is no longer active")
|
||||
|
||||
try:
|
||||
async with transaction() as tx:
|
||||
current_user = await prisma.models.User.prisma(tx).find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if current_user is not None:
|
||||
return User.from_db(current_user)
|
||||
|
||||
current_invited_user = await prisma.models.InvitedUser.prisma(
|
||||
tx
|
||||
).find_unique(where={"email": normalized_email})
|
||||
if current_invited_user is None:
|
||||
raise NotAuthorizedError(
|
||||
"Your email is not allowed to access the platform"
|
||||
)
|
||||
|
||||
if current_invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
||||
raise NotAuthorizedError("Your invitation is no longer active")
|
||||
|
||||
if current_invited_user.authUserId not in (None, auth_user_id):
|
||||
raise NotAuthorizedError("Your invitation has already been claimed")
|
||||
|
||||
preferred_name = current_invited_user.name or _normalize_name(metadata_name)
|
||||
await prisma.models.User.prisma(tx).create(
|
||||
data=prisma.types.UserCreateInput(
|
||||
id=auth_user_id,
|
||||
email=normalized_email,
|
||||
name=preferred_name,
|
||||
)
|
||||
)
|
||||
|
||||
await prisma.models.InvitedUser.prisma(tx).update(
|
||||
where={"id": current_invited_user.id},
|
||||
data={
|
||||
"status": prisma.enums.InvitedUserStatus.CLAIMED,
|
||||
"authUserId": auth_user_id,
|
||||
},
|
||||
)
|
||||
|
||||
await _ensure_default_profile(
|
||||
auth_user_id,
|
||||
normalized_email,
|
||||
preferred_name,
|
||||
tx,
|
||||
)
|
||||
await _ensure_default_onboarding(auth_user_id, tx)
|
||||
await _apply_tally_understanding(auth_user_id, current_invited_user, tx)
|
||||
except UniqueViolationError:
|
||||
logger.info("Concurrent activation for user %s; re-fetching", auth_user_id)
|
||||
already_created = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if already_created is not None:
|
||||
return User.from_db(already_created)
|
||||
raise RuntimeError(
|
||||
f"UniqueViolationError during activation but user {auth_user_id} not found"
|
||||
)
|
||||
|
||||
get_user_by_id.cache_delete(auth_user_id)
|
||||
get_user_by_email.cache_delete(normalized_email)
|
||||
|
||||
activated_user = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if activated_user is None:
|
||||
raise RuntimeError(
|
||||
f"Activated user {auth_user_id} was not found after creation"
|
||||
)
|
||||
|
||||
return User.from_db(activated_user)
|
||||
@@ -1,335 +0,0 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.util.exceptions import NotAuthorizedError, PreconditionFailed
|
||||
|
||||
from .invited_user import (
|
||||
InvitedUserRecord,
|
||||
bulk_create_invited_users_from_file,
|
||||
create_invited_user,
|
||||
get_or_activate_user,
|
||||
retry_invited_user_tally,
|
||||
)
|
||||
|
||||
|
||||
def _invited_user_db_record(
|
||||
*,
|
||||
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
|
||||
tally_understanding: dict | None = None,
|
||||
):
|
||||
now = datetime.now(timezone.utc)
|
||||
return SimpleNamespace(
|
||||
id="invite-1",
|
||||
email="invited@example.com",
|
||||
status=status,
|
||||
authUserId=None,
|
||||
name="Invited User",
|
||||
tallyUnderstanding=tally_understanding,
|
||||
tallyStatus=prisma.enums.TallyComputationStatus.PENDING,
|
||||
tallyComputedAt=None,
|
||||
tallyError=None,
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
)
|
||||
|
||||
|
||||
def _invited_user_record(
|
||||
*,
|
||||
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
|
||||
tally_understanding: dict | None = None,
|
||||
):
|
||||
return InvitedUserRecord.from_db(
|
||||
cast(
|
||||
prisma.models.InvitedUser,
|
||||
_invited_user_db_record(
|
||||
status=status,
|
||||
tally_understanding=tally_understanding,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _user_db_record():
|
||||
now = datetime.now(timezone.utc)
|
||||
return SimpleNamespace(
|
||||
id="auth-user-1",
|
||||
email="invited@example.com",
|
||||
emailVerified=True,
|
||||
name="Invited User",
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
metadata={},
|
||||
integrations="",
|
||||
stripeCustomerId=None,
|
||||
topUpConfig=None,
|
||||
maxEmailsPerDay=3,
|
||||
notifyOnAgentRun=True,
|
||||
notifyOnZeroBalance=True,
|
||||
notifyOnLowBalance=True,
|
||||
notifyOnBlockExecutionFailed=True,
|
||||
notifyOnContinuousAgentError=True,
|
||||
notifyOnDailySummary=True,
|
||||
notifyOnWeeklySummary=True,
|
||||
notifyOnMonthlySummary=True,
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="not-set",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invited_user_rejects_existing_active_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
user_repo = Mock()
|
||||
user_repo.find_unique = AsyncMock(return_value=_user_db_record())
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
with pytest.raises(PreconditionFailed):
|
||||
await create_invited_user("Invited@example.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invited_user_schedules_tally_seed(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
user_repo = Mock()
|
||||
user_repo.find_unique = AsyncMock(return_value=None)
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock(return_value=None)
|
||||
invited_user_repo.create = AsyncMock(return_value=_invited_user_db_record())
|
||||
schedule = mocker.patch(
|
||||
"backend.data.invited_user.schedule_invited_user_tally_precompute"
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
invited_user = await create_invited_user("Invited@example.com", "Invited User")
|
||||
|
||||
assert invited_user.email == "invited@example.com"
|
||||
invited_user_repo.create.assert_awaited_once()
|
||||
schedule.assert_called_once_with("invite-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_invited_user_tally_resets_state_and_schedules(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock(return_value=_invited_user_db_record())
|
||||
invited_user_repo.update = AsyncMock(return_value=_invited_user_db_record())
|
||||
schedule = mocker.patch(
|
||||
"backend.data.invited_user.schedule_invited_user_tally_precompute"
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
invited_user = await retry_invited_user_tally("invite-1")
|
||||
|
||||
assert invited_user.id == "invite-1"
|
||||
invited_user_repo.update.assert_awaited_once()
|
||||
schedule.assert_called_once_with("invite-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_activate_user_requires_invite(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
mock_get_user_by_id = AsyncMock(side_effect=ValueError("User not found"))
|
||||
mock_get_user_by_id.cache_delete = Mock()
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.get_user_by_id",
|
||||
mock_get_user_by_id,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user._settings.config.enable_invite_gate",
|
||||
True,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
with pytest.raises(NotAuthorizedError):
|
||||
await get_or_activate_user(
|
||||
{"sub": "auth-user-1", "email": "invited@example.com"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_activate_user_creates_user_from_invite(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
tx = object()
|
||||
invited_user = _invited_user_db_record(
|
||||
tally_understanding={"user_name": "Invited User", "industry": "Automation"}
|
||||
)
|
||||
created_user = _user_db_record()
|
||||
|
||||
outside_user_repo = Mock()
|
||||
# Only called once at post-transaction verification (line 741);
|
||||
# get_user_by_id (line 657) uses prisma.user.find_unique, not this mock.
|
||||
outside_user_repo.find_unique = AsyncMock(return_value=created_user)
|
||||
|
||||
inside_user_repo = Mock()
|
||||
inside_user_repo.find_unique = AsyncMock(return_value=None)
|
||||
inside_user_repo.create = AsyncMock(return_value=created_user)
|
||||
|
||||
outside_invited_repo = Mock()
|
||||
outside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
|
||||
|
||||
inside_invited_repo = Mock()
|
||||
inside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
|
||||
inside_invited_repo.update = AsyncMock(return_value=invited_user)
|
||||
|
||||
def user_prisma(client=None):
|
||||
return inside_user_repo if client is tx else outside_user_repo
|
||||
|
||||
def invited_user_prisma(client=None):
|
||||
return inside_invited_repo if client is tx else outside_invited_repo
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_transaction():
|
||||
yield tx
|
||||
|
||||
# Mock get_user_by_id since it uses prisma.user.find_unique (global client),
|
||||
# not prisma.models.User.prisma().find_unique which we mock above.
|
||||
mock_get_user_by_id = AsyncMock(side_effect=ValueError("User not found"))
|
||||
mock_get_user_by_id.cache_delete = Mock()
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.get_user_by_id",
|
||||
mock_get_user_by_id,
|
||||
)
|
||||
mock_get_user_by_email = AsyncMock()
|
||||
mock_get_user_by_email.cache_delete = Mock()
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.get_user_by_email",
|
||||
mock_get_user_by_email,
|
||||
)
|
||||
ensure_profile = mocker.patch(
|
||||
"backend.data.invited_user._ensure_default_profile",
|
||||
AsyncMock(),
|
||||
)
|
||||
ensure_onboarding = mocker.patch(
|
||||
"backend.data.invited_user._ensure_default_onboarding",
|
||||
AsyncMock(),
|
||||
)
|
||||
apply_tally = mocker.patch(
|
||||
"backend.data.invited_user._apply_tally_understanding",
|
||||
AsyncMock(),
|
||||
)
|
||||
mocker.patch("backend.data.invited_user.transaction", fake_transaction)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", side_effect=user_prisma
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
side_effect=invited_user_prisma,
|
||||
)
|
||||
|
||||
user = await get_or_activate_user(
|
||||
{
|
||||
"sub": "auth-user-1",
|
||||
"email": "Invited@example.com",
|
||||
"user_metadata": {"name": "Invited User"},
|
||||
}
|
||||
)
|
||||
|
||||
assert user.id == "auth-user-1"
|
||||
inside_user_repo.create.assert_awaited_once()
|
||||
inside_invited_repo.update.assert_awaited_once()
|
||||
ensure_profile.assert_awaited_once()
|
||||
ensure_onboarding.assert_awaited_once_with("auth-user-1", tx)
|
||||
apply_tally.assert_awaited_once_with("auth-user-1", invited_user, tx)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_create_invited_users_from_text_file(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
create_invited = mocker.patch(
|
||||
"backend.data.invited_user.create_invited_user",
|
||||
AsyncMock(
|
||||
side_effect=[
|
||||
_invited_user_record(),
|
||||
_invited_user_record(),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = await bulk_create_invited_users_from_file(
|
||||
"invites.txt",
|
||||
b"Invited@example.com\nsecond@example.com\n",
|
||||
)
|
||||
|
||||
assert result.created_count == 2
|
||||
assert result.skipped_count == 0
|
||||
assert result.error_count == 0
|
||||
assert [row.status for row in result.results] == ["CREATED", "CREATED"]
|
||||
assert create_invited.await_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_create_invited_users_handles_csv_duplicates_and_invalid_rows(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
create_invited = mocker.patch(
|
||||
"backend.data.invited_user.create_invited_user",
|
||||
AsyncMock(
|
||||
side_effect=[
|
||||
_invited_user_record(),
|
||||
PreconditionFailed("An invited user with this email already exists"),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = await bulk_create_invited_users_from_file(
|
||||
"invites.csv",
|
||||
(
|
||||
"email,name\n"
|
||||
"valid@example.com,Valid User\n"
|
||||
"not-an-email,Bad Row\n"
|
||||
"valid@example.com,Duplicate In File\n"
|
||||
"existing@example.com,Existing User\n"
|
||||
).encode("utf-8"),
|
||||
)
|
||||
|
||||
assert result.created_count == 1
|
||||
assert result.skipped_count == 2
|
||||
assert result.error_count == 1
|
||||
assert [row.status for row in result.results] == [
|
||||
"CREATED",
|
||||
"ERROR",
|
||||
"SKIPPED",
|
||||
"SKIPPED",
|
||||
]
|
||||
assert create_invited.await_count == 2
|
||||
@@ -8,8 +8,6 @@ from backend.api.model import NotificationPayload
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.util.settings import Settings
|
||||
|
||||
_settings = Settings()
|
||||
|
||||
|
||||
class NotificationEvent(BaseModel):
|
||||
"""Generic notification event destined for websocket delivery."""
|
||||
@@ -28,7 +26,7 @@ class AsyncRedisNotificationEventBus(AsyncRedisEventBus[NotificationEvent]):
|
||||
|
||||
@property
|
||||
def event_bus_name(self) -> str:
|
||||
return _settings.config.notification_event_bus_name
|
||||
return Settings().config.notification_event_bus_name
|
||||
|
||||
async def publish(self, event: NotificationEvent) -> None:
|
||||
await self.publish_event(event, event.user_id)
|
||||
|
||||
@@ -41,7 +41,7 @@ _MAX_PAGES = 100
|
||||
_LLM_TIMEOUT = 30
|
||||
|
||||
|
||||
def mask_email(email: str) -> str:
|
||||
def _mask_email(email: str) -> str:
|
||||
"""Mask an email for safe logging: 'alice@example.com' -> 'a***e@example.com'."""
|
||||
try:
|
||||
local, domain = email.rsplit("@", 1)
|
||||
@@ -196,7 +196,8 @@ async def _refresh_cache(form_id: str) -> tuple[dict, list]:
|
||||
|
||||
Returns (email_index, questions).
|
||||
"""
|
||||
client = _make_tally_client(_settings.secrets.tally_api_key)
|
||||
settings = Settings()
|
||||
client = _make_tally_client(settings.secrets.tally_api_key)
|
||||
|
||||
redis = await get_redis_async()
|
||||
last_fetch_key = _LAST_FETCH_KEY.format(form_id=form_id)
|
||||
@@ -331,9 +332,6 @@ Fields:
|
||||
- current_software (list of strings): software/tools currently used
|
||||
- existing_automation (list of strings): existing automations
|
||||
- additional_notes (string): any additional context
|
||||
- suggested_prompts (list of 5 strings): short action prompts (each under 20 words) that would help \
|
||||
this person get started with automating their work. Should be specific to their industry, role, and \
|
||||
pain points; actionable and conversational in tone; focused on automation opportunities.
|
||||
|
||||
Form data:
|
||||
"""
|
||||
@@ -341,21 +339,21 @@ Form data:
|
||||
_EXTRACTION_SUFFIX = "\n\nReturn ONLY valid JSON."
|
||||
|
||||
|
||||
async def extract_business_understanding_from_tally(
|
||||
async def extract_business_understanding(
|
||||
formatted_text: str,
|
||||
) -> BusinessUnderstandingInput:
|
||||
"""
|
||||
Use an LLM to extract structured business understanding from form text.
|
||||
"""Use an LLM to extract structured business understanding from form text.
|
||||
|
||||
Raises on timeout or unparseable response so the caller can handle it.
|
||||
"""
|
||||
api_key = _settings.secrets.open_router_api_key
|
||||
settings = Settings()
|
||||
api_key = settings.secrets.open_router_api_key
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=OPENROUTER_BASE_URL)
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
client.chat.completions.create(
|
||||
model=_settings.config.tally_extraction_llm_model,
|
||||
model="openai/gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
@@ -380,57 +378,9 @@ async def extract_business_understanding_from_tally(
|
||||
|
||||
# Filter out null values before constructing
|
||||
cleaned = {k: v for k, v in data.items() if v is not None}
|
||||
|
||||
# Validate suggested_prompts: filter >20 words, keep top 3
|
||||
raw_prompts = cleaned.get("suggested_prompts", [])
|
||||
if isinstance(raw_prompts, list):
|
||||
valid = [
|
||||
p.strip()
|
||||
for p in raw_prompts
|
||||
if isinstance(p, str) and len(p.strip().split()) <= 20
|
||||
]
|
||||
# This will keep up to 3 suggestions
|
||||
short_prompts = valid[:3] if valid else None
|
||||
if short_prompts:
|
||||
cleaned["suggested_prompts"] = short_prompts
|
||||
else:
|
||||
# We dont want to add a None value suggested_prompts field
|
||||
cleaned.pop("suggested_prompts", None)
|
||||
else:
|
||||
# suggested_prompts must be a list - removing it as its not here
|
||||
cleaned.pop("suggested_prompts", None)
|
||||
|
||||
return BusinessUnderstandingInput(**cleaned)
|
||||
|
||||
|
||||
async def get_business_understanding_input_from_tally(
|
||||
email: str,
|
||||
*,
|
||||
require_api_key: bool = False,
|
||||
) -> Optional[BusinessUnderstandingInput]:
|
||||
if not _settings.secrets.tally_api_key:
|
||||
if require_api_key:
|
||||
raise RuntimeError("Tally API key is not configured")
|
||||
logger.debug("Tally: no API key configured, skipping")
|
||||
return None
|
||||
|
||||
masked = mask_email(email)
|
||||
result = await find_submission_by_email(TALLY_FORM_ID, email)
|
||||
if result is None:
|
||||
logger.debug(f"Tally: no submission found for {masked}")
|
||||
return None
|
||||
|
||||
submission, questions = result
|
||||
logger.info(f"Tally: found submission for {masked}, extracting understanding")
|
||||
|
||||
formatted = format_submission_for_llm(submission, questions)
|
||||
if not formatted.strip():
|
||||
logger.warning("Tally: formatted submission was empty, skipping")
|
||||
return None
|
||||
|
||||
return await extract_business_understanding_from_tally(formatted)
|
||||
|
||||
|
||||
async def populate_understanding_from_tally(user_id: str, email: str) -> None:
|
||||
"""Main orchestrator: check Tally for a matching submission and populate understanding.
|
||||
|
||||
@@ -445,10 +395,30 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
|
||||
)
|
||||
return
|
||||
|
||||
understanding_input = await get_business_understanding_input_from_tally(email)
|
||||
if understanding_input is None:
|
||||
# Check API key is configured
|
||||
settings = Settings()
|
||||
if not settings.secrets.tally_api_key:
|
||||
logger.debug("Tally: no API key configured, skipping")
|
||||
return
|
||||
|
||||
# Look up submission by email
|
||||
masked = _mask_email(email)
|
||||
result = await find_submission_by_email(TALLY_FORM_ID, email)
|
||||
if result is None:
|
||||
logger.debug(f"Tally: no submission found for {masked}")
|
||||
return
|
||||
|
||||
submission, questions = result
|
||||
logger.info(f"Tally: found submission for {masked}, extracting understanding")
|
||||
|
||||
# Format and extract
|
||||
formatted = format_submission_for_llm(submission, questions)
|
||||
if not formatted.strip():
|
||||
logger.warning("Tally: formatted submission was empty, skipping")
|
||||
return
|
||||
|
||||
understanding_input = await extract_business_understanding(formatted)
|
||||
|
||||
# Upsert into database
|
||||
await upsert_business_understanding(user_id, understanding_input)
|
||||
logger.info(f"Tally: successfully populated understanding for user {user_id}")
|
||||
|
||||
@@ -12,11 +12,11 @@ from backend.data.tally import (
|
||||
_build_email_index,
|
||||
_format_answer,
|
||||
_make_tally_client,
|
||||
_mask_email,
|
||||
_refresh_cache,
|
||||
extract_business_understanding_from_tally,
|
||||
extract_business_understanding,
|
||||
find_submission_by_email,
|
||||
format_submission_for_llm,
|
||||
mask_email,
|
||||
populate_understanding_from_tally,
|
||||
)
|
||||
|
||||
@@ -248,7 +248,7 @@ async def test_populate_understanding_skips_no_api_key():
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
@@ -284,7 +284,6 @@ async def test_populate_understanding_full_flow():
|
||||
],
|
||||
}
|
||||
mock_input = MagicMock()
|
||||
mock_input.suggested_prompts = ["Prompt 1", "Prompt 2", "Prompt 3"]
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -292,14 +291,14 @@ async def test_populate_understanding_full_flow():
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(submission, SAMPLE_QUESTIONS),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.extract_business_understanding_from_tally",
|
||||
"backend.data.tally.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_input,
|
||||
) as mock_extract,
|
||||
@@ -332,14 +331,14 @@ async def test_populate_understanding_handles_llm_timeout():
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(submission, SAMPLE_QUESTIONS),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.extract_business_understanding_from_tally",
|
||||
"backend.data.tally.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=asyncio.TimeoutError(),
|
||||
),
|
||||
@@ -357,13 +356,13 @@ async def test_populate_understanding_handles_llm_timeout():
|
||||
|
||||
|
||||
def test_mask_email():
|
||||
assert mask_email("alice@example.com") == "a***e@example.com"
|
||||
assert mask_email("ab@example.com") == "a***@example.com"
|
||||
assert mask_email("a@example.com") == "a***@example.com"
|
||||
assert _mask_email("alice@example.com") == "a***e@example.com"
|
||||
assert _mask_email("ab@example.com") == "a***@example.com"
|
||||
assert _mask_email("a@example.com") == "a***@example.com"
|
||||
|
||||
|
||||
def test_mask_email_invalid():
|
||||
assert mask_email("no-at-sign") == "***"
|
||||
assert _mask_email("no-at-sign") == "***"
|
||||
|
||||
|
||||
# ── Prompt construction (curly-brace safety) ─────────────────────────────────
|
||||
@@ -394,11 +393,11 @@ def test_extraction_prompt_no_format_placeholders():
|
||||
assert single_braces == [], f"Found format placeholders: {single_braces}"
|
||||
|
||||
|
||||
# ── extract_business_understanding_from_tally ────────────────────────────────────────────
|
||||
# ── extract_business_understanding ────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_success():
|
||||
async def test_extract_business_understanding_success():
|
||||
"""Happy path: LLM returns valid JSON that maps to BusinessUnderstandingInput."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
@@ -407,13 +406,6 @@ async def test_extract_business_understanding_from_tally_success():
|
||||
"business_name": "Acme Corp",
|
||||
"industry": "Technology",
|
||||
"pain_points": ["manual reporting"],
|
||||
"suggested_prompts": [
|
||||
"Automate weekly reports",
|
||||
"Set up invoice processing",
|
||||
"Create a customer onboarding flow",
|
||||
"Track project deadlines automatically",
|
||||
"Send follow-up emails after meetings",
|
||||
],
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
@@ -423,56 +415,16 @@ async def test_extract_business_understanding_from_tally_success():
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
|
||||
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
result = await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
assert result.user_name == "Alice"
|
||||
assert result.business_name == "Acme Corp"
|
||||
assert result.industry == "Technology"
|
||||
assert result.pain_points == ["manual reporting"]
|
||||
# suggested_prompts validated and sliced to top 3
|
||||
assert result.suggested_prompts == [
|
||||
"Automate weekly reports",
|
||||
"Set up invoice processing",
|
||||
"Create a customer onboarding flow",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_filters_long_prompts():
|
||||
"""Prompts exceeding 20 words are excluded and only top 3 are kept."""
|
||||
long_prompt = " ".join(["word"] * 21)
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{
|
||||
"user_name": "Alice",
|
||||
"suggested_prompts": [
|
||||
long_prompt,
|
||||
"Short prompt one",
|
||||
long_prompt,
|
||||
"Short prompt two",
|
||||
"Short prompt three",
|
||||
"Short prompt four",
|
||||
],
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
|
||||
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
|
||||
assert result.suggested_prompts == [
|
||||
"Short prompt one",
|
||||
"Short prompt two",
|
||||
"Short prompt three",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_filters_nulls():
|
||||
async def test_extract_business_understanding_filters_nulls():
|
||||
"""Null values from LLM should be excluded from the result."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
@@ -485,7 +437,7 @@ async def test_extract_business_understanding_from_tally_filters_nulls():
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
|
||||
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
result = await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
assert result.user_name == "Alice"
|
||||
assert result.business_name is None
|
||||
@@ -493,7 +445,7 @@ async def test_extract_business_understanding_from_tally_filters_nulls():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_invalid_json():
|
||||
async def test_extract_business_understanding_invalid_json():
|
||||
"""Invalid JSON from LLM should raise JSONDecodeError."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = "not valid json {"
|
||||
@@ -507,11 +459,11 @@ async def test_extract_business_understanding_from_tally_invalid_json():
|
||||
patch("backend.data.tally.AsyncOpenAI", return_value=mock_client),
|
||||
pytest.raises(json.JSONDecodeError),
|
||||
):
|
||||
await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_timeout():
|
||||
async def test_extract_business_understanding_timeout():
|
||||
"""LLM timeout should propagate as asyncio.TimeoutError."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.side_effect = asyncio.TimeoutError()
|
||||
@@ -521,7 +473,7 @@ async def test_extract_business_understanding_from_tally_timeout():
|
||||
patch("backend.data.tally._LLM_TIMEOUT", 0.001),
|
||||
pytest.raises(asyncio.TimeoutError),
|
||||
):
|
||||
await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
|
||||
# ── _refresh_cache ───────────────────────────────────────────────────────────
|
||||
@@ -540,7 +492,7 @@ async def test_refresh_cache_full_fetch():
|
||||
submissions = SAMPLE_SUBMISSIONS
|
||||
|
||||
with (
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
@@ -588,7 +540,7 @@ async def test_refresh_cache_incremental_fetch():
|
||||
new_submissions = [SAMPLE_SUBMISSIONS[0]] # Just Alice
|
||||
|
||||
with (
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
|
||||
@@ -86,11 +86,6 @@ class BusinessUnderstandingInput(pydantic.BaseModel):
|
||||
None, description="Any additional context"
|
||||
)
|
||||
|
||||
# Suggested prompts (UI-only, not included in system prompt)
|
||||
suggested_prompts: Optional[list[str]] = pydantic.Field(
|
||||
None, description="LLM-generated suggested prompts based on business context"
|
||||
)
|
||||
|
||||
|
||||
class BusinessUnderstanding(pydantic.BaseModel):
|
||||
"""Full business understanding model returned from database."""
|
||||
@@ -127,9 +122,6 @@ class BusinessUnderstanding(pydantic.BaseModel):
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = None
|
||||
|
||||
# Suggested prompts (UI-only, not included in system prompt)
|
||||
suggested_prompts: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_record: CoPilotUnderstanding) -> "BusinessUnderstanding":
|
||||
"""Convert database record to Pydantic model."""
|
||||
@@ -157,7 +149,6 @@ class BusinessUnderstanding(pydantic.BaseModel):
|
||||
current_software=_json_to_list(business.get("current_software")),
|
||||
existing_automation=_json_to_list(business.get("existing_automation")),
|
||||
additional_notes=business.get("additional_notes"),
|
||||
suggested_prompts=_json_to_list(data.get("suggested_prompts")),
|
||||
)
|
||||
|
||||
|
||||
@@ -175,62 +166,6 @@ def _merge_lists(existing: list | None, new: list | None) -> list | None:
|
||||
return merged
|
||||
|
||||
|
||||
def merge_business_understanding_data(
|
||||
existing_data: dict[str, Any],
|
||||
input_data: BusinessUnderstandingInput,
|
||||
) -> dict[str, Any]:
|
||||
merged_data = dict(existing_data)
|
||||
|
||||
merged_business: dict[str, Any] = {}
|
||||
if isinstance(merged_data.get("business"), dict):
|
||||
merged_business = dict(merged_data["business"])
|
||||
|
||||
business_string_fields = [
|
||||
"job_title",
|
||||
"business_name",
|
||||
"industry",
|
||||
"business_size",
|
||||
"user_role",
|
||||
"additional_notes",
|
||||
]
|
||||
business_list_fields = [
|
||||
"key_workflows",
|
||||
"daily_activities",
|
||||
"pain_points",
|
||||
"bottlenecks",
|
||||
"manual_tasks",
|
||||
"automation_goals",
|
||||
"current_software",
|
||||
"existing_automation",
|
||||
]
|
||||
|
||||
if input_data.user_name is not None:
|
||||
merged_data["name"] = input_data.user_name
|
||||
|
||||
for field in business_string_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
merged_business[field] = value
|
||||
|
||||
for field in business_list_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_list = _json_to_list(merged_business.get(field))
|
||||
merged_list = _merge_lists(existing_list, value)
|
||||
merged_business[field] = merged_list
|
||||
|
||||
merged_business["version"] = 1
|
||||
merged_data["business"] = merged_business
|
||||
|
||||
# suggested_prompts lives at the top level (not under `business`) because
|
||||
# it's a UI-only artifact consumed by the frontend, not business understanding
|
||||
# data. The `business` sub-dict feeds the system prompt.
|
||||
if input_data.suggested_prompts is not None:
|
||||
merged_data["suggested_prompts"] = input_data.suggested_prompts
|
||||
|
||||
return merged_data
|
||||
|
||||
|
||||
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
|
||||
"""Get business understanding from Redis cache."""
|
||||
try:
|
||||
@@ -310,18 +245,63 @@ async def upsert_business_understanding(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
|
||||
# Get existing data structure or start fresh
|
||||
existing_data: dict[str, Any] = {}
|
||||
if existing and isinstance(existing.data, dict):
|
||||
existing_data = dict(existing.data)
|
||||
|
||||
merged_data = merge_business_understanding_data(existing_data, input_data)
|
||||
existing_business: dict[str, Any] = {}
|
||||
if isinstance(existing_data.get("business"), dict):
|
||||
existing_business = dict(existing_data["business"])
|
||||
|
||||
# Business fields (stored inside business object)
|
||||
business_string_fields = [
|
||||
"job_title",
|
||||
"business_name",
|
||||
"industry",
|
||||
"business_size",
|
||||
"user_role",
|
||||
"additional_notes",
|
||||
]
|
||||
business_list_fields = [
|
||||
"key_workflows",
|
||||
"daily_activities",
|
||||
"pain_points",
|
||||
"bottlenecks",
|
||||
"manual_tasks",
|
||||
"automation_goals",
|
||||
"current_software",
|
||||
"existing_automation",
|
||||
]
|
||||
|
||||
# Handle top-level name field
|
||||
if input_data.user_name is not None:
|
||||
existing_data["name"] = input_data.user_name
|
||||
|
||||
# Business string fields - overwrite if provided
|
||||
for field in business_string_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_business[field] = value
|
||||
|
||||
# Business list fields - merge with existing
|
||||
for field in business_list_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_list = _json_to_list(existing_business.get(field))
|
||||
merged = _merge_lists(existing_list, value)
|
||||
existing_business[field] = merged
|
||||
|
||||
# Set version and nest business data
|
||||
existing_business["version"] = 1
|
||||
existing_data["business"] = existing_business
|
||||
|
||||
# Upsert with the merged data
|
||||
record = await CoPilotUnderstanding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "data": SafeJson(merged_data)},
|
||||
"update": {"data": SafeJson(merged_data)},
|
||||
"create": {"userId": user_id, "data": SafeJson(existing_data)},
|
||||
"update": {"data": SafeJson(existing_data)},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
"""Tests for business understanding merge and format logic."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstanding,
|
||||
BusinessUnderstandingInput,
|
||||
format_understanding_for_prompt,
|
||||
merge_business_understanding_data,
|
||||
)
|
||||
|
||||
|
||||
def _make_input(**kwargs: Any) -> BusinessUnderstandingInput:
|
||||
"""Create a BusinessUnderstandingInput with only the specified fields."""
|
||||
return BusinessUnderstandingInput.model_validate(kwargs)
|
||||
|
||||
|
||||
# ─── merge_business_understanding_data: suggested_prompts ─────────────
|
||||
|
||||
|
||||
def test_merge_suggested_prompts_overwrites_existing():
|
||||
"""New suggested_prompts should fully replace existing ones (not append)."""
|
||||
existing = {
|
||||
"name": "Alice",
|
||||
"business": {"industry": "Tech", "version": 1},
|
||||
"suggested_prompts": ["Old prompt 1", "Old prompt 2"],
|
||||
}
|
||||
input_data = _make_input(
|
||||
suggested_prompts=["New prompt A", "New prompt B", "New prompt C"],
|
||||
)
|
||||
|
||||
result = merge_business_understanding_data(existing, input_data)
|
||||
|
||||
assert result["suggested_prompts"] == [
|
||||
"New prompt A",
|
||||
"New prompt B",
|
||||
"New prompt C",
|
||||
]
|
||||
|
||||
|
||||
def test_merge_suggested_prompts_none_preserves_existing():
|
||||
"""When input has suggested_prompts=None, existing prompts are preserved."""
|
||||
existing = {
|
||||
"name": "Alice",
|
||||
"business": {"industry": "Tech", "version": 1},
|
||||
"suggested_prompts": ["Keep me"],
|
||||
}
|
||||
input_data = _make_input(industry="Finance")
|
||||
|
||||
result = merge_business_understanding_data(existing, input_data)
|
||||
|
||||
assert result["suggested_prompts"] == ["Keep me"]
|
||||
assert result["business"]["industry"] == "Finance"
|
||||
|
||||
|
||||
def test_merge_suggested_prompts_added_to_empty_data():
|
||||
"""Suggested prompts are set at top level even when starting from empty data."""
|
||||
existing: dict[str, Any] = {}
|
||||
input_data = _make_input(suggested_prompts=["Prompt 1"])
|
||||
|
||||
result = merge_business_understanding_data(existing, input_data)
|
||||
|
||||
assert result["suggested_prompts"] == ["Prompt 1"]
|
||||
|
||||
|
||||
def test_merge_suggested_prompts_empty_list_overwrites():
|
||||
"""An explicit empty list should overwrite existing prompts."""
|
||||
existing: dict[str, Any] = {
|
||||
"suggested_prompts": ["Old prompt"],
|
||||
"business": {"version": 1},
|
||||
}
|
||||
input_data = _make_input(suggested_prompts=[])
|
||||
|
||||
result = merge_business_understanding_data(existing, input_data)
|
||||
|
||||
assert result["suggested_prompts"] == []
|
||||
|
||||
|
||||
# ─── format_understanding_for_prompt: excludes suggested_prompts ──────
|
||||
|
||||
|
||||
def test_format_understanding_excludes_suggested_prompts():
|
||||
"""suggested_prompts is UI-only and must NOT appear in the system prompt."""
|
||||
understanding = BusinessUnderstanding(
|
||||
id="test-id",
|
||||
user_id="user-1",
|
||||
created_at=datetime.now(tz=timezone.utc),
|
||||
updated_at=datetime.now(tz=timezone.utc),
|
||||
user_name="Alice",
|
||||
industry="Technology",
|
||||
suggested_prompts=["Automate reports", "Set up alerts", "Track KPIs"],
|
||||
)
|
||||
|
||||
formatted = format_understanding_for_prompt(understanding)
|
||||
|
||||
assert "Alice" in formatted
|
||||
assert "Technology" in formatted
|
||||
assert "suggested_prompts" not in formatted
|
||||
assert "Automate reports" not in formatted
|
||||
assert "Set up alerts" not in formatted
|
||||
assert "Track KPIs" not in formatted
|
||||
@@ -61,12 +61,7 @@ from backend.util.decorator import (
|
||||
error_logged,
|
||||
time_measured,
|
||||
)
|
||||
from backend.util.exceptions import (
|
||||
GraphNotFoundError,
|
||||
InsufficientBalanceError,
|
||||
ModerationError,
|
||||
NotFoundError,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.metrics import DiscordChannel
|
||||
@@ -380,16 +375,9 @@ async def execute_node(
|
||||
log_metadata.debug("Node produced output", **{output_name: output_data})
|
||||
yield output_name, output_data
|
||||
except Exception as ex:
|
||||
# Only capture unexpected errors to Sentry, not user-caused ones.
|
||||
# Most ValueError subclasses here are expected (BlockExecutionError,
|
||||
# InsufficientBalanceError, plain ValueError for auth/disabled blocks, etc.)
|
||||
# but NotFoundError/GraphNotFoundError could indicate real platform issues.
|
||||
is_expected = isinstance(ex, ValueError) and not isinstance(
|
||||
ex, (NotFoundError, GraphNotFoundError)
|
||||
)
|
||||
if not is_expected:
|
||||
sentry_sdk.capture_exception(error=ex, scope=scope)
|
||||
sentry_sdk.flush()
|
||||
# Capture exception WITH context still set before restoring scope
|
||||
sentry_sdk.capture_exception(error=ex, scope=scope)
|
||||
sentry_sdk.flush() # Ensure it's sent before we restore scope
|
||||
# Re-raise to maintain normal error flow
|
||||
raise
|
||||
finally:
|
||||
@@ -1490,7 +1478,7 @@ class ExecutionProcessor:
|
||||
alert_message, DiscordChannel.PRODUCT
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send low balance Discord alert: {e}")
|
||||
logger.error(f"Failed to send low balance Discord alert: {e}")
|
||||
|
||||
|
||||
class ExecutionManager(AppProcess):
|
||||
@@ -1912,16 +1900,17 @@ class ExecutionManager(AppProcess):
|
||||
channel = client.get_channel()
|
||||
channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming())
|
||||
|
||||
thread.join(timeout=300)
|
||||
if thread.is_alive():
|
||||
logger.warning(
|
||||
try:
|
||||
thread.join(timeout=300)
|
||||
except TimeoutError:
|
||||
logger.error(
|
||||
f"{prefix} ⚠️ Run thread did not finish in time, forcing disconnect"
|
||||
)
|
||||
|
||||
client.disconnect()
|
||||
logger.info(f"{prefix} ✅ Run client disconnected")
|
||||
except Exception as e:
|
||||
logger.warning(f"{prefix} ⚠️ Error disconnecting run client: {type(e)} {e}")
|
||||
logger.error(f"{prefix} ⚠️ Error disconnecting run client: {type(e)} {e}")
|
||||
|
||||
def cleanup(self):
|
||||
"""Override cleanup to implement graceful shutdown with active execution waiting."""
|
||||
@@ -1937,9 +1926,7 @@ class ExecutionManager(AppProcess):
|
||||
)
|
||||
logger.info(f"{prefix} ✅ Exec consumer has been signaled to stop")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"{prefix} ⚠️ Error signaling consumer to stop: {type(e)} {e}"
|
||||
)
|
||||
logger.error(f"{prefix} ⚠️ Error signaling consumer to stop: {type(e)} {e}")
|
||||
|
||||
# Wait for active executions to complete
|
||||
if self.active_graph_runs:
|
||||
@@ -1970,7 +1957,7 @@ class ExecutionManager(AppProcess):
|
||||
waited += wait_interval
|
||||
|
||||
if self.active_graph_runs:
|
||||
logger.warning(
|
||||
logger.error(
|
||||
f"{prefix} ⚠️ {len(self.active_graph_runs)} executions still running after {max_wait}s"
|
||||
)
|
||||
else:
|
||||
@@ -1981,7 +1968,7 @@ class ExecutionManager(AppProcess):
|
||||
self.executor.shutdown(cancel_futures=True, wait=False)
|
||||
logger.info(f"{prefix} ✅ Executor shutdown completed")
|
||||
except Exception as e:
|
||||
logger.warning(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}")
|
||||
logger.error(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}")
|
||||
|
||||
# Release remaining execution locks
|
||||
try:
|
||||
|
||||
@@ -94,7 +94,7 @@ SCHEDULER_OPERATION_TIMEOUT_SECONDS = 300 # 5 minutes for scheduler operations
|
||||
def job_listener(event):
|
||||
"""Logs job execution outcomes for better monitoring."""
|
||||
if event.exception:
|
||||
logger.warning(
|
||||
logger.error(
|
||||
f"Job {event.job_id} failed: {type(event.exception).__name__}: {event.exception}"
|
||||
)
|
||||
else:
|
||||
@@ -137,7 +137,7 @@ def run_async(coro, timeout: float = SCHEDULER_OPERATION_TIMEOUT_SECONDS):
|
||||
try:
|
||||
return future.result(timeout=timeout)
|
||||
except Exception as e:
|
||||
logger.warning(f"Async operation failed: {type(e).__name__}: {e}")
|
||||
logger.error(f"Async operation failed: {type(e).__name__}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -186,7 +186,7 @@ async def _execute_graph(**kwargs):
|
||||
|
||||
|
||||
async def _handle_graph_validation_error(args: "GraphExecutionJobArgs") -> None:
|
||||
logger.warning(
|
||||
logger.error(
|
||||
f"Scheduled Graph {args.graph_id} failed validation. Unscheduling graph"
|
||||
)
|
||||
if args.schedule_id:
|
||||
@@ -196,9 +196,8 @@ async def _handle_graph_validation_error(args: "GraphExecutionJobArgs") -> None:
|
||||
user_id=args.user_id,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unable to unschedule graph: {args.graph_id} as this is an old job "
|
||||
f"with no associated schedule_id please remove manually"
|
||||
logger.error(
|
||||
f"Unable to unschedule graph: {args.graph_id} as this is an old job with no associated schedule_id please remove manually"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from backend.util.exceptions import (
|
||||
)
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import coerce_inputs_to_schema
|
||||
from backend.util.type import convert
|
||||
|
||||
config = Config()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
|
||||
@@ -213,8 +213,11 @@ def validate_exec(
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Coerce non-matching data types to the expected input schema.
|
||||
coerce_inputs_to_schema(data, schema)
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
value = data.get(name)
|
||||
if (value is not None) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
if missing_input := schema.get_missing_input(data):
|
||||
|
||||
@@ -303,9 +303,9 @@ class NotificationManager(AppService):
|
||||
)
|
||||
|
||||
if not oldest_message:
|
||||
logger.warning(
|
||||
f"Batch for user {batch.user_id} and type {notification_type} "
|
||||
f"has no oldest message — batch may have been cleared concurrently"
|
||||
# this should never happen
|
||||
logger.error(
|
||||
f"Batch for user {batch.user_id} and type {notification_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -318,7 +318,7 @@ class NotificationManager(AppService):
|
||||
).get_user_email_by_id(batch.user_id)
|
||||
|
||||
if not recipient_email:
|
||||
logger.warning(
|
||||
logger.error(
|
||||
f"User email not found for user {batch.user_id}"
|
||||
)
|
||||
continue
|
||||
@@ -344,7 +344,7 @@ class NotificationManager(AppService):
|
||||
).get_user_notification_batch(batch.user_id, notification_type)
|
||||
|
||||
if not batch_data or not batch_data.notifications:
|
||||
logger.warning(
|
||||
logger.error(
|
||||
f"Batch data not found for user {batch.user_id}"
|
||||
)
|
||||
# Clear the batch
|
||||
@@ -372,7 +372,7 @@ class NotificationManager(AppService):
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
logger.error(
|
||||
f"Error parsing notification event: {e=}, {db_event=}"
|
||||
)
|
||||
continue
|
||||
@@ -415,10 +415,7 @@ class NotificationManager(AppService):
|
||||
async def discord_system_alert(
|
||||
self, content: str, channel: DiscordChannel = DiscordChannel.PLATFORM
|
||||
):
|
||||
try:
|
||||
await discord_send_alert(content, channel)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send Discord system alert: {e}")
|
||||
await discord_send_alert(content, channel)
|
||||
|
||||
async def _queue_scheduled_notification(self, event: SummaryParamsEventModel):
|
||||
"""Queue a scheduled notification - exposed method for other services to call"""
|
||||
@@ -519,7 +516,7 @@ class NotificationManager(AppService):
|
||||
raise ValueError("Invalid event type or params")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to gather summary data: {e}")
|
||||
logger.error(f"Failed to gather summary data: {e}")
|
||||
# Return sensible defaults in case of error
|
||||
if event_type == NotificationType.DAILY_SUMMARY and isinstance(
|
||||
params, DailySummaryParams
|
||||
@@ -565,9 +562,8 @@ class NotificationManager(AppService):
|
||||
should_retry=False
|
||||
).get_user_notification_oldest_message_in_batch(user_id, event_type)
|
||||
if not oldest_message:
|
||||
logger.warning(
|
||||
f"Batch for user {user_id} and type {event_type} "
|
||||
f"has no oldest message — batch may have been cleared concurrently"
|
||||
logger.error(
|
||||
f"Batch for user {user_id} and type {event_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
|
||||
)
|
||||
return False
|
||||
oldest_age = oldest_message.created_at
|
||||
@@ -589,7 +585,7 @@ class NotificationManager(AppService):
|
||||
get_notif_data_type(event.type)
|
||||
].model_validate_json(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing message due to non matching schema {e}")
|
||||
logger.error(f"Error parsing message due to non matching schema {e}")
|
||||
return None
|
||||
|
||||
async def _process_admin_message(self, message: str) -> bool:
|
||||
@@ -618,7 +614,7 @@ class NotificationManager(AppService):
|
||||
should_retry=False
|
||||
).get_user_email_by_id(event.user_id)
|
||||
if not recipient_email:
|
||||
logger.warning(f"User email not found for user {event.user_id}")
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
|
||||
should_send = await self._should_email_user_based_on_preference(
|
||||
@@ -655,7 +651,7 @@ class NotificationManager(AppService):
|
||||
should_retry=False
|
||||
).get_user_email_by_id(event.user_id)
|
||||
if not recipient_email:
|
||||
logger.warning(f"User email not found for user {event.user_id}")
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
|
||||
should_send = await self._should_email_user_based_on_preference(
|
||||
@@ -676,7 +672,7 @@ class NotificationManager(AppService):
|
||||
should_retry=False
|
||||
).get_user_notification_batch(event.user_id, event.type)
|
||||
if not batch or not batch.notifications:
|
||||
logger.warning(f"Batch not found for user {event.user_id}")
|
||||
logger.error(f"Batch not found for user {event.user_id}")
|
||||
return False
|
||||
unsub_link = generate_unsubscribe_link(event.user_id)
|
||||
|
||||
@@ -749,7 +745,7 @@ class NotificationManager(AppService):
|
||||
f"Removed {len(chunk_ids)} sent notifications from batch"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
logger.error(
|
||||
f"Failed to remove sent notifications: {e}"
|
||||
)
|
||||
# Continue anyway - better to risk duplicates than lose emails
|
||||
@@ -774,7 +770,7 @@ class NotificationManager(AppService):
|
||||
else:
|
||||
# Message is too large even after size reduction
|
||||
if attempt_size == 1:
|
||||
logger.warning(
|
||||
logger.error(
|
||||
f"Failed to send notification at index {i}: "
|
||||
f"Single notification exceeds email size limit "
|
||||
f"({len(test_message):,} chars > {MAX_EMAIL_SIZE:,} chars). "
|
||||
@@ -793,7 +789,7 @@ class NotificationManager(AppService):
|
||||
f"Removed oversized notification {chunk_ids[0]} from batch permanently"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
logger.error(
|
||||
f"Failed to remove oversized notification: {e}"
|
||||
)
|
||||
|
||||
@@ -827,7 +823,7 @@ class NotificationManager(AppService):
|
||||
f"Set email verification to false for user {event.user_id}"
|
||||
)
|
||||
except Exception as deactivation_error:
|
||||
logger.warning(
|
||||
logger.error(
|
||||
f"Failed to deactivate email for user {event.user_id}: "
|
||||
f"{deactivation_error}"
|
||||
)
|
||||
@@ -839,7 +835,7 @@ class NotificationManager(AppService):
|
||||
f"Disabled all notification preferences for user {event.user_id}"
|
||||
)
|
||||
except Exception as disable_error:
|
||||
logger.warning(
|
||||
logger.error(
|
||||
f"Failed to disable notification preferences: {disable_error}"
|
||||
)
|
||||
|
||||
@@ -852,7 +848,7 @@ class NotificationManager(AppService):
|
||||
f"Cleared ALL notification batches for user {event.user_id}"
|
||||
)
|
||||
except Exception as remove_error:
|
||||
logger.warning(
|
||||
logger.error(
|
||||
f"Failed to clear batches for inactive recipient: {remove_error}"
|
||||
)
|
||||
|
||||
@@ -863,7 +859,7 @@ class NotificationManager(AppService):
|
||||
"422" in error_message
|
||||
or "unprocessable" in error_message
|
||||
):
|
||||
logger.warning(
|
||||
logger.error(
|
||||
f"Failed to send notification at index {i}: "
|
||||
f"Malformed notification data rejected by Postmark. "
|
||||
f"Error: {e}. Removing from batch permanently."
|
||||
@@ -881,7 +877,7 @@ class NotificationManager(AppService):
|
||||
"Removed malformed notification from batch permanently"
|
||||
)
|
||||
except Exception as remove_error:
|
||||
logger.warning(
|
||||
logger.error(
|
||||
f"Failed to remove malformed notification: {remove_error}"
|
||||
)
|
||||
# Check if it's a ValueError for size limit
|
||||
@@ -889,14 +885,14 @@ class NotificationManager(AppService):
|
||||
isinstance(e, ValueError)
|
||||
and "too large" in error_message
|
||||
):
|
||||
logger.warning(
|
||||
logger.error(
|
||||
f"Failed to send notification at index {i}: "
|
||||
f"Notification size exceeds email limit. "
|
||||
f"Error: {e}. Skipping this notification."
|
||||
)
|
||||
# Other API errors
|
||||
else:
|
||||
logger.warning(
|
||||
logger.error(
|
||||
f"Failed to send notification at index {i}: "
|
||||
f"Email API error ({error_type}): {e}. "
|
||||
f"Skipping this notification."
|
||||
@@ -911,9 +907,7 @@ class NotificationManager(AppService):
|
||||
|
||||
if not chunk_sent:
|
||||
# Should not reach here due to single notification handling
|
||||
logger.warning(
|
||||
f"Failed to send notifications starting at index {i}"
|
||||
)
|
||||
logger.error(f"Failed to send notifications starting at index {i}")
|
||||
failed_indices.append(i)
|
||||
i += 1
|
||||
|
||||
@@ -952,7 +946,7 @@ class NotificationManager(AppService):
|
||||
should_retry=False
|
||||
).get_user_email_by_id(event.user_id)
|
||||
if not recipient_email:
|
||||
logger.warning(f"User email not found for user {event.user_id}")
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
should_send = await self._should_email_user_based_on_preference(
|
||||
event.user_id, event.type
|
||||
@@ -1013,10 +1007,7 @@ class NotificationManager(AppService):
|
||||
# Let message.process() handle the rejection
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error processing message in {queue_name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.error(f"Error processing message in {queue_name}: {e}")
|
||||
# Let message.process() handle the rejection
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
|
||||
@@ -256,9 +256,9 @@ class TestNotificationErrorHandling:
|
||||
assert 2 not in successful_indices # Index 2 failed
|
||||
|
||||
# Verify 422 error was logged
|
||||
warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list]
|
||||
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
|
||||
assert any(
|
||||
"422" in call or "malformed" in call.lower() for call in warning_calls
|
||||
"422" in call or "malformed" in call.lower() for call in error_calls
|
||||
)
|
||||
|
||||
# Verify all notifications were removed (4 successful + 1 malformed)
|
||||
@@ -371,10 +371,10 @@ class TestNotificationErrorHandling:
|
||||
assert 3 not in successful_indices # Index 3 was not sent
|
||||
|
||||
# Verify oversized error was logged
|
||||
warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list]
|
||||
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
|
||||
assert any(
|
||||
"exceeds email size limit" in call or "oversized" in call.lower()
|
||||
for call in warning_calls
|
||||
for call in error_calls
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -478,10 +478,10 @@ class TestNotificationErrorHandling:
|
||||
assert 1 in failed_indices # Index 1 failed
|
||||
|
||||
# Verify generic error was logged
|
||||
warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list]
|
||||
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
|
||||
assert any(
|
||||
"api error" in call.lower() or "skipping" in call.lower()
|
||||
for call in warning_calls
|
||||
for call in error_calls
|
||||
)
|
||||
|
||||
# Only successful ones should be removed from batch (failed one stays for retry)
|
||||
|
||||
@@ -613,5 +613,5 @@ async def cleanup_expired_files_async() -> int:
|
||||
)
|
||||
return deleted_count
|
||||
except Exception as e:
|
||||
logger.warning(f"[CloudStorage] Error during cloud storage cleanup: {e}")
|
||||
logger.error(f"[CloudStorage] Error during cloud storage cleanup: {e}")
|
||||
return 0
|
||||
|
||||
@@ -275,12 +275,13 @@ async def store_media_file(
|
||||
# Process file
|
||||
elif file.startswith("data:"):
|
||||
# Data URI
|
||||
parsed_uri = parse_data_uri(file)
|
||||
if parsed_uri is None:
|
||||
match = re.match(r"^data:([^;]+);base64,(.*)$", file, re.DOTALL)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
"Invalid data URI format. Expected data:<mime>;base64,<data>"
|
||||
)
|
||||
mime_type, b64_content = parsed_uri
|
||||
mime_type = match.group(1).strip().lower()
|
||||
b64_content = match.group(2).strip()
|
||||
|
||||
# Generate filename and decode
|
||||
extension = _extension_from_mime(mime_type)
|
||||
@@ -414,70 +415,13 @@ def get_dir_size(path: Path) -> int:
|
||||
return total
|
||||
|
||||
|
||||
async def resolve_media_content(
|
||||
content: MediaFileType,
|
||||
execution_context: "ExecutionContext",
|
||||
*,
|
||||
return_format: MediaReturnFormat,
|
||||
) -> MediaFileType:
|
||||
"""Resolve a ``MediaFileType`` value if it is a media reference, pass through otherwise.
|
||||
|
||||
Convenience wrapper around :func:`is_media_file_ref` + :func:`store_media_file`.
|
||||
Plain text content (source code, filenames) is returned unchanged. Media
|
||||
references (``data:``, ``workspace://``, ``http(s)://``) are resolved via
|
||||
:func:`store_media_file` using *return_format*.
|
||||
|
||||
Use this when a block field is typed as ``MediaFileType`` but may contain
|
||||
either literal text or a media reference.
|
||||
"""
|
||||
if not content or not is_media_file_ref(content):
|
||||
return content
|
||||
return await store_media_file(
|
||||
content, execution_context, return_format=return_format
|
||||
)
|
||||
|
||||
|
||||
def is_media_file_ref(value: str) -> bool:
|
||||
"""Return True if *value* looks like a ``MediaFileType`` reference.
|
||||
|
||||
Detects data URIs, workspace:// references, and HTTP(S) URLs — the
|
||||
formats accepted by :func:`store_media_file`. Plain text content
|
||||
(e.g. source code, filenames) returns False.
|
||||
|
||||
Known limitation: HTTP(S) URL detection is heuristic. Any string that
|
||||
starts with ``http://`` or ``https://`` is treated as a media URL, even
|
||||
if it appears as a URL inside source-code comments or documentation.
|
||||
Blocks that produce source code or Markdown as output may therefore
|
||||
trigger false positives. Callers that need higher precision should
|
||||
inspect the string further (e.g. verify the URL is reachable or has a
|
||||
media-friendly extension).
|
||||
|
||||
Note: this does *not* match local file paths, which are ambiguous
|
||||
(could be filenames or actual paths). Blocks that need to resolve
|
||||
local paths should check for them separately.
|
||||
"""
|
||||
return value.startswith(("data:", "workspace://", "http://", "https://"))
|
||||
|
||||
|
||||
def parse_data_uri(value: str) -> tuple[str, str] | None:
|
||||
"""Parse a ``data:<mime>;base64,<payload>`` URI.
|
||||
|
||||
Returns ``(mime_type, base64_payload)`` if *value* is a valid data URI,
|
||||
or ``None`` if it is not.
|
||||
"""
|
||||
match = re.match(r"^data:([^;]+);base64,(.*)$", value, re.DOTALL)
|
||||
if not match:
|
||||
return None
|
||||
return match.group(1).strip().lower(), match.group(2).strip()
|
||||
|
||||
|
||||
def get_mime_type(file: str) -> str:
|
||||
"""
|
||||
Get the MIME type of a file, whether it's a data URI, URL, or local path.
|
||||
"""
|
||||
if file.startswith("data:"):
|
||||
parsed_uri = parse_data_uri(file)
|
||||
return parsed_uri[0] if parsed_uri else "application/octet-stream"
|
||||
match = re.match(r"^data:([^;]+);base64,", file)
|
||||
return match.group(1) if match else "application/octet-stream"
|
||||
|
||||
elif file.startswith(("http://", "https://")):
|
||||
parsed_url = urlparse(file)
|
||||
|
||||
@@ -1,375 +0,0 @@
|
||||
"""Parse file content into structured Python objects based on file format.
|
||||
|
||||
Used by the ``@@agptfile:`` expansion system to eagerly parse well-known file
|
||||
formats into native Python types *before* schema-driven coercion runs. This
|
||||
lets blocks with ``Any``-typed inputs receive structured data rather than raw
|
||||
strings, while blocks expecting strings get the value coerced back via
|
||||
``convert()``.
|
||||
|
||||
Supported formats:
|
||||
|
||||
- **JSON** (``.json``) — arrays and objects are promoted; scalars stay as strings
|
||||
- **JSON Lines** (``.jsonl``, ``.ndjson``) — each non-empty line parsed as JSON;
|
||||
when all lines are dicts with the same keys (tabular data), output is
|
||||
``list[list[Any]]`` with a header row, consistent with CSV/Parquet/Excel;
|
||||
otherwise returns a plain ``list`` of parsed values
|
||||
- **CSV** (``.csv``) — ``csv.reader`` → ``list[list[str]]``
|
||||
- **TSV** (``.tsv``) — tab-delimited → ``list[list[str]]``
|
||||
- **YAML** (``.yaml``, ``.yml``) — parsed via PyYAML; containers only
|
||||
- **TOML** (``.toml``) — parsed via stdlib ``tomllib``
|
||||
- **Parquet** (``.parquet``) — via pandas/pyarrow → ``list[list[Any]]`` with header row
|
||||
- **Excel** (``.xlsx``) — via pandas/openpyxl → ``list[list[Any]]`` with header row
|
||||
(legacy ``.xls`` is **not** supported — only the modern OOXML format)
|
||||
|
||||
The **fallback contract** is enforced by :func:`parse_file_content`, not by
|
||||
individual parser functions. If any parser raises, ``parse_file_content``
|
||||
catches the exception and returns the original content unchanged (string for
|
||||
text formats, bytes for binary formats). Callers should never see an
|
||||
exception from the public API when ``strict=False``.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import tomllib
|
||||
import zipfile
|
||||
from collections.abc import Callable
|
||||
|
||||
# posixpath.splitext handles forward-slash URI paths correctly on all platforms,
|
||||
# unlike os.path.splitext which uses platform-native separators.
|
||||
from posixpath import splitext
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extension / MIME → format label mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EXT_TO_FORMAT: dict[str, str] = {
|
||||
".json": "json",
|
||||
".jsonl": "jsonl",
|
||||
".ndjson": "jsonl",
|
||||
".csv": "csv",
|
||||
".tsv": "tsv",
|
||||
".yaml": "yaml",
|
||||
".yml": "yaml",
|
||||
".toml": "toml",
|
||||
".parquet": "parquet",
|
||||
".xlsx": "xlsx",
|
||||
}
|
||||
|
||||
MIME_TO_FORMAT: dict[str, str] = {
|
||||
"application/json": "json",
|
||||
"application/x-ndjson": "jsonl",
|
||||
"application/jsonl": "jsonl",
|
||||
"text/csv": "csv",
|
||||
"text/tab-separated-values": "tsv",
|
||||
"application/x-yaml": "yaml",
|
||||
"application/yaml": "yaml",
|
||||
"text/yaml": "yaml",
|
||||
"application/toml": "toml",
|
||||
"application/vnd.apache.parquet": "parquet",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx",
|
||||
}
|
||||
|
||||
# Formats that require raw bytes rather than decoded text.
|
||||
BINARY_FORMATS: frozenset[str] = frozenset({"parquet", "xlsx"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API (top-down: main functions first, helpers below)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def infer_format_from_uri(uri: str) -> str | None:
|
||||
"""Return a format label based on URI extension or MIME fragment.
|
||||
|
||||
Returns ``None`` when the format cannot be determined — the caller should
|
||||
fall back to returning the content as a plain string.
|
||||
"""
|
||||
# 1. Check MIME fragment (workspace://abc123#application/json)
|
||||
if "#" in uri:
|
||||
_, fragment = uri.rsplit("#", 1)
|
||||
fmt = MIME_TO_FORMAT.get(fragment.lower())
|
||||
if fmt:
|
||||
return fmt
|
||||
|
||||
# 2. Check file extension from the path portion.
|
||||
# Strip the fragment first so ".json#mime" doesn't confuse splitext.
|
||||
path = uri.split("#")[0].split("?")[0]
|
||||
_, ext = splitext(path)
|
||||
fmt = _EXT_TO_FORMAT.get(ext.lower())
|
||||
if fmt is not None:
|
||||
return fmt
|
||||
|
||||
# Legacy .xls is not supported — map it so callers can produce a
|
||||
# user-friendly error instead of returning garbled binary.
|
||||
if ext.lower() == ".xls":
|
||||
return "xls"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_file_content(content: str | bytes, fmt: str, *, strict: bool = False) -> Any:
|
||||
"""Parse *content* according to *fmt* and return a native Python value.
|
||||
|
||||
When *strict* is ``False`` (default), returns the original *content*
|
||||
unchanged if *fmt* is not recognised or parsing fails for any reason.
|
||||
This mode **never raises**.
|
||||
|
||||
When *strict* is ``True``, parsing errors are propagated to the caller.
|
||||
Unrecognised formats or type mismatches (e.g. text for a binary format)
|
||||
still return *content* unchanged without raising.
|
||||
"""
|
||||
if fmt == "xls":
|
||||
return (
|
||||
"[Unsupported format] Legacy .xls files are not supported. "
|
||||
"Please re-save the file as .xlsx (Excel 2007+) and upload again."
|
||||
)
|
||||
|
||||
try:
|
||||
if fmt in BINARY_FORMATS:
|
||||
parser = _BINARY_PARSERS.get(fmt)
|
||||
if parser is None:
|
||||
return content
|
||||
if isinstance(content, str):
|
||||
# Caller gave us text for a binary format — can't parse.
|
||||
return content
|
||||
return parser(content)
|
||||
|
||||
parser = _TEXT_PARSERS.get(fmt)
|
||||
if parser is None:
|
||||
return content
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8", errors="replace")
|
||||
return parser(content)
|
||||
|
||||
except PARSE_EXCEPTIONS:
|
||||
if strict:
|
||||
raise
|
||||
logger.debug("Structured parsing failed for format=%s, falling back", fmt)
|
||||
return content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exception loading helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _load_openpyxl_exception() -> type[Exception]:
|
||||
"""Return openpyxl's InvalidFileException, raising ImportError if absent."""
|
||||
from openpyxl.utils.exceptions import InvalidFileException # noqa: PLC0415
|
||||
|
||||
return InvalidFileException
|
||||
|
||||
|
||||
def _load_arrow_exception() -> type[Exception]:
|
||||
"""Return pyarrow's ArrowException, raising ImportError if absent."""
|
||||
from pyarrow import ArrowException # noqa: PLC0415
|
||||
|
||||
return ArrowException
|
||||
|
||||
|
||||
def _optional_exc(loader: "Callable[[], type[Exception]]") -> "type[Exception] | None":
|
||||
"""Return the exception class from *loader*, or ``None`` if the dep is absent."""
|
||||
try:
|
||||
return loader()
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
# Exception types that can be raised during file content parsing.
|
||||
# Shared between ``parse_file_content`` (which catches them in non-strict mode)
|
||||
# and ``file_ref._expand_bare_ref`` (which re-raises them as FileRefExpansionError).
|
||||
#
|
||||
# Optional-dependency exception types are loaded via a helper that raises
|
||||
# ``ImportError`` at *parse time* rather than silently becoming ``None`` here.
|
||||
# This ensures mypy sees clean types and missing deps surface as real errors.
|
||||
PARSE_EXCEPTIONS: tuple[type[BaseException], ...] = tuple(
|
||||
exc
|
||||
for exc in (
|
||||
json.JSONDecodeError,
|
||||
csv.Error,
|
||||
yaml.YAMLError,
|
||||
tomllib.TOMLDecodeError,
|
||||
ValueError,
|
||||
UnicodeDecodeError,
|
||||
ImportError,
|
||||
OSError,
|
||||
KeyError,
|
||||
TypeError,
|
||||
zipfile.BadZipFile,
|
||||
_optional_exc(_load_openpyxl_exception),
|
||||
# ArrowException covers ArrowIOError and ArrowCapacityError which
|
||||
# do not inherit from standard exceptions; ArrowInvalid/ArrowTypeError
|
||||
# already map to ValueError/TypeError but this catches the rest.
|
||||
_optional_exc(_load_arrow_exception),
|
||||
)
|
||||
if exc is not None
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text-based parsers (content: str → Any)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_container(parser: Callable[[str], Any], content: str) -> list | dict | str:
|
||||
"""Parse *content* and return the result only if it is a container (list/dict).
|
||||
|
||||
Scalar values (strings, numbers, booleans, None) are discarded and the
|
||||
original *content* string is returned instead. This prevents e.g. a JSON
|
||||
file containing just ``"42"`` from silently becoming an int.
|
||||
"""
|
||||
parsed = parser(content)
|
||||
if isinstance(parsed, (list, dict)):
|
||||
return parsed
|
||||
return content
|
||||
|
||||
|
||||
def _parse_json(content: str) -> list | dict | str:
|
||||
return _parse_container(json.loads, content)
|
||||
|
||||
|
||||
def _parse_jsonl(content: str) -> Any:
|
||||
lines = [json.loads(line) for line in content.splitlines() if line.strip()]
|
||||
if not lines:
|
||||
return content
|
||||
|
||||
# When every line is a dict with the same keys, convert to table format
|
||||
# (header row + data rows) — consistent with CSV/TSV/Parquet/Excel output.
|
||||
# Require ≥2 dicts so a single-line JSONL stays as [dict] (not a table).
|
||||
if len(lines) >= 2 and all(isinstance(obj, dict) for obj in lines):
|
||||
keys = list(lines[0].keys())
|
||||
# Cache as tuple to avoid O(n×k) list allocations in the all() call.
|
||||
keys_tuple = tuple(keys)
|
||||
if keys and all(tuple(obj.keys()) == keys_tuple for obj in lines[1:]):
|
||||
return [keys] + [[obj[k] for k in keys] for obj in lines]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _parse_csv(content: str) -> Any:
|
||||
return _parse_delimited(content, delimiter=",")
|
||||
|
||||
|
||||
def _parse_tsv(content: str) -> Any:
|
||||
return _parse_delimited(content, delimiter="\t")
|
||||
|
||||
|
||||
def _parse_delimited(content: str, *, delimiter: str) -> Any:
|
||||
reader = csv.reader(io.StringIO(content), delimiter=delimiter)
|
||||
# csv.reader never yields [] — blank lines yield [""]. Filter out
|
||||
# rows where every cell is empty (i.e. truly blank lines).
|
||||
rows = [row for row in reader if _row_has_content(row)]
|
||||
if not rows:
|
||||
return content
|
||||
# If the declared delimiter produces only single-column rows, try
|
||||
# sniffing the actual delimiter — catches misidentified files (e.g.
|
||||
# a tab-delimited file with a .csv extension).
|
||||
if len(rows[0]) == 1:
|
||||
try:
|
||||
dialect = csv.Sniffer().sniff(content[:8192])
|
||||
if dialect.delimiter != delimiter:
|
||||
reader = csv.reader(io.StringIO(content), dialect)
|
||||
rows = [row for row in reader if _row_has_content(row)]
|
||||
except csv.Error:
|
||||
pass
|
||||
if rows and len(rows[0]) >= 2:
|
||||
return rows
|
||||
return content
|
||||
|
||||
|
||||
def _row_has_content(row: list[str]) -> bool:
|
||||
"""Return True when *row* contains at least one non-empty cell.
|
||||
|
||||
``csv.reader`` never yields ``[]`` — truly blank lines yield ``[""]``.
|
||||
This predicate filters those out consistently across the initial read
|
||||
and the sniffer-fallback re-read.
|
||||
"""
|
||||
return any(cell for cell in row)
|
||||
|
||||
|
||||
def _parse_yaml(content: str) -> list | dict | str:
|
||||
# NOTE: YAML anchor/alias expansion can amplify input beyond the 10MB cap.
|
||||
# safe_load prevents code execution; for production hardening consider
|
||||
# a YAML parser with expansion limits (e.g. ruamel.yaml with max_alias_count).
|
||||
if "\n---" in content or content.startswith("---\n"):
|
||||
# Multi-document YAML: only the first document is parsed; the rest
|
||||
# are silently ignored by yaml.safe_load. Warn so callers are aware.
|
||||
logger.warning(
|
||||
"Multi-document YAML detected (--- separator); "
|
||||
"only the first document will be parsed."
|
||||
)
|
||||
return _parse_container(yaml.safe_load, content)
|
||||
|
||||
|
||||
def _parse_toml(content: str) -> Any:
|
||||
parsed = tomllib.loads(content)
|
||||
# tomllib.loads always returns a dict — return it even if empty.
|
||||
return parsed
|
||||
|
||||
|
||||
_TEXT_PARSERS: dict[str, Callable[[str], Any]] = {
|
||||
"json": _parse_json,
|
||||
"jsonl": _parse_jsonl,
|
||||
"csv": _parse_csv,
|
||||
"tsv": _parse_tsv,
|
||||
"yaml": _parse_yaml,
|
||||
"toml": _parse_toml,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Binary-based parsers (content: bytes → Any)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_parquet(content: bytes) -> list[list[Any]]:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.read_parquet(io.BytesIO(content))
|
||||
return _df_to_rows(df)
|
||||
|
||||
|
||||
def _parse_xlsx(content: bytes) -> list[list[Any]]:
|
||||
import pandas as pd
|
||||
|
||||
# Explicitly specify openpyxl engine; the default engine varies by pandas
|
||||
# version and does not support legacy .xls (which is excluded by our format map).
|
||||
df = pd.read_excel(io.BytesIO(content), engine="openpyxl")
|
||||
return _df_to_rows(df)
|
||||
|
||||
|
||||
def _df_to_rows(df: Any) -> list[list[Any]]:
|
||||
"""Convert a DataFrame to ``list[list[Any]]`` with a header row.
|
||||
|
||||
NaN values are replaced with ``None`` so the result is JSON-serializable.
|
||||
Uses explicit cell-level checking because ``df.where(df.notna(), None)``
|
||||
silently converts ``None`` back to ``NaN`` in float64 columns.
|
||||
"""
|
||||
header = df.columns.tolist()
|
||||
rows = [
|
||||
[None if _is_nan(cell) else cell for cell in row] for row in df.values.tolist()
|
||||
]
|
||||
return [header] + rows
|
||||
|
||||
|
||||
def _is_nan(cell: Any) -> bool:
|
||||
"""Check if a cell value is NaN, handling non-scalar types (lists, dicts).
|
||||
|
||||
``pd.isna()`` on a list/dict returns a boolean array which raises
|
||||
``ValueError`` in a boolean context. Guard with a scalar check first.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
return bool(pd.api.types.is_scalar(cell) and pd.isna(cell))
|
||||
|
||||
|
||||
_BINARY_PARSERS: dict[str, Callable[[bytes], Any]] = {
|
||||
"parquet": _parse_parquet,
|
||||
"xlsx": _parse_xlsx,
|
||||
}
|
||||
@@ -1,624 +0,0 @@
|
||||
"""Tests for file_content_parser — format inference and structured parsing."""
|
||||
|
||||
import io
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.file_content_parser import (
|
||||
BINARY_FORMATS,
|
||||
infer_format_from_uri,
|
||||
parse_file_content,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# infer_format_from_uri
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInferFormat:
|
||||
# --- extension-based ---
|
||||
|
||||
def test_json_extension(self):
|
||||
assert infer_format_from_uri("/home/user/data.json") == "json"
|
||||
|
||||
def test_jsonl_extension(self):
|
||||
assert infer_format_from_uri("/tmp/events.jsonl") == "jsonl"
|
||||
|
||||
def test_ndjson_extension(self):
|
||||
assert infer_format_from_uri("/tmp/events.ndjson") == "jsonl"
|
||||
|
||||
def test_csv_extension(self):
|
||||
assert infer_format_from_uri("workspace:///reports/sales.csv") == "csv"
|
||||
|
||||
def test_tsv_extension(self):
|
||||
assert infer_format_from_uri("/home/user/data.tsv") == "tsv"
|
||||
|
||||
def test_yaml_extension(self):
|
||||
assert infer_format_from_uri("/home/user/config.yaml") == "yaml"
|
||||
|
||||
def test_yml_extension(self):
|
||||
assert infer_format_from_uri("/home/user/config.yml") == "yaml"
|
||||
|
||||
def test_toml_extension(self):
|
||||
assert infer_format_from_uri("/home/user/config.toml") == "toml"
|
||||
|
||||
def test_parquet_extension(self):
|
||||
assert infer_format_from_uri("/data/table.parquet") == "parquet"
|
||||
|
||||
def test_xlsx_extension(self):
|
||||
assert infer_format_from_uri("/data/spreadsheet.xlsx") == "xlsx"
|
||||
|
||||
def test_xls_extension_returns_xls_label(self):
|
||||
# Legacy .xls is mapped so callers can produce a helpful error.
|
||||
assert infer_format_from_uri("/data/old_spreadsheet.xls") == "xls"
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert infer_format_from_uri("/data/FILE.JSON") == "json"
|
||||
assert infer_format_from_uri("/data/FILE.CSV") == "csv"
|
||||
|
||||
def test_unicode_filename(self):
|
||||
assert infer_format_from_uri("/home/user/\u30c7\u30fc\u30bf.json") == "json"
|
||||
assert infer_format_from_uri("/home/user/\u00e9t\u00e9.csv") == "csv"
|
||||
|
||||
def test_unknown_extension(self):
|
||||
assert infer_format_from_uri("/home/user/readme.txt") is None
|
||||
|
||||
def test_no_extension(self):
|
||||
assert infer_format_from_uri("workspace://abc123") is None
|
||||
|
||||
# --- MIME-based ---
|
||||
|
||||
def test_mime_json(self):
|
||||
assert infer_format_from_uri("workspace://abc123#application/json") == "json"
|
||||
|
||||
def test_mime_csv(self):
|
||||
assert infer_format_from_uri("workspace://abc123#text/csv") == "csv"
|
||||
|
||||
def test_mime_tsv(self):
|
||||
assert (
|
||||
infer_format_from_uri("workspace://abc123#text/tab-separated-values")
|
||||
== "tsv"
|
||||
)
|
||||
|
||||
def test_mime_ndjson(self):
|
||||
assert (
|
||||
infer_format_from_uri("workspace://abc123#application/x-ndjson") == "jsonl"
|
||||
)
|
||||
|
||||
def test_mime_yaml(self):
|
||||
assert infer_format_from_uri("workspace://abc123#application/x-yaml") == "yaml"
|
||||
|
||||
def test_mime_xlsx(self):
|
||||
uri = "workspace://abc123#application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
assert infer_format_from_uri(uri) == "xlsx"
|
||||
|
||||
def test_mime_parquet(self):
|
||||
assert (
|
||||
infer_format_from_uri("workspace://abc123#application/vnd.apache.parquet")
|
||||
== "parquet"
|
||||
)
|
||||
|
||||
def test_unknown_mime(self):
|
||||
assert infer_format_from_uri("workspace://abc123#text/plain") is None
|
||||
|
||||
def test_unknown_mime_falls_through_to_extension(self):
|
||||
# Unknown MIME (text/plain) should fall through to extension-based detection.
|
||||
assert infer_format_from_uri("workspace:///data.csv#text/plain") == "csv"
|
||||
|
||||
# --- MIME takes precedence over extension ---
|
||||
|
||||
def test_mime_overrides_extension(self):
|
||||
# .txt extension but JSON MIME → json
|
||||
assert infer_format_from_uri("workspace:///file.txt#application/json") == "json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — JSON
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseJson:
|
||||
def test_array(self):
|
||||
result = parse_file_content("[1, 2, 3]", "json")
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
def test_object(self):
|
||||
result = parse_file_content('{"key": "value"}', "json")
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_nested(self):
|
||||
content = json.dumps({"rows": [[1, 2], [3, 4]]})
|
||||
result = parse_file_content(content, "json")
|
||||
assert result == {"rows": [[1, 2], [3, 4]]}
|
||||
|
||||
def test_scalar_string_stays_as_string(self):
|
||||
result = parse_file_content('"hello"', "json")
|
||||
assert result == '"hello"' # original content, not parsed
|
||||
|
||||
def test_scalar_number_stays_as_string(self):
|
||||
result = parse_file_content("42", "json")
|
||||
assert result == "42"
|
||||
|
||||
def test_scalar_boolean_stays_as_string(self):
|
||||
result = parse_file_content("true", "json")
|
||||
assert result == "true"
|
||||
|
||||
def test_null_stays_as_string(self):
|
||||
result = parse_file_content("null", "json")
|
||||
assert result == "null"
|
||||
|
||||
def test_invalid_json_fallback(self):
|
||||
content = "not json at all"
|
||||
result = parse_file_content(content, "json")
|
||||
assert result == content
|
||||
|
||||
def test_empty_string_fallback(self):
|
||||
result = parse_file_content("", "json")
|
||||
assert result == ""
|
||||
|
||||
def test_bytes_input_decoded(self):
|
||||
result = parse_file_content(b"[1, 2, 3]", "json")
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — JSONL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseJsonl:
|
||||
def test_tabular_uniform_dicts_to_table_format(self):
|
||||
"""JSONL with uniform dict keys → table format (header + rows),
|
||||
consistent with CSV/TSV/Parquet/Excel output."""
|
||||
content = '{"name":"apple","color":"red"}\n{"name":"banana","color":"yellow"}\n{"name":"cherry","color":"red"}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [
|
||||
["name", "color"],
|
||||
["apple", "red"],
|
||||
["banana", "yellow"],
|
||||
["cherry", "red"],
|
||||
]
|
||||
|
||||
def test_tabular_single_key_dicts(self):
|
||||
"""JSONL with single-key uniform dicts → table format."""
|
||||
content = '{"a": 1}\n{"a": 2}\n{"a": 3}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [["a"], [1], [2], [3]]
|
||||
|
||||
def test_tabular_blank_lines_skipped(self):
|
||||
content = '{"a": 1}\n\n{"a": 2}\n'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [["a"], [1], [2]]
|
||||
|
||||
def test_heterogeneous_dicts_stay_as_list(self):
|
||||
"""JSONL with different keys across objects → list of dicts (no table)."""
|
||||
content = '{"name":"apple"}\n{"color":"red"}\n{"size":3}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [{"name": "apple"}, {"color": "red"}, {"size": 3}]
|
||||
|
||||
def test_partially_overlapping_keys_stay_as_list(self):
|
||||
"""JSONL dicts with partially overlapping keys → list of dicts."""
|
||||
content = '{"name":"apple","color":"red"}\n{"name":"banana","size":"medium"}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [
|
||||
{"name": "apple", "color": "red"},
|
||||
{"name": "banana", "size": "medium"},
|
||||
]
|
||||
|
||||
def test_mixed_types_stay_as_list(self):
|
||||
"""JSONL with non-dict lines → list of parsed values (no table)."""
|
||||
content = '1\n"hello"\n[1,2]\n'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [1, "hello", [1, 2]]
|
||||
|
||||
def test_mixed_dicts_and_non_dicts_stay_as_list(self):
|
||||
"""JSONL mixing dicts and non-dicts → list of parsed values."""
|
||||
content = '{"a": 1}\n42\n{"b": 2}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [{"a": 1}, 42, {"b": 2}]
|
||||
|
||||
def test_tabular_preserves_key_order(self):
|
||||
"""Table header should follow the key order of the first object."""
|
||||
content = '{"z": 1, "a": 2}\n{"z": 3, "a": 4}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result[0] == ["z", "a"] # order from first object
|
||||
assert result[1] == [1, 2]
|
||||
assert result[2] == [3, 4]
|
||||
|
||||
def test_single_dict_stays_as_list(self):
|
||||
"""Single-line JSONL with one dict → [dict], NOT a table.
|
||||
Tabular detection requires ≥2 dicts to avoid vacuously true all()."""
|
||||
content = '{"a": 1, "b": 2}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [{"a": 1, "b": 2}]
|
||||
|
||||
def test_tabular_with_none_values(self):
|
||||
"""Uniform keys but some null values → table with None cells."""
|
||||
content = '{"name":"apple","color":"red"}\n{"name":"banana","color":null}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [
|
||||
["name", "color"],
|
||||
["apple", "red"],
|
||||
["banana", None],
|
||||
]
|
||||
|
||||
def test_empty_file_fallback(self):
|
||||
result = parse_file_content("", "jsonl")
|
||||
assert result == ""
|
||||
|
||||
def test_all_blank_lines_fallback(self):
|
||||
result = parse_file_content("\n\n\n", "jsonl")
|
||||
assert result == "\n\n\n"
|
||||
|
||||
def test_invalid_line_fallback(self):
|
||||
content = '{"a": 1}\nnot json\n'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == content # fallback
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — CSV
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseCsv:
|
||||
def test_basic(self):
|
||||
content = "Name,Score\nAlice,90\nBob,85"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
|
||||
|
||||
def test_quoted_fields(self):
|
||||
content = 'Name,Bio\nAlice,"Loves, commas"\nBob,Simple'
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result[1] == ["Alice", "Loves, commas"]
|
||||
|
||||
def test_single_column_fallback(self):
|
||||
# Only 1 column — not tabular enough.
|
||||
content = "Name\nAlice\nBob"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == content
|
||||
|
||||
def test_empty_rows_skipped(self):
|
||||
content = "A,B\n\n1,2\n\n3,4"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["A", "B"], ["1", "2"], ["3", "4"]]
|
||||
|
||||
def test_empty_file_fallback(self):
|
||||
result = parse_file_content("", "csv")
|
||||
assert result == ""
|
||||
|
||||
def test_utf8_bom(self):
|
||||
"""CSV with a UTF-8 BOM should parse correctly (BOM stripped by decode)."""
|
||||
bom = "\ufeff"
|
||||
content = bom + "Name,Score\nAlice,90\nBob,85"
|
||||
result = parse_file_content(content, "csv")
|
||||
# The BOM may be part of the first header cell; ensure rows are still parsed.
|
||||
assert len(result) == 3
|
||||
assert result[1] == ["Alice", "90"]
|
||||
assert result[2] == ["Bob", "85"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — TSV
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseTsv:
|
||||
def test_basic(self):
|
||||
content = "Name\tScore\nAlice\t90\nBob\t85"
|
||||
result = parse_file_content(content, "tsv")
|
||||
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
|
||||
|
||||
def test_single_column_fallback(self):
|
||||
content = "Name\nAlice\nBob"
|
||||
result = parse_file_content(content, "tsv")
|
||||
assert result == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — YAML
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseYaml:
|
||||
def test_list(self):
|
||||
content = "- apple\n- banana\n- cherry"
|
||||
result = parse_file_content(content, "yaml")
|
||||
assert result == ["apple", "banana", "cherry"]
|
||||
|
||||
def test_dict(self):
|
||||
content = "name: Alice\nage: 30"
|
||||
result = parse_file_content(content, "yaml")
|
||||
assert result == {"name": "Alice", "age": 30}
|
||||
|
||||
def test_nested(self):
|
||||
content = "users:\n - name: Alice\n - name: Bob"
|
||||
result = parse_file_content(content, "yaml")
|
||||
assert result == {"users": [{"name": "Alice"}, {"name": "Bob"}]}
|
||||
|
||||
def test_scalar_stays_as_string(self):
|
||||
result = parse_file_content("hello world", "yaml")
|
||||
assert result == "hello world"
|
||||
|
||||
def test_invalid_yaml_fallback(self):
|
||||
content = ":\n :\n invalid: - -"
|
||||
result = parse_file_content(content, "yaml")
|
||||
# Malformed YAML should fall back to the original string, not raise.
|
||||
assert result == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — TOML
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseToml:
|
||||
def test_basic(self):
|
||||
content = '[server]\nhost = "localhost"\nport = 8080'
|
||||
result = parse_file_content(content, "toml")
|
||||
assert result == {"server": {"host": "localhost", "port": 8080}}
|
||||
|
||||
def test_flat(self):
|
||||
content = 'name = "test"\ncount = 42'
|
||||
result = parse_file_content(content, "toml")
|
||||
assert result == {"name": "test", "count": 42}
|
||||
|
||||
def test_empty_string_returns_empty_dict(self):
|
||||
result = parse_file_content("", "toml")
|
||||
assert result == {}
|
||||
|
||||
def test_invalid_toml_fallback(self):
|
||||
result = parse_file_content("not = [valid toml", "toml")
|
||||
assert result == "not = [valid toml"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — Parquet (binary)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
try:
|
||||
import pyarrow as _pa # noqa: F401 # pyright: ignore[reportMissingImports]
|
||||
|
||||
_has_pyarrow = True
|
||||
except ImportError:
|
||||
_has_pyarrow = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _has_pyarrow, reason="pyarrow not installed")
|
||||
class TestParseParquet:
|
||||
@pytest.fixture
|
||||
def parquet_bytes(self) -> bytes:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"Name": ["Alice", "Bob"], "Score": [90, 85]})
|
||||
buf = io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
return buf.getvalue()
|
||||
|
||||
def test_basic(self, parquet_bytes: bytes):
|
||||
result = parse_file_content(parquet_bytes, "parquet")
|
||||
assert result == [["Name", "Score"], ["Alice", 90], ["Bob", 85]]
|
||||
|
||||
def test_string_input_fallback(self):
|
||||
# Parquet is binary — string input can't be parsed.
|
||||
result = parse_file_content("not parquet", "parquet")
|
||||
assert result == "not parquet"
|
||||
|
||||
def test_invalid_bytes_fallback(self):
|
||||
result = parse_file_content(b"not parquet bytes", "parquet")
|
||||
assert result == b"not parquet bytes"
|
||||
|
||||
def test_empty_bytes_fallback(self):
|
||||
"""Empty binary input should return the empty bytes, not crash."""
|
||||
result = parse_file_content(b"", "parquet")
|
||||
assert result == b""
|
||||
|
||||
def test_nan_replaced_with_none(self):
|
||||
"""NaN values in Parquet must become None for JSON serializability."""
|
||||
import math
|
||||
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"A": [1.0, float("nan"), 3.0], "B": ["x", None, "z"]})
|
||||
buf = io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
result = parse_file_content(buf.getvalue(), "parquet")
|
||||
# Row with NaN in float col → None
|
||||
assert result[2][0] is None # float NaN → None
|
||||
assert result[2][1] is None # str None → None
|
||||
# Ensure no NaN leaks
|
||||
for row in result[1:]:
|
||||
for cell in row:
|
||||
if isinstance(cell, float):
|
||||
assert not math.isnan(cell), f"NaN leaked: {row}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — Excel (binary)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseExcel:
|
||||
@pytest.fixture
|
||||
def xlsx_bytes(self) -> bytes:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"Name": ["Alice", "Bob"], "Score": [90, 85]})
|
||||
buf = io.BytesIO()
|
||||
df.to_excel(buf, index=False) # type: ignore[arg-type] # BytesIO is a valid target
|
||||
return buf.getvalue()
|
||||
|
||||
def test_basic(self, xlsx_bytes: bytes):
|
||||
result = parse_file_content(xlsx_bytes, "xlsx")
|
||||
assert result == [["Name", "Score"], ["Alice", 90], ["Bob", 85]]
|
||||
|
||||
def test_string_input_fallback(self):
|
||||
result = parse_file_content("not xlsx", "xlsx")
|
||||
assert result == "not xlsx"
|
||||
|
||||
def test_invalid_bytes_fallback(self):
|
||||
result = parse_file_content(b"not xlsx bytes", "xlsx")
|
||||
assert result == b"not xlsx bytes"
|
||||
|
||||
def test_empty_bytes_fallback(self):
|
||||
"""Empty binary input should return the empty bytes, not crash."""
|
||||
result = parse_file_content(b"", "xlsx")
|
||||
assert result == b""
|
||||
|
||||
def test_nan_replaced_with_none(self):
|
||||
"""NaN values in float columns must become None for JSON serializability."""
|
||||
import math
|
||||
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"A": [1.0, float("nan"), 3.0], "B": ["x", "y", None]})
|
||||
buf = io.BytesIO()
|
||||
df.to_excel(buf, index=False) # type: ignore[arg-type]
|
||||
result = parse_file_content(buf.getvalue(), "xlsx")
|
||||
# Row with NaN in float col → None, not float('nan')
|
||||
assert result[2][0] is None # float NaN → None
|
||||
assert result[3][1] is None # str None → None
|
||||
# Ensure no NaN leaks
|
||||
for row in result[1:]: # skip header
|
||||
for cell in row:
|
||||
if isinstance(cell, float):
|
||||
assert not math.isnan(cell), f"NaN leaked: {row}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — unknown format / fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFallback:
|
||||
def test_unknown_format_returns_content(self):
|
||||
result = parse_file_content("hello world", "xml")
|
||||
assert result == "hello world"
|
||||
|
||||
def test_none_format_returns_content(self):
|
||||
# Shouldn't normally be called with unrecognised format, but must not crash.
|
||||
result = parse_file_content("hello", "unknown_format")
|
||||
assert result == "hello"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BINARY_FORMATS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBinaryFormats:
|
||||
def test_parquet_is_binary(self):
|
||||
assert "parquet" in BINARY_FORMATS
|
||||
|
||||
def test_xlsx_is_binary(self):
|
||||
assert "xlsx" in BINARY_FORMATS
|
||||
|
||||
def test_text_formats_not_binary(self):
|
||||
for fmt in ("json", "jsonl", "csv", "tsv", "yaml", "toml"):
|
||||
assert fmt not in BINARY_FORMATS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MIME mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMimeMapping:
|
||||
def test_application_yaml(self):
|
||||
assert infer_format_from_uri("workspace://abc123#application/yaml") == "yaml"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CSV sniffer fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCsvSnifferFallback:
|
||||
def test_tab_delimited_with_csv_format(self):
|
||||
"""Tab-delimited content parsed as csv should use sniffer fallback."""
|
||||
content = "Name\tScore\nAlice\t90\nBob\t85"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
|
||||
|
||||
def test_sniffer_failure_returns_content(self):
|
||||
"""When sniffer fails, single-column falls back to raw content."""
|
||||
content = "Name\nAlice\nBob"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenpyxlInvalidFile fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOpenpyxlFallback:
|
||||
def test_invalid_xlsx_non_strict(self):
|
||||
"""Invalid xlsx bytes should fall back gracefully in non-strict mode."""
|
||||
result = parse_file_content(b"not xlsx bytes", "xlsx")
|
||||
assert result == b"not xlsx bytes"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Header-only CSV
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHeaderOnlyCsv:
|
||||
def test_header_only_csv_returns_header_row(self):
|
||||
"""CSV with only a header row (no data rows) should return [[header]]."""
|
||||
content = "Name,Score"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"]]
|
||||
|
||||
def test_header_only_csv_with_trailing_newline(self):
|
||||
content = "Name,Score\n"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Binary format + line range (line range ignored for binary formats)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _has_pyarrow, reason="pyarrow not installed")
|
||||
class TestBinaryFormatLineRange:
|
||||
def test_parquet_ignores_line_range(self):
|
||||
"""Binary formats should parse the full file regardless of line range.
|
||||
|
||||
Line ranges are meaningless for binary formats (parquet/xlsx) — the
|
||||
caller (file_ref._expand_bare_ref) passes raw bytes and the parser
|
||||
should return the complete structured data.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
|
||||
buf = io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
# parse_file_content itself doesn't take a line range — this tests
|
||||
# that the full content is parsed even though the bytes could have
|
||||
# been truncated upstream (it's not, by design).
|
||||
result = parse_file_content(buf.getvalue(), "parquet")
|
||||
assert result == [["A", "B"], [1, 4], [2, 5], [3, 6]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Legacy .xls UX
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestXlsFallback:
|
||||
def test_xls_returns_helpful_error_string(self):
|
||||
"""Uploading a .xls file should produce a helpful error, not garbled binary."""
|
||||
result = parse_file_content(b"\xd0\xcf\x11\xe0garbled", "xls")
|
||||
assert isinstance(result, str)
|
||||
assert ".xlsx" in result
|
||||
assert "not supported" in result.lower()
|
||||
|
||||
def test_xls_with_string_content(self):
|
||||
result = parse_file_content("some text", "xls")
|
||||
assert isinstance(result, str)
|
||||
assert ".xlsx" in result
|
||||
@@ -8,12 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.file import (
|
||||
is_media_file_ref,
|
||||
parse_data_uri,
|
||||
resolve_media_content,
|
||||
store_media_file,
|
||||
)
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
@@ -349,162 +344,3 @@ class TestFileCloudIntegration:
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_media_file_ref
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsMediaFileRef:
|
||||
def test_data_uri(self):
|
||||
assert is_media_file_ref("data:image/png;base64,iVBORw0KGg==") is True
|
||||
|
||||
def test_workspace_uri(self):
|
||||
assert is_media_file_ref("workspace://abc123") is True
|
||||
|
||||
def test_workspace_uri_with_mime(self):
|
||||
assert is_media_file_ref("workspace://abc123#image/png") is True
|
||||
|
||||
def test_http_url(self):
|
||||
assert is_media_file_ref("http://example.com/image.png") is True
|
||||
|
||||
def test_https_url(self):
|
||||
assert is_media_file_ref("https://example.com/image.png") is True
|
||||
|
||||
def test_plain_text(self):
|
||||
assert is_media_file_ref("print('hello')") is False
|
||||
|
||||
def test_local_path(self):
|
||||
assert is_media_file_ref("/tmp/file.txt") is False
|
||||
|
||||
def test_empty_string(self):
|
||||
assert is_media_file_ref("") is False
|
||||
|
||||
def test_filename(self):
|
||||
assert is_media_file_ref("image.png") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_data_uri
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseDataUri:
|
||||
def test_valid_png(self):
|
||||
result = parse_data_uri("data:image/png;base64,iVBORw0KGg==")
|
||||
assert result is not None
|
||||
mime, payload = result
|
||||
assert mime == "image/png"
|
||||
assert payload == "iVBORw0KGg=="
|
||||
|
||||
def test_valid_text(self):
|
||||
result = parse_data_uri("data:text/plain;base64,SGVsbG8=")
|
||||
assert result is not None
|
||||
assert result[0] == "text/plain"
|
||||
assert result[1] == "SGVsbG8="
|
||||
|
||||
def test_mime_case_normalized(self):
|
||||
result = parse_data_uri("data:IMAGE/PNG;base64,abc")
|
||||
assert result is not None
|
||||
assert result[0] == "image/png"
|
||||
|
||||
def test_not_data_uri(self):
|
||||
assert parse_data_uri("workspace://abc123") is None
|
||||
|
||||
def test_plain_text(self):
|
||||
assert parse_data_uri("hello world") is None
|
||||
|
||||
def test_missing_base64(self):
|
||||
assert parse_data_uri("data:image/png;utf-8,abc") is None
|
||||
|
||||
def test_empty_payload(self):
|
||||
result = parse_data_uri("data:image/png;base64,")
|
||||
assert result is not None
|
||||
assert result[1] == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_media_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveMediaContent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_text_passthrough(self):
|
||||
"""Plain text content (not a media ref) passes through unchanged."""
|
||||
ctx = make_test_context()
|
||||
result = await resolve_media_content(
|
||||
MediaFileType("print('hello')"),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == "print('hello')"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_string_passthrough(self):
|
||||
"""Empty string passes through unchanged."""
|
||||
ctx = make_test_context()
|
||||
result = await resolve_media_content(
|
||||
MediaFileType(""),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_media_ref_delegates_to_store(self):
|
||||
"""Media references are resolved via store_media_file."""
|
||||
ctx = make_test_context()
|
||||
with patch(
|
||||
"backend.util.file.store_media_file",
|
||||
new=AsyncMock(return_value=MediaFileType("data:image/png;base64,abc")),
|
||||
) as mock_store:
|
||||
result = await resolve_media_content(
|
||||
MediaFileType("workspace://img123"),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == "data:image/png;base64,abc"
|
||||
mock_store.assert_called_once_with(
|
||||
MediaFileType("workspace://img123"),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_uri_delegates_to_store(self):
|
||||
"""Data URIs are also resolved via store_media_file."""
|
||||
ctx = make_test_context()
|
||||
data_uri = "data:image/png;base64,iVBORw0KGg=="
|
||||
with patch(
|
||||
"backend.util.file.store_media_file",
|
||||
new=AsyncMock(return_value=MediaFileType(data_uri)),
|
||||
) as mock_store:
|
||||
result = await resolve_media_content(
|
||||
MediaFileType(data_uri),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == data_uri
|
||||
mock_store.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_https_url_delegates_to_store(self):
|
||||
"""HTTPS URLs are resolved via store_media_file."""
|
||||
ctx = make_test_context()
|
||||
with patch(
|
||||
"backend.util.file.store_media_file",
|
||||
new=AsyncMock(return_value=MediaFileType("data:image/png;base64,abc")),
|
||||
) as mock_store:
|
||||
result = await resolve_media_content(
|
||||
MediaFileType("https://example.com/image.png"),
|
||||
ctx,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
assert result == "data:image/png;base64,abc"
|
||||
mock_store.assert_called_once_with(
|
||||
MediaFileType("https://example.com/image.png"),
|
||||
ctx,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ from sentry_sdk.integrations.launchdarkly import LaunchDarklyIntegration
|
||||
from sentry_sdk.integrations.logging import LoggingIntegration
|
||||
|
||||
from backend.util import feature_flag
|
||||
from backend.util.settings import BehaveAs, Settings
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -21,95 +21,6 @@ class DiscordChannel(str, Enum):
|
||||
PRODUCT = "product" # For product alerts (low balance, zero balance, etc.)
|
||||
|
||||
|
||||
def _before_send(event, hint):
|
||||
"""Filter out expected/transient errors from Sentry to reduce noise."""
|
||||
if "exc_info" in hint:
|
||||
exc_type, exc_value, _ = hint["exc_info"]
|
||||
exc_msg = str(exc_value).lower() if exc_value else ""
|
||||
|
||||
# AMQP/RabbitMQ transient connection errors — expected during deploys
|
||||
amqp_keywords = [
|
||||
"amqpconnection",
|
||||
"amqpconnector",
|
||||
"connection_forced",
|
||||
"channelinvalidstateerror",
|
||||
"no active transport",
|
||||
]
|
||||
if any(kw in exc_msg for kw in amqp_keywords):
|
||||
return None
|
||||
|
||||
# "connection refused" only for AMQP-related exceptions (not other services)
|
||||
if "connection refused" in exc_msg:
|
||||
exc_module = getattr(exc_type, "__module__", "") or ""
|
||||
exc_name = getattr(exc_type, "__name__", "") or ""
|
||||
amqp_indicators = ["aio_pika", "aiormq", "amqp", "pika", "rabbitmq"]
|
||||
if any(
|
||||
ind in exc_module.lower() or ind in exc_name.lower()
|
||||
for ind in amqp_indicators
|
||||
) or any(kw in exc_msg for kw in ["amqp", "pika", "rabbitmq"]):
|
||||
return None
|
||||
|
||||
# User-caused credential/auth errors — not platform bugs
|
||||
user_auth_keywords = [
|
||||
"incorrect api key",
|
||||
"invalid x-api-key",
|
||||
"missing authentication header",
|
||||
"invalid api token",
|
||||
"authentication_error",
|
||||
]
|
||||
if any(kw in exc_msg for kw in user_auth_keywords):
|
||||
return None
|
||||
|
||||
# Expected business logic — insufficient balance
|
||||
if "insufficient balance" in exc_msg or "no credits left" in exc_msg:
|
||||
return None
|
||||
|
||||
# Expected security check — blocked IP access
|
||||
if "access to blocked or private ip" in exc_msg:
|
||||
return None
|
||||
|
||||
# Discord bot token misconfiguration — not a platform error
|
||||
if "improper token has been passed" in exc_msg or (
|
||||
exc_type and exc_type.__name__ == "Forbidden" and "50001" in exc_msg
|
||||
):
|
||||
return None
|
||||
|
||||
# Google metadata DNS errors — expected in non-GCP environments
|
||||
if (
|
||||
"metadata.google.internal" in exc_msg
|
||||
and settings.config.behave_as != BehaveAs.CLOUD
|
||||
):
|
||||
return None
|
||||
|
||||
# Inactive email recipients — expected for bounced addresses
|
||||
if "marked as inactive" in exc_msg or "inactive addresses" in exc_msg:
|
||||
return None
|
||||
|
||||
# Also filter log-based events for known noisy messages.
|
||||
# Sentry's LoggingIntegration stores log messages under "logentry", not "message".
|
||||
logentry = event.get("logentry") or {}
|
||||
log_msg = (
|
||||
logentry.get("formatted") or logentry.get("message") or event.get("message")
|
||||
)
|
||||
if event.get("logger") and log_msg:
|
||||
msg = log_msg.lower()
|
||||
noisy_patterns = [
|
||||
"amqpconnection",
|
||||
"connection_forced",
|
||||
"unclosed client session",
|
||||
"unclosed connector",
|
||||
]
|
||||
if any(p in msg for p in noisy_patterns):
|
||||
return None
|
||||
# "connection refused" in logs only when AMQP-related context is present
|
||||
if "connection refused" in msg and any(
|
||||
ind in msg for ind in ("amqp", "pika", "rabbitmq", "aio_pika", "aiormq")
|
||||
):
|
||||
return None
|
||||
|
||||
return event
|
||||
|
||||
|
||||
def sentry_init():
|
||||
sentry_dsn = settings.secrets.sentry_dsn
|
||||
integrations = []
|
||||
@@ -124,7 +35,6 @@ def sentry_init():
|
||||
profiles_sample_rate=1.0,
|
||||
environment=f"app:{settings.config.app_env.value}-behave:{settings.config.behave_as.value}",
|
||||
_experiments={"enable_logs": True},
|
||||
before_send=_before_send,
|
||||
integrations=[
|
||||
AsyncioIntegration(),
|
||||
LoggingIntegration(sentry_logs_level=logging.INFO),
|
||||
|
||||
@@ -64,7 +64,7 @@ def send_rate_limited_discord_alert(
|
||||
return True
|
||||
|
||||
except Exception as alert_error:
|
||||
logger.warning(f"Failed to send Discord alert: {alert_error}")
|
||||
logger.error(f"Failed to send Discord alert: {alert_error}")
|
||||
return False
|
||||
|
||||
|
||||
@@ -182,8 +182,7 @@ def conn_retry(
|
||||
func_name = getattr(retry_state.fn, "__name__", "unknown")
|
||||
|
||||
if retry_state.outcome.failed and retry_state.next_action is None:
|
||||
# Final failure is logged by sync_wrapper/async_wrapper — skip here to avoid duplicates
|
||||
pass
|
||||
logger.error(f"{prefix} {action_name} failed after retries: {exception}")
|
||||
else:
|
||||
if attempt_number == EXCESSIVE_RETRY_THRESHOLD:
|
||||
if send_rate_limited_discord_alert(
|
||||
@@ -226,7 +225,7 @@ def conn_retry(
|
||||
logger.info(f"{prefix} {action_name} completed successfully.")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"{prefix} {action_name} failed after retries: {e}")
|
||||
logger.error(f"{prefix} {action_name} failed after retries: {e}")
|
||||
raise
|
||||
|
||||
@wraps(func)
|
||||
@@ -238,7 +237,7 @@ def conn_retry(
|
||||
logger.info(f"{prefix} {action_name} completed successfully.")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"{prefix} {action_name} failed after retries: {e}")
|
||||
logger.error(f"{prefix} {action_name} failed after retries: {e}")
|
||||
raise
|
||||
|
||||
return async_wrapper if is_coroutine else sync_wrapper
|
||||
|
||||
@@ -89,10 +89,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
le=500,
|
||||
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
|
||||
)
|
||||
tally_extraction_llm_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
description="OpenRouter model ID used for extracting business understanding from Tally form data",
|
||||
)
|
||||
ollama_host: str = Field(
|
||||
default="localhost:11434",
|
||||
description="Default Ollama host; exempted from SSRF checks.",
|
||||
@@ -121,10 +117,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=True,
|
||||
description="If authentication is enabled or not",
|
||||
)
|
||||
enable_invite_gate: bool = Field(
|
||||
default=True,
|
||||
description="If the invite-only signup gate is enforced",
|
||||
)
|
||||
enable_credit: bool = Field(
|
||||
default=False,
|
||||
description="If user credit system is enabled or not",
|
||||
|
||||
@@ -249,87 +249,6 @@ def convert(value: Any, target_type: Any) -> Any:
|
||||
raise ConversionError(f"Failed to convert {value} to {target_type}") from e
|
||||
|
||||
|
||||
def _value_satisfies_type(value: Any, target: Any) -> bool:
|
||||
"""Check whether *value* already satisfies *target*, including inner elements.
|
||||
|
||||
For union types this checks each member; for generic container types it
|
||||
recursively checks that inner elements satisfy the type args (e.g. every
|
||||
element in a ``list[str]`` is a ``str``). Returns ``False`` when uncertain
|
||||
so the caller falls through to :func:`convert`.
|
||||
"""
|
||||
# typing.Any cannot be used with isinstance(); treat as always satisfied.
|
||||
if target is Any:
|
||||
return True
|
||||
|
||||
origin = get_origin(target)
|
||||
|
||||
if origin is Union or origin is types.UnionType:
|
||||
non_none = [a for a in get_args(target) if a is not type(None)]
|
||||
return any(_value_satisfies_type(value, member) for member in non_none)
|
||||
|
||||
# Generic container type (e.g. list[str], dict[str, int])
|
||||
if origin is not None:
|
||||
# Guard: origin may not be a runtime type (e.g. Literal)
|
||||
if not isinstance(origin, type):
|
||||
return False
|
||||
if not isinstance(value, origin):
|
||||
return False
|
||||
args = get_args(target)
|
||||
if not args:
|
||||
return True
|
||||
# Check inner elements satisfy the type args
|
||||
if _is_type_or_subclass(origin, list):
|
||||
return all(_value_satisfies_type(v, args[0]) for v in value)
|
||||
if _is_type_or_subclass(origin, dict) and len(args) >= 2:
|
||||
return all(
|
||||
_value_satisfies_type(k, args[0]) and _value_satisfies_type(v, args[1])
|
||||
for k, v in value.items()
|
||||
)
|
||||
if (
|
||||
_is_type_or_subclass(origin, set) or _is_type_or_subclass(origin, frozenset)
|
||||
) and args:
|
||||
return all(_value_satisfies_type(v, args[0]) for v in value)
|
||||
if _is_type_or_subclass(origin, tuple):
|
||||
# Homogeneous tuple[T, ...] — single type + Ellipsis
|
||||
if len(args) == 2 and args[1] is Ellipsis:
|
||||
return all(_value_satisfies_type(v, args[0]) for v in value)
|
||||
# Heterogeneous tuple[T1, T2, ...] — positional types
|
||||
if len(value) != len(args):
|
||||
return False
|
||||
return all(_value_satisfies_type(v, t) for v, t in zip(value, args))
|
||||
# Unhandled generic origin — fall through to convert()
|
||||
return False
|
||||
|
||||
# Simple type (e.g. str, int)
|
||||
if isinstance(target, type):
|
||||
return isinstance(value, target)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def coerce_inputs_to_schema(data: dict[str, Any], schema: type) -> None:
|
||||
"""Coerce *data* values in-place to match *schema*'s field types.
|
||||
|
||||
Uses ``model_fields`` (not ``__annotations__``) so inherited fields are
|
||||
included. Skips coercion when the value already satisfies the target
|
||||
type — in particular for union-typed fields where the value matches one
|
||||
member but differs from the annotation object itself.
|
||||
|
||||
This is the single authoritative coercion step shared by the executor
|
||||
(``validate_exec``) and the CoPilot (``execute_block``).
|
||||
"""
|
||||
for name, field_info in schema.model_fields.items():
|
||||
value = data.get(name)
|
||||
if value is None:
|
||||
continue
|
||||
target = field_info.annotation
|
||||
if target is None:
|
||||
continue
|
||||
if _value_satisfies_type(value, target):
|
||||
continue
|
||||
data[name] = convert(value, target)
|
||||
|
||||
|
||||
class FormattedStringType(str):
|
||||
string_format: str
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import Any, List, Literal, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.type import _value_satisfies_type, coerce_inputs_to_schema, convert
|
||||
from backend.util.type import convert
|
||||
|
||||
|
||||
def test_type_conversion():
|
||||
@@ -48,343 +46,3 @@ def test_type_conversion():
|
||||
# Test other empty list conversions
|
||||
assert convert([], int) == 0 # len([]) = 0
|
||||
assert convert([], Optional[int]) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _value_satisfies_type
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValueSatisfiesType:
|
||||
# --- simple types ---
|
||||
def test_simple_match(self):
|
||||
assert _value_satisfies_type("hello", str) is True
|
||||
assert _value_satisfies_type(42, int) is True
|
||||
assert _value_satisfies_type(3.14, float) is True
|
||||
assert _value_satisfies_type(True, bool) is True
|
||||
|
||||
def test_simple_mismatch(self):
|
||||
assert _value_satisfies_type("hello", int) is False
|
||||
assert _value_satisfies_type(42, str) is False
|
||||
assert _value_satisfies_type([1, 2], str) is False
|
||||
|
||||
# --- Any ---
|
||||
def test_any_always_satisfied(self):
|
||||
assert _value_satisfies_type("hello", Any) is True
|
||||
assert _value_satisfies_type(42, Any) is True
|
||||
assert _value_satisfies_type([1, 2], Any) is True
|
||||
assert _value_satisfies_type(None, Any) is True
|
||||
|
||||
# --- Optional / Union ---
|
||||
def test_optional_with_value(self):
|
||||
assert _value_satisfies_type("hello", Optional[str]) is True
|
||||
assert _value_satisfies_type(42, Optional[int]) is True
|
||||
|
||||
def test_optional_mismatch(self):
|
||||
assert _value_satisfies_type(42, Optional[str]) is False
|
||||
|
||||
def test_union_matches_first_member(self):
|
||||
assert _value_satisfies_type("hello", str | list[str]) is True
|
||||
|
||||
def test_union_matches_second_member(self):
|
||||
assert _value_satisfies_type(["a", "b"], str | list[str]) is True
|
||||
|
||||
def test_union_no_match(self):
|
||||
assert _value_satisfies_type(42, str | list[str]) is False
|
||||
|
||||
# --- list[T] ---
|
||||
def test_list_str_all_match(self):
|
||||
assert _value_satisfies_type(["a", "b", "c"], list[str]) is True
|
||||
|
||||
def test_list_str_inner_mismatch(self):
|
||||
assert _value_satisfies_type([1, 2, 3], list[str]) is False
|
||||
|
||||
def test_list_int_all_match(self):
|
||||
assert _value_satisfies_type([1, 2, 3], list[int]) is True
|
||||
|
||||
def test_list_int_inner_mismatch(self):
|
||||
assert _value_satisfies_type(["1", "2"], list[int]) is False
|
||||
|
||||
def test_empty_list_satisfies_any_list_type(self):
|
||||
assert _value_satisfies_type([], list[str]) is True
|
||||
assert _value_satisfies_type([], list[int]) is True
|
||||
|
||||
def test_string_does_not_satisfy_list(self):
|
||||
assert _value_satisfies_type("hello", list[str]) is False
|
||||
|
||||
# --- nested list[list[str]] ---
|
||||
def test_nested_list_all_match(self):
|
||||
assert _value_satisfies_type([["a", "b"], ["c"]], list[list[str]]) is True
|
||||
|
||||
def test_nested_list_inner_mismatch(self):
|
||||
assert _value_satisfies_type([["a", 1], ["c"]], list[list[str]]) is False
|
||||
|
||||
def test_nested_list_outer_mismatch(self):
|
||||
assert _value_satisfies_type(["a", "b"], list[list[str]]) is False
|
||||
|
||||
# --- dict[K, V] ---
|
||||
def test_dict_str_int_match(self):
|
||||
assert _value_satisfies_type({"a": 1, "b": 2}, dict[str, int]) is True
|
||||
|
||||
def test_dict_str_int_value_mismatch(self):
|
||||
assert _value_satisfies_type({"a": "1", "b": "2"}, dict[str, int]) is False
|
||||
|
||||
def test_dict_str_int_key_mismatch(self):
|
||||
assert _value_satisfies_type({1: 1, 2: 2}, dict[str, int]) is False
|
||||
|
||||
def test_empty_dict_satisfies(self):
|
||||
assert _value_satisfies_type({}, dict[str, int]) is True
|
||||
|
||||
# --- set[T] / tuple[T] ---
|
||||
def test_set_match(self):
|
||||
assert _value_satisfies_type({1, 2, 3}, set[int]) is True
|
||||
|
||||
def test_set_mismatch(self):
|
||||
assert _value_satisfies_type({"a", "b"}, set[int]) is False
|
||||
|
||||
def test_tuple_homogeneous_match(self):
|
||||
assert _value_satisfies_type((1, 2, 3), tuple[int, ...]) is True
|
||||
|
||||
def test_tuple_homogeneous_mismatch(self):
|
||||
assert _value_satisfies_type((1, "2", 3), tuple[int, ...]) is False
|
||||
|
||||
def test_tuple_heterogeneous_match(self):
|
||||
assert _value_satisfies_type(("a", 1, True), tuple[str, int, bool]) is True
|
||||
|
||||
def test_tuple_heterogeneous_mismatch(self):
|
||||
assert _value_satisfies_type(("a", "b", True), tuple[str, int, bool]) is False
|
||||
|
||||
def test_tuple_heterogeneous_wrong_length(self):
|
||||
assert _value_satisfies_type(("a", 1), tuple[str, int, bool]) is False
|
||||
|
||||
# --- bare generics (no args) ---
|
||||
def test_bare_list(self):
|
||||
assert _value_satisfies_type([1, "a"], list) is True
|
||||
|
||||
def test_bare_dict(self):
|
||||
assert _value_satisfies_type({"a": 1}, dict) is True
|
||||
|
||||
# --- union with generic inner mismatch ---
|
||||
def test_union_list_with_wrong_inner_falls_through(self):
|
||||
# [1, 2] doesn't satisfy list[str] (inner mismatch), and not str either
|
||||
assert _value_satisfies_type([1, 2], str | list[str]) is False
|
||||
|
||||
# --- Literal (non-runtime origin) ---
|
||||
def test_literal_does_not_crash(self):
|
||||
"""Literal origins are not runtime types — should return False, not crash."""
|
||||
assert _value_satisfies_type("active", Literal["active", "inactive"]) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# coerce_inputs_to_schema — using real Pydantic models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SampleSchema(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
items: list[str]
|
||||
config: dict[str, int] = {}
|
||||
|
||||
|
||||
class NestedSchema(BaseModel):
|
||||
rows: list[list[str]]
|
||||
|
||||
|
||||
class UnionSchema(BaseModel):
|
||||
content: str | list[str]
|
||||
|
||||
|
||||
class OptionalSchema(BaseModel):
|
||||
label: Optional[str] = None
|
||||
value: int = 0
|
||||
|
||||
|
||||
class AnyFieldSchema(BaseModel):
|
||||
data: Any
|
||||
|
||||
|
||||
class TestCoerceInputsToSchema:
|
||||
def test_string_to_int(self):
|
||||
data: dict[str, Any] = {"name": "test", "count": "42", "items": ["a"]}
|
||||
coerce_inputs_to_schema(data, SampleSchema)
|
||||
assert data["count"] == 42
|
||||
assert isinstance(data["count"], int)
|
||||
|
||||
def test_json_string_to_list(self):
|
||||
data: dict[str, Any] = {"name": "test", "count": 1, "items": '["a","b","c"]'}
|
||||
coerce_inputs_to_schema(data, SampleSchema)
|
||||
assert data["items"] == ["a", "b", "c"]
|
||||
|
||||
def test_already_correct_types_unchanged(self):
|
||||
data: dict[str, Any] = {
|
||||
"name": "test",
|
||||
"count": 42,
|
||||
"items": ["a", "b"],
|
||||
"config": {"x": 1},
|
||||
}
|
||||
coerce_inputs_to_schema(data, SampleSchema)
|
||||
assert data == {
|
||||
"name": "test",
|
||||
"count": 42,
|
||||
"items": ["a", "b"],
|
||||
"config": {"x": 1},
|
||||
}
|
||||
|
||||
def test_inner_element_coercion(self):
|
||||
"""list[str] with int inner elements → coerced to strings."""
|
||||
data: dict[str, Any] = {"name": "test", "count": 1, "items": [1, 2, 3]}
|
||||
coerce_inputs_to_schema(data, SampleSchema)
|
||||
assert data["items"] == ["1", "2", "3"]
|
||||
|
||||
def test_dict_value_coercion(self):
|
||||
"""dict[str, int] with string values → coerced to ints."""
|
||||
data: dict[str, Any] = {
|
||||
"name": "test",
|
||||
"count": 1,
|
||||
"items": [],
|
||||
"config": {"x": "10", "y": "20"},
|
||||
}
|
||||
coerce_inputs_to_schema(data, SampleSchema)
|
||||
assert data["config"] == {"x": 10, "y": 20}
|
||||
|
||||
def test_nested_list_from_json_string(self):
|
||||
data: dict[str, Any] = {
|
||||
"rows": '[["Name","Score"],["Alice","90"]]',
|
||||
}
|
||||
coerce_inputs_to_schema(data, NestedSchema)
|
||||
assert data["rows"] == [["Name", "Score"], ["Alice", "90"]]
|
||||
|
||||
def test_nested_list_already_correct(self):
|
||||
original = [["a", "b"], ["c", "d"]]
|
||||
data: dict[str, Any] = {"rows": original}
|
||||
coerce_inputs_to_schema(data, NestedSchema)
|
||||
assert data["rows"] == original
|
||||
|
||||
def test_union_preserves_valid_list(self):
|
||||
"""list[str] value should NOT be stringified for str | list[str]."""
|
||||
data: dict[str, Any] = {"content": ["a", "b"]}
|
||||
coerce_inputs_to_schema(data, UnionSchema)
|
||||
assert data["content"] == ["a", "b"]
|
||||
assert isinstance(data["content"], list)
|
||||
|
||||
def test_union_preserves_valid_string(self):
|
||||
data: dict[str, Any] = {"content": "hello"}
|
||||
coerce_inputs_to_schema(data, UnionSchema)
|
||||
assert data["content"] == "hello"
|
||||
|
||||
def test_union_list_with_wrong_inner_gets_coerced(self):
|
||||
"""[1, 2] for str | list[str] — inner ints don't match list[str],
|
||||
so convert() is called. convert tries str first → stringifies."""
|
||||
data: dict[str, Any] = {"content": [1, 2]}
|
||||
coerce_inputs_to_schema(data, UnionSchema)
|
||||
# convert([1,2], str | list[str]) tries str first → "[1, 2]"
|
||||
# This is convert()'s union behavior — str wins over list[str]
|
||||
assert isinstance(data["content"], (str, list))
|
||||
|
||||
def test_skips_none_values(self):
|
||||
data: dict[str, Any] = {"label": None, "value": "5"}
|
||||
coerce_inputs_to_schema(data, OptionalSchema)
|
||||
assert data["label"] is None
|
||||
assert data["value"] == 5
|
||||
|
||||
def test_skips_missing_fields(self):
|
||||
data: dict[str, Any] = {"value": "10"}
|
||||
coerce_inputs_to_schema(data, OptionalSchema)
|
||||
assert "label" not in data
|
||||
assert data["value"] == 10
|
||||
|
||||
def test_any_field_skipped(self):
|
||||
"""Fields typed as Any should pass through without coercion."""
|
||||
data: dict[str, Any] = {"data": [1, "mixed", {"nested": True}]}
|
||||
coerce_inputs_to_schema(data, AnyFieldSchema)
|
||||
assert data["data"] == [1, "mixed", {"nested": True}]
|
||||
|
||||
def test_preserves_all_convert_capabilities(self):
|
||||
"""Verify coerce_inputs_to_schema doesn't lose any convert() capability
|
||||
that existed before the _value_satisfies_type gate was added."""
|
||||
|
||||
class FullSchema(BaseModel):
|
||||
as_int: int
|
||||
as_float: float
|
||||
as_bool: bool
|
||||
as_str: str
|
||||
as_list: list[int]
|
||||
as_dict: dict[str, str]
|
||||
|
||||
data: dict[str, Any] = {
|
||||
"as_int": "42",
|
||||
"as_float": "3.14",
|
||||
"as_bool": "True",
|
||||
"as_str": 123,
|
||||
"as_list": "[1,2,3]",
|
||||
"as_dict": '{"a": "b"}',
|
||||
}
|
||||
coerce_inputs_to_schema(data, FullSchema)
|
||||
assert data["as_int"] == 42
|
||||
assert data["as_float"] == 3.14
|
||||
assert data["as_bool"] is True
|
||||
assert data["as_str"] == "123"
|
||||
assert data["as_list"] == [1, 2, 3]
|
||||
assert data["as_dict"] == {"a": "b"}
|
||||
|
||||
def test_inherited_fields_are_coerced(self):
|
||||
"""model_fields includes inherited fields; __annotations__ does not.
|
||||
This verifies that fields from a parent schema are still coerced."""
|
||||
|
||||
class ParentSchema(BaseModel):
|
||||
base_count: int
|
||||
|
||||
class ChildSchema(ParentSchema):
|
||||
name: str
|
||||
|
||||
# base_count is inherited — __annotations__ wouldn't include it
|
||||
assert "base_count" not in ChildSchema.__annotations__
|
||||
assert "base_count" in ChildSchema.model_fields
|
||||
|
||||
data: dict[str, Any] = {"base_count": "42", "name": "test"}
|
||||
coerce_inputs_to_schema(data, ChildSchema)
|
||||
assert data["base_count"] == 42
|
||||
assert isinstance(data["base_count"], int)
|
||||
|
||||
def test_nested_pydantic_model_field(self):
|
||||
"""dict input for a Pydantic model-typed field passes through.
|
||||
convert() doesn't construct Pydantic models — Pydantic validation
|
||||
handles that downstream. This test documents the behavior."""
|
||||
|
||||
class InnerModel(BaseModel):
|
||||
x: int
|
||||
|
||||
class OuterModel(BaseModel):
|
||||
inner: InnerModel
|
||||
|
||||
data: dict[str, Any] = {"inner": {"x": 1}}
|
||||
coerce_inputs_to_schema(data, OuterModel)
|
||||
# dict stays as dict — convert() doesn't construct Pydantic models
|
||||
assert data["inner"] == {"x": 1}
|
||||
assert isinstance(data["inner"], dict)
|
||||
|
||||
def test_literal_field_passes_through(self):
|
||||
"""Literal-typed fields should not crash coercion."""
|
||||
|
||||
class LiteralSchema(BaseModel):
|
||||
status: Literal["active", "inactive"]
|
||||
|
||||
data: dict[str, Any] = {"status": "active"}
|
||||
coerce_inputs_to_schema(data, LiteralSchema)
|
||||
assert data["status"] == "active"
|
||||
|
||||
def test_list_of_pydantic_model_field(self):
|
||||
"""list[dict] for list[PydanticModel] passes through unchanged."""
|
||||
|
||||
class ItemModel(BaseModel):
|
||||
name: str
|
||||
|
||||
class ContainerModel(BaseModel):
|
||||
items: list[ItemModel]
|
||||
|
||||
data: dict[str, Any] = {"items": [{"name": "a"}, {"name": "b"}]}
|
||||
coerce_inputs_to_schema(data, ContainerModel)
|
||||
# Dicts stay as dicts — Pydantic validation handles construction
|
||||
assert data["items"] == [{"name": "a"}, {"name": "b"}]
|
||||
assert isinstance(data["items"][0], dict)
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
-- Migrate Gemini 3 Pro Preview to Gemini 3.1 Pro Preview
|
||||
-- This updates all AgentNode blocks that use the deprecated Gemini 3 Pro Preview model
|
||||
-- Google is shutting down google/gemini-3-pro-preview on March 9, 2026
|
||||
|
||||
-- Update AgentNode constant inputs
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
'"google/gemini-3.1-pro-preview"'::jsonb
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = 'google/gemini-3-pro-preview';
|
||||
|
||||
-- Update AgentPreset input overrides (stored in AgentNodeExecutionInputOutput)
|
||||
UPDATE "AgentNodeExecutionInputOutput"
|
||||
SET "data" = JSONB_SET(
|
||||
"data"::jsonb,
|
||||
'{model}',
|
||||
'"google/gemini-3.1-pro-preview"'::jsonb
|
||||
)
|
||||
WHERE "agentPresetId" IS NOT NULL
|
||||
AND "data"::jsonb->>'model' = 'google/gemini-3-pro-preview';
|
||||
@@ -1,46 +0,0 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the column `search` on the `StoreListingVersion` table. All the data in the column will be lost.
|
||||
|
||||
*/-- CreateEnum
|
||||
CREATE TYPE "InvitedUserStatus" AS ENUM('INVITED',
|
||||
'CLAIMED',
|
||||
'REVOKED');
|
||||
-- CreateEnum
|
||||
CREATE TYPE "TallyComputationStatus" AS ENUM('PENDING',
|
||||
'RUNNING',
|
||||
'READY',
|
||||
'FAILED');
|
||||
-- CreateTable
|
||||
CREATE TABLE "InvitedUser"(
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"email" TEXT NOT NULL,
|
||||
"status" "InvitedUserStatus" NOT NULL DEFAULT 'INVITED',
|
||||
"authUserId" TEXT,
|
||||
"name" TEXT,
|
||||
"tallyUnderstanding" JSONB,
|
||||
"tallyStatus" "TallyComputationStatus" NOT NULL DEFAULT 'PENDING',
|
||||
"tallyComputedAt" TIMESTAMP(3),
|
||||
"tallyError" TEXT,
|
||||
CONSTRAINT "InvitedUser_pkey" PRIMARY KEY("id")
|
||||
);
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "InvitedUser_email_key"
|
||||
ON "InvitedUser"("email");
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "InvitedUser_authUserId_key"
|
||||
ON "InvitedUser"("authUserId");
|
||||
-- CreateIndex
|
||||
CREATE INDEX "InvitedUser_status_idx"
|
||||
ON "InvitedUser"("status");
|
||||
-- CreateIndex
|
||||
CREATE INDEX "InvitedUser_tallyStatus_idx"
|
||||
ON "InvitedUser"("tallyStatus");
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "InvitedUser" ADD CONSTRAINT "InvitedUser_authUserId_fkey" FOREIGN KEY("authUserId") REFERENCES "User"("id")
|
||||
ON DELETE
|
||||
SET NULL
|
||||
ON UPDATE CASCADE;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user